Skip to content

Commit a132fb9

Browse files
committed
feat: enforce mypy type checking across all source files
1 parent 11e2ac5 commit a132fb9

22 files changed

Lines changed: 199 additions & 185 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ jobs:
4141
- name: Run pre-commit
4242
run: uv run pre-commit run --all-files
4343

44+
- name: Run mypy
45+
run: uv run mypy src/
46+
4447
test:
4548
name: Test Python ${{ matrix.python-version }}
4649
runs-on: ubuntu-latest

src/bedrock_agentcore/evaluation/integrations/strands_agents_evals/evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _is_valid_adot_document(item: Any) -> bool:
3131
return isinstance(item, dict) and "scope" in item and "traceId" in item and "spanId" in item
3232

3333

34-
def _validate_spans(spans):
34+
def _validate_spans(spans: Any) -> bool:
3535
"""Validate spans are OpenTelemetry Span objects."""
3636
if not spans:
3737
return False
@@ -127,14 +127,14 @@ def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[Eva
127127
]
128128

129129
# Check if spans are already in ADOT format or need conversion
130-
if _is_adot_format(evaluation_case.actual_trajectory):
130+
if _is_adot_format(evaluation_case.actual_trajectory): # type: ignore[arg-type]
131131
# Already in ADOT format (fetched from CloudWatch), use as-is
132132
spans = evaluation_case.actual_trajectory
133133
else:
134134
# Raw OTel spans from in-memory exporter, validate and convert
135135
if not _validate_spans(evaluation_case.actual_trajectory):
136136
return [EvaluationOutput(score=0.0, test_pass=False, reason="Invalid span objects")]
137-
spans = convert_strands_to_adot(evaluation_case.actual_trajectory)
137+
spans = convert_strands_to_adot(evaluation_case.actual_trajectory) # type: ignore[arg-type]
138138

139139
request_payload = {"evaluatorId": self.evaluator_id, "evaluationInput": {"sessionSpans": spans}}
140140

@@ -165,7 +165,7 @@ async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT])
165165
return await asyncio.to_thread(self.evaluate, evaluation_case)
166166

167167

168-
def create_strands_evaluator(evaluator_id: str, **kwargs) -> StrandsEvalsAgentCoreEvaluator:
168+
def create_strands_evaluator(evaluator_id: str, **kwargs: Any) -> StrandsEvalsAgentCoreEvaluator:
169169
"""Create Strands-compatible evaluator backed by AgentCore Evaluation API.
170170
171171
Args:

src/bedrock_agentcore/evaluation/span_to_adot_serializer/adot_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class SpanParser:
7777
"""
7878

7979
@staticmethod
80-
def extract_metadata(span) -> SpanMetadata:
80+
def extract_metadata(span: Any) -> SpanMetadata:
8181
"""Extract core span metadata."""
8282
if not hasattr(span, "context") or not span.context:
8383
raise ValueError(f"Span '{getattr(span, 'name', 'unknown')}' missing required context")
@@ -96,7 +96,7 @@ def extract_metadata(span) -> SpanMetadata:
9696
)
9797

9898
@staticmethod
99-
def extract_resource_info(span) -> ResourceInfo:
99+
def extract_resource_info(span: Any) -> ResourceInfo:
100100
"""Extract resource and scope information."""
101101
resource_attrs = {}
102102
if hasattr(span, "resource") and span.resource and hasattr(span.resource, "attributes"):
@@ -115,7 +115,7 @@ def extract_resource_info(span) -> ResourceInfo:
115115
)
116116

117117
@staticmethod
118-
def get_span_attributes(span) -> Dict[str, Any]:
118+
def get_span_attributes(span: Any) -> Dict[str, Any]:
119119
"""Safely extract span attributes."""
120120
return dict(span.attributes) if hasattr(span, "attributes") and span.attributes else {}
121121

src/bedrock_agentcore/evaluation/span_to_adot_serializer/strands_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ def extract_tool_execution(cls, events: List[Any]) -> Optional[ToolExecution]:
120120
class StrandsToADOTConverter:
121121
"""Convert Strands OTel spans to ADOT format."""
122122

