diff --git a/src/bedrock_agentcore/_utils/snake_case.py b/src/bedrock_agentcore/_utils/snake_case.py new file mode 100644 index 00000000..d6be4b77 --- /dev/null +++ b/src/bedrock_agentcore/_utils/snake_case.py @@ -0,0 +1,46 @@ +"""Utilities for wrapping boto3 methods to accept snake_case kwargs.""" + +import functools +import re +from typing import Any, Callable, Dict + +_VALID_SNAKE_RE = re.compile(r"^[a-z][a-z0-9]*(_[a-z0-9]+)*$") + + +def snake_to_camel(name: str) -> str: + """Convert a snake_case string to camelCase. + + Already-camelCase strings pass through unchanged (no underscores to split on). + Raises ValueError for malformed snake_case (e.g. leading/trailing underscores, + consecutive underscores, uppercase characters). + """ + if "_" not in name: + return name + if not _VALID_SNAKE_RE.match(name): + raise ValueError(f"Invalid parameter name: '{name}'") + parts = name.split("_") + return parts[0] + "".join(p.title() for p in parts[1:]) + + +def accept_snake_case_kwargs(method: Callable[..., Any]) -> Callable[..., Any]: + """Wrap a boto3 method to accept both snake_case and camelCase kwargs. + + Converts all snake_case kwargs to camelCase before forwarding. + Raises TypeError if both forms are provided (e.g. memory_id and memoryId). + """ + + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + converted: Dict[str, Any] = {} + original_keys: Dict[str, str] = {} + for key, value in kwargs.items(): + camel_key = snake_to_camel(key) + if camel_key in converted: + raise TypeError( + f"Got both '{original_keys[camel_key]}' and '{key}' for the same parameter. Use one or the other." + ) + original_keys[camel_key] = key + converted[camel_key] = value + return method(*args, **converted) + + return wrapper diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index 53bb52af..f7c1be46 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -20,6 +20,7 @@ from botocore.config import Config from botocore.exceptions import ClientError +from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs from bedrock_agentcore._utils.user_agent import build_user_agent_suffix from .constants import ( @@ -126,12 +127,12 @@ def __getattr__(self, name: str): if name in self._ALLOWED_GMDP_METHODS and hasattr(self.gmdp_client, name): method = getattr(self.gmdp_client, name) logger.debug("Forwarding method '%s' to gmdp_client", name) - return method + return accept_snake_case_kwargs(method) if name in self._ALLOWED_GMCP_METHODS and hasattr(self.gmcp_client, name): method = getattr(self.gmcp_client, name) logger.debug("Forwarding method '%s' to gmcp_client", name) - return method + return accept_snake_case_kwargs(method) # Method not found on either client raise AttributeError( diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index bf88f278..279f29c5 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -10,6 +10,8 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs + from .constants import BlobMessage, ConversationalMessage, MessageRole, RetrievalConfig from .models import ( ActorSummary, @@ -240,7 +242,7 @@ def __getattr__(self, name: str): if name in self._ALLOWED_DATA_PLANE_METHODS and hasattr(self._data_plane_client, name): method = getattr(self._data_plane_client, name) logger.debug("Forwarding method '%s' to _data_plane_client", name) - return method + return accept_snake_case_kwargs(method) # Method not found on client raise AttributeError( diff --git a/tests/bedrock_agentcore/memory/test_session.py b/tests/bedrock_agentcore/memory/test_session.py index b4e07b96..56a12da1 100644 --- a/tests/bedrock_agentcore/memory/test_session.py +++ b/tests/bedrock_agentcore/memory/test_session.py @@ -360,11 +360,20 @@ def test_getattr_allowed_method(self): manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") # Test accessing an allowed method - mock_method = MagicMock() + mock_method = MagicMock(return_value={"records": []}) mock_client_instance.retrieve_memory_records = mock_method result = manager.retrieve_memory_records - assert result == mock_method + assert callable(result) + + # camelCase works (backward compat) + result(memoryId="mem-1", namespace="ns/") + mock_method.assert_called_once_with(memoryId="mem-1", namespace="ns/") + + # snake_case is converted to camelCase + mock_method.reset_mock() + result(memory_id="mem-1", namespace="ns/") + mock_method.assert_called_once_with(memoryId="mem-1", namespace="ns/") def test_getattr_disallowed_method(self): """Test __getattr__ raises AttributeError for disallowed methods.""" @@ -3652,7 +3661,7 @@ def test_getattr_debug_logging(self): manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") # Mock an allowed method - mock_method = MagicMock() + mock_method = MagicMock(return_value={"records": []}) mock_client_instance.retrieve_memory_records = mock_method with patch("bedrock_agentcore.memory.session.logger") as mock_logger: @@ -3662,7 +3671,9 @@ def test_getattr_debug_logging(self): mock_logger.debug.assert_called_once_with( "Forwarding method '%s' to _data_plane_client", "retrieve_memory_records" ) - assert result == mock_method + assert callable(result) + result(memoryId="mem-1") + mock_method.assert_called_once_with(memoryId="mem-1") def test_process_turn_with_llm_no_retrieval_namespace(self): """Test process_turn_with_llm without retrieval_config (no memory retrieval).""" diff --git a/tests/bedrock_agentcore/test_snake_case.py b/tests/bedrock_agentcore/test_snake_case.py new file mode 100644 index 00000000..11e1a827 --- /dev/null +++ b/tests/bedrock_agentcore/test_snake_case.py @@ -0,0 +1,101 @@ +"""Tests for snake_case kwargs utilities.""" + +from unittest.mock import MagicMock + +import pytest + +from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs, snake_to_camel + + +class TestSnakeToCamel: + """Tests for snake_to_camel conversion.""" + + def test_single_word(self): + assert snake_to_camel("name") == "name" + + def test_two_words(self): + assert snake_to_camel("memory_id") == "memoryId" + + def test_already_camel_case_passthrough(self): + assert snake_to_camel("memoryId") == "memoryId" + + def test_multi_segment_snake(self): + assert snake_to_camel("memory_execution_role_arn") == "memoryExecutionRoleArn" + + def test_empty_string(self): + assert snake_to_camel("") == "" + + # Reject malformed snake_case early rather than silently converting it. + # We don't want users depending on conversion quirks (e.g. "a__b" → "aB") + # that only work by accident of the current implementation. + + def test_rejects_leading_underscore(self): + with pytest.raises(ValueError, match="Invalid parameter name"): + snake_to_camel("_private") + + def test_rejects_consecutive_underscores(self): + with pytest.raises(ValueError, match="Invalid parameter name"): + snake_to_camel("a__b") + + def test_rejects_trailing_underscore(self): + with pytest.raises(ValueError, match="Invalid parameter name"): + snake_to_camel("name_") + + def test_rejects_uppercase_in_snake(self): + with pytest.raises(ValueError, match="Invalid parameter name"): + snake_to_camel("memory_ID") + + +class TestAcceptSnakeCaseKwargs: + """Tests for accept_snake_case_kwargs wrapper.""" + + def setup_method(self): + self.mock_method = MagicMock(return_value={"result": "ok"}) + + def test_snake_case_converted(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + wrapped(memory_id="mem-1", actor_id="user-1") + self.mock_method.assert_called_once_with(memoryId="mem-1", actorId="user-1") + + def test_camel_case_passthrough(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + wrapped(memoryId="mem-1", actorId="user-1") + self.mock_method.assert_called_once_with(memoryId="mem-1", actorId="user-1") + + def test_mixed_snake_and_camel_different_params(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + wrapped(memory_id="mem-1", actorId="user-1") + self.mock_method.assert_called_once_with(memoryId="mem-1", actorId="user-1") + + def test_collision_raises_type_error(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + with pytest.raises(TypeError, match="memoryId.*memory_id"): + wrapped(memoryId="mem-1", memory_id="mem-2") + + def test_return_value_forwarded(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + result = wrapped(memory_id="mem-1") + assert result == {"result": "ok"} + + def test_positional_args_forwarded(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + wrapped("pos1", "pos2", memory_id="mem-1") + self.mock_method.assert_called_once_with("pos1", "pos2", memoryId="mem-1") + + def test_no_kwargs(self): + wrapped = accept_snake_case_kwargs(self.mock_method) + wrapped() + self.mock_method.assert_called_once_with() + + def test_exception_propagated(self): + self.mock_method.side_effect = ValueError("boom") + wrapped = accept_snake_case_kwargs(self.mock_method) + with pytest.raises(ValueError, match="boom"): + wrapped(memory_id="mem-1") + + def test_preserves_function_name(self): + def my_boto3_method(): + pass + + wrapped = accept_snake_case_kwargs(my_boto3_method) + assert wrapped.__name__ == "my_boto3_method" diff --git a/tests_integ/memory/test_memory_client.py b/tests_integ/memory/test_memory_client.py index 7f2764fc..67074e95 100644 --- a/tests_integ/memory/test_memory_client.py +++ b/tests_integ/memory/test_memory_client.py @@ -94,13 +94,12 @@ def test_stream_delivery_create_and_update(self): assert memory["streamDeliveryResources"] == delivery_config # Test update via MemoryClient.__getattr__ passthrough to boto3 client. - # Uses camelCase params because the passthrough forwards directly to boto3 - # without the snake_case translation that explicit SDK methods provide. + # Uses snake_case params — the passthrough wrapper converts to camelCase. updated_config = self._make_delivery_config("METADATA_ONLY") response = self.client.update_memory( - memoryId=memory_id, - clientToken=str(uuid.uuid4()), - streamDeliveryResources=updated_config, + memory_id=memory_id, + client_token=str(uuid.uuid4()), + stream_delivery_resources=updated_config, ) assert response["memory"]["streamDeliveryResources"] == updated_config