-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnova_client.py
More file actions
126 lines (102 loc) · 3.79 KB
/
nova_client.py
File metadata and controls
126 lines (102 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
NovaClient - AWS Bedrock Runtime wrapper for Amazon Nova model.
Uses the Converse API with extended thinking enabled.
"""
import os
from pathlib import Path
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
# Load .env if python-dotenv is available (for API key from file)
try:
from dotenv import load_dotenv
load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env")
except ImportError:
pass
class NovaClient:
"""Client for invoking Amazon Nova model via AWS Bedrock Converse API."""
MODEL_ID = "global.amazon.nova-2-lite-v1:0"
REGION = "us-east-1"
def __init__(self, api_key: str | None = None):
"""
Initialize the Bedrock Runtime client in us-east-1.
Args:
api_key: Optional Bedrock API key. If provided, sets
AWS_BEARER_TOKEN_BEDROCK. Otherwise uses that env var.
"""
if api_key is not None:
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = api_key
self._client = boto3.client(
"bedrock-runtime",
region_name=self.REGION,
)
def invoke(self, prompt: str) -> str:
"""
Invoke the Nova model with the given prompt using the Converse API.
Args:
prompt: The user prompt to send to the model.
Returns:
The text content from the model's response.
Raises:
ClientError: On API errors (except ThrottlingException, which is retried).
"""
messages = [
{
"role": "user",
"content": [{"text": prompt}],
}
]
request_kwargs = {
"modelId": self.MODEL_ID,
"messages": messages,
"additionalModelRequestFields": {
"reasoningConfig": {
"type": "enabled",
"maxReasoningEffort": "low",
}
},
}
try:
response = self._client.converse(**request_kwargs)
return self._parse_response(response)
except NoCredentialsError as e:
raise NoCredentialsError(
"Unable to locate AWS credentials. If you are using a Bedrock API key, ensure "
"AWS_BEARER_TOKEN_BEDROCK is set and your boto3/botocore version supports Bedrock API key auth. "
"Alternatively configure IAM credentials (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or AWS profile)."
) from e
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "ThrottlingException":
raise ThrottlingException(
"Rate limit exceeded. Please retry after a short delay."
) from e
raise
def _parse_response(self, response: dict) -> str:
"""
Extract text content from the Converse API response.
Args:
response: The raw response from bedrock-runtime.converse().
Returns:
Concatenated text content from the output message.
"""
output = response.get("output", {})
message = output.get("message", {})
content = message.get("content", [])
text_parts = []
for block in content:
if block.get("text"):
text_parts.append(block["text"])
return "".join(text_parts)
class ThrottlingException(Exception):
"""Raised when Bedrock returns a ThrottlingException."""
pass
if __name__ == "__main__":
print("Testing NovaClient...")
try:
client = NovaClient()
response = client.invoke("Say hello in one sentence.")
print("Response:", response)
except ThrottlingException as e:
print("Throttled:", e)
except Exception as e:
print("Error:", e)