123-
def __init__(self):
123+
def __init__(self) -> None:
124124
"""Initialize converter with parsers and builder."""
125125
self.span_parser = SpanParser()
126126
self.event_parser = StrandsEventParser()
127127
self.doc_builder = ADOTDocumentBuilder()
128128

129-
def convert_span(self, span) -> List[Dict[str, Any]]:
129+
def convert_span(self, span: Any) -> List[Dict[str, Any]]:
130130
"""Convert a single span to ADOT documents."""
131131
documents = []
132132

src/bedrock_agentcore/identity/auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _get_iam_jwt_token(region: str) -> str:
187187
try:
188188
response = sts_client.get_web_identity_token(**params)
189189
logger.info("Successfully obtained AWS IAM JWT token")
190-
return response["WebIdentityToken"]
190+
return response["WebIdentityToken"] # type: ignore[no-any-return]
191191
except ClientError as e:
192192
error_code = e.response.get("Error", {}).get("Code", "")
193193
if error_code in ["FeatureDisabledException", "FeatureDisabled"]:
@@ -231,7 +231,7 @@ def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
231231
def decorator(func: Callable) -> Callable:
232232
client = IdentityClient(_get_region())
233233

234-
async def _get_api_key():
234+
async def _get_api_key() -> str:
235235
return await client.get_api_key(
236236
provider_name=provider_name,
237237
agent_identity_token=await _get_workload_access_token(client),
@@ -268,7 +268,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
268268
return decorator
269269

270270

271-
def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]):
271+
def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]) -> Optional[str]:
272272
if user_provided_oauth2_callback_url:
273273
return user_provided_oauth2_callback_url
274274

@@ -298,7 +298,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
298298

299299
config_path = Path(".agentcore.json")
300300
workload_identity_name = None
301-
config = {}
301+
config: dict[str, str] = {}
302302
if config_path.exists():
303303
try:
304304
with open(config_path, "r", encoding="utf-8") as file:
@@ -327,7 +327,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
327327
except Exception:
328328
print("Warning: could not write the created workload identity to file")
329329

330-
return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"]
330+
return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] # type: ignore[no-any-return]
331331

332332

333333
def _get_region() -> str:

src/bedrock_agentcore/memory/client.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self.gmdp_client.meta.region_name,
9999
)
100100

101-
def __getattr__(self, name: str):
101+
def __getattr__(self, name: str) -> Any:
102102
"""Dynamically forward method calls to the appropriate boto3 client.
103103
104104
This method enables access to all boto3 client methods without explicitly
@@ -203,7 +203,7 @@ def create_or_get_memory(
203203
try:
204204
memory = self.create_memory_and_wait(
205205
name=name,
206-
strategies=strategies,
206+
strategies=strategies, # type: ignore[arg-type]
207207
description=description,
208208
event_expiry_days=event_expiry_days,
209209
memory_execution_role_arn=memory_execution_role_arn,
@@ -213,7 +213,7 @@ def create_or_get_memory(
213213
except ClientError as e:
214214
if e.response["Error"]["Code"] == "ValidationException" and "already exists" in str(e):
215215
memories = self.list_memories()
216-
memory = next((m for m in memories if m["id"].startswith(name)), None)
216+
memory = next((m for m in memories if m["id"].startswith(name)), None) # type: ignore[arg-type]
217217
logger.info("Memory already exists. Using existing memory ID: %s", memory["id"])
218218
return memory
219219
else:
@@ -338,7 +338,7 @@ def retrieve_memories(
338338
memoryId=memory_id, namespace=namespace, searchCriteria={"searchQuery": query, "topK": top_k}
339339
)
340340

341-
memories = response.get("memoryRecordSummaries", [])
341+
memories: list[Dict[str, Any]] = response.get("memoryRecordSummaries", [])
342342
logger.info("Retrieved %d memories from namespace: %s", len(memories), namespace)
343343
return memories
344344

@@ -473,7 +473,7 @@ def create_event(
473473

474474
response = self.gmdp_client.create_event(**params)
475475

476-
event = response["event"]
476+
event: Dict[str, Any] = response["event"]
477477
logger.info("Created event: %s", event["eventId"])
478478

479479
return event
@@ -539,7 +539,7 @@ def create_blob_event(
539539

540540
response = self.gmdp_client.create_event(**params)
541541

542-
event = response["event"]
542+
event: Dict[str, Any] = response["event"]
543543
logger.info("Created blob event: %s", event["eventId"])
544544

545545
return event
@@ -635,7 +635,7 @@ def save_conversation(
635635

636636
response = self.gmdp_client.create_event(**params)
637637

638-
event = response["event"]
638+
event: Dict[str, Any] = response["event"]
639639
logger.info("Created event: %s", event["eventId"])
640640

641641
return event
@@ -777,7 +777,7 @@ def list_events(
777777
)
778778
"""
779779
try:
780-
all_events = []
780+
all_events: List[Dict[str, Any]] = []
781781
next_token = None
782782

783783
while len(all_events) < max_results:
@@ -793,7 +793,7 @@ def list_events(
793793
params["nextToken"] = next_token
794794

795795
# Build filter map
796-
filter_map = {}
796+
filter_map: Dict[str, Any] = {}
797797

798798
# Add branch filter if specified (but not for "main")
799799
if branch_name and branch_name != "main":
@@ -937,7 +937,7 @@ def list_branch_events(
937937
params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}}
938938

939939
response = self.gmdp_client.list_events(**params)
940-
events = response.get("events", [])
940+
events: list[Dict[str, Any]] = response.get("events", [])
941941

942942
# Handle pagination
943943
next_token = response.get("nextToken")
@@ -992,7 +992,11 @@ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str)
992992
break
993993

994994
# Build tree structure
995-
tree = {"session_id": session_id, "actor_id": actor_id, "main_branch": {"events": [], "branches": {}}}
995+
tree: Dict[str, Any] = {
996+
"session_id": session_id,
997+
"actor_id": actor_id,
998+
"main_branch": {"events": [], "branches": {}},
999+
}
9961000

9971001
# Group events by branch
9981002
for event in all_events:
@@ -1094,7 +1098,7 @@ def get_last_k_turns(
10941098
Returns:
10951099
List of turns, where each turn is a list of message dictionaries
10961100
"""
1097-
base_params = {
1101+
base_params: Dict[str, Any] = {
10981102
"memoryId": memory_id,
10991103
"actorId": actor_id,
11001104
"sessionId": session_id,
@@ -1222,7 +1226,7 @@ def get_memory_status(self, memory_id: str) -> str:
12221226
"""Get current memory status."""
12231227
try:
12241228
response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1225-
return response["memory"]["status"]
1229+
return response["memory"]["status"] # type: ignore[no-any-return]
12261230
except ClientError as e:
12271231
logger.error("Failed to get memory status: %s", e)
12281232
raise
@@ -1265,7 +1269,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
12651269
def delete_memory(self, memory_id: str) -> Dict[str, Any]:
12661270
"""Delete a memory resource."""
12671271
try:
1268-
response = self.gmcp_client.delete_memory(
1272+
response: Dict[str, Any] = self.gmcp_client.delete_memory(
12691273
memoryId=memory_id, clientToken=str(uuid.uuid4())
12701274
) # Input uses old field name
12711275
logger.info("Deleted memory: %s", memory_id)

src/bedrock_agentcore/memory/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ConversationalMessage:
128128
text: str
129129
role: MessageRole
130130

131-
def __post_init__(self):
131+
def __post_init__(self) -> None:
132132
"""Validate message fields after initialization."""
133133
if not isinstance(self.text, str):
134134
raise ValueError("ConversationalMessage.text must be a string")

src/bedrock_agentcore/memory/controlplane.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def create_memory(
9292

9393
try:
9494
response = self.client.create_memory(**params)
95-
memory = response["memory"]
95+
memory: Dict[str, Any] = response["memory"]
9696
memory_id = memory["id"]
9797

9898
logger.info("Created memory: %s", memory_id)
@@ -118,7 +118,7 @@ def get_memory(self, memory_id: str, include_strategies: bool = True) -> Dict[st
118118
"""
119119
try:
120120
response = self.client.get_memory(memoryId=memory_id)
121-
memory = response["memory"]
121+
memory: Dict[str, Any] = response["memory"]
122122

123123
# Add strategy count
124124
strategies = memory.get("strategies", [])
@@ -144,7 +144,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
144144
List of memory summaries
145145
"""
146146
try:
147-
memories = []
147+
memories: List[Dict[str, Any]] = []
148148
next_token = None
149149

150150
while len(memories) < max_results:
@@ -239,7 +239,7 @@ def update_memory(
239239

240240
try:
241241
response = self.client.update_memory(**params)
242-
memory = response["memory"]
242+
memory: Dict[str, Any] = response["memory"]
243243
logger.info("Updated memory: %s", memory_id)
244244

245245
if wait_for_active:
@@ -300,7 +300,7 @@ def delete_memory(
300300
logger.warning("Error waiting for strategies to become ACTIVE: %s", e)
301301

302302
# Now delete the memory
303-
response = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4()))
303+
response: Dict[str, Any] = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4()))
304304

305305
logger.info("Initiated deletion of memory: %s", memory_id)
306306

@@ -399,7 +399,8 @@ def get_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]:
399399

400400
for strategy in strategies:
401401
if strategy.get("strategyId") == strategy_id:
402-
return strategy
402+
result: Dict[str, Any] = strategy
403+
return result
403404

404405
raise ValueError(f"Strategy {strategy_id} not found in memory {memory_id}")
405406

@@ -567,7 +568,7 @@ def _wait_for_status(
567568

568569
start_time = time.time()
569570
last_memory_status = None
570-
strategy_statuses = {}
571+
strategy_statuses: Dict[str, str] = {}
571572

572573
while time.time() - start_time < max_wait:
573574
try:

src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
7979
if "conversational" in payload_item:
8080
conv = payload_item["conversational"]
8181
session_msg = SessionMessage.from_dict(json.loads(conv["content"]["text"]))
82-
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
82+
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type]
8383
if session_msg.message.get("content"):
8484
messages.append(session_msg)
8585
elif "blob" in payload_item:
@@ -88,7 +88,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
8888
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
8989
try:
9090
session_msg = SessionMessage.from_dict(json.loads(blob_data[0]))
91-
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
91+
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type]
9292
if session_msg.message.get("content"):
9393
messages.append(session_msg)
9494
except (json.JSONDecodeError, ValueError):

src/bedrock_agentcore/memory/integrations/strands/converters/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _bedrock_to_openai(message: dict) -> dict:
5959
}
6060
)
6161

62-
result: dict[str, Any] = {"role": role}
62+
result: dict[str, Any] = {"role": role} # type: ignore[no-redef]
6363

6464
if tool_calls:
6565
result["content"] = "\n".join(text_parts) if text_parts else None
@@ -144,7 +144,7 @@ def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]
144144
if not has_non_empty:
145145
return []
146146

147-
openai_msg = _bedrock_to_openai(message)
147+
openai_msg = _bedrock_to_openai(message) # type: ignore[arg-type]
148148
role = openai_msg.get("role", "user")
149149
return [(json.dumps(openai_msg), role)]
150150

@@ -177,7 +177,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
177177
if openai_msg and isinstance(openai_msg, dict):
178178
bedrock_msg = _openai_to_bedrock(openai_msg)
179179
if bedrock_msg.get("content"):
180-
session_msg = SessionMessage(message=bedrock_msg, message_id=0)
180+
session_msg = SessionMessage(message=bedrock_msg, message_id=0) # type: ignore[arg-type]
181181
messages.append(session_msg)
182182

183183
return messages

0 commit comments

Comments
 (0)