From 5541d53dc37b60b0427f12cff012dcb41eddaf07 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 13 May 2026 10:48:31 +0900 Subject: [PATCH] fix: unify memory optional dependency import errors --- src/agents/extensions/memory/__init__.py | 129 ++++----------- .../extensions/memory/_optional_imports.py | 19 +++ src/agents/extensions/memory/dapr_session.py | 11 +- .../extensions/memory/mongodb_session.py | 12 +- src/agents/extensions/memory/redis_session.py | 11 +- .../extensions/memory/test_memory_imports.py | 154 ++++++++++++++++++ 6 files changed, 232 insertions(+), 104 deletions(-) create mode 100644 src/agents/extensions/memory/_optional_imports.py create mode 100644 tests/extensions/memory/test_memory_imports.py diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py index 3b0c74da5f..3fcb71ecaf 100644 --- a/src/agents/extensions/memory/__init__.py +++ b/src/agents/extensions/memory/__init__.py @@ -8,8 +8,11 @@ from __future__ import annotations +from importlib import import_module from typing import TYPE_CHECKING, Any +from ._optional_imports import raise_optional_dependency_error + if TYPE_CHECKING: from .advanced_sqlite_session import AdvancedSQLiteSession from .async_sqlite_session import AsyncSQLiteSession @@ -35,99 +38,37 @@ "SQLAlchemySession", ] +_LAZY_EXPORTS: dict[str, tuple[str, tuple[str, str] | None]] = { + "EncryptedSession": (".encrypt_session", ("cryptography", "encrypt")), + "RedisSession": (".redis_session", ("redis", "redis")), + "SQLAlchemySession": (".sqlalchemy_session", ("sqlalchemy", "sqlalchemy")), + "AdvancedSQLiteSession": (".advanced_sqlite_session", None), + "AsyncSQLiteSession": (".async_sqlite_session", None), + "DaprSession": (".dapr_session", ("dapr", "dapr")), + "DAPR_CONSISTENCY_EVENTUAL": (".dapr_session", ("dapr", "dapr")), + "DAPR_CONSISTENCY_STRONG": (".dapr_session", ("dapr", "dapr")), + "MongoDBSession": (".mongodb_session", ("mongodb", "mongodb")), +} -def __getattr__(name: str) -> Any: - if name == "EncryptedSession": - try: - from .encrypt_session import EncryptedSession # noqa: F401 - - return EncryptedSession - except ModuleNotFoundError as e: - raise ImportError( - "EncryptedSession requires the 'cryptography' extra. " - "Install it with: pip install openai-agents[encrypt]" - ) from e - - if name == "RedisSession": - try: - from .redis_session import RedisSession # noqa: F401 - - return RedisSession - except ModuleNotFoundError as e: - raise ImportError( - "RedisSession requires the 'redis' extra. " - "Install it with: pip install openai-agents[redis]" - ) from e - - if name == "SQLAlchemySession": - try: - from .sqlalchemy_session import SQLAlchemySession # noqa: F401 - - return SQLAlchemySession - except ModuleNotFoundError as e: - raise ImportError( - "SQLAlchemySession requires the 'sqlalchemy' extra. " - "Install it with: pip install openai-agents[sqlalchemy]" - ) from e - - if name == "AdvancedSQLiteSession": - try: - from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401 - - return AdvancedSQLiteSession - except ModuleNotFoundError as e: - raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e - - if name == "AsyncSQLiteSession": - try: - from .async_sqlite_session import AsyncSQLiteSession # noqa: F401 - return AsyncSQLiteSession - except ModuleNotFoundError as e: - raise ImportError(f"Failed to import AsyncSQLiteSession: {e}") from e - - if name == "DaprSession": - try: - from .dapr_session import DaprSession # noqa: F401 - - return DaprSession - except ModuleNotFoundError as e: - raise ImportError( - "DaprSession requires the 'dapr' extra. " - "Install it with: pip install openai-agents[dapr]" - ) from e - - if name == "DAPR_CONSISTENCY_EVENTUAL": - try: - from .dapr_session import DAPR_CONSISTENCY_EVENTUAL # noqa: F401 - - return DAPR_CONSISTENCY_EVENTUAL - except ModuleNotFoundError as e: - raise ImportError( - "DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. " - "Install it with: pip install openai-agents[dapr]" - ) from e - - if name == "DAPR_CONSISTENCY_STRONG": - try: - from .dapr_session import DAPR_CONSISTENCY_STRONG # noqa: F401 - - return DAPR_CONSISTENCY_STRONG - except ModuleNotFoundError as e: - raise ImportError( - "DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. " - "Install it with: pip install openai-agents[dapr]" - ) from e - - if name == "MongoDBSession": - try: - from .mongodb_session import MongoDBSession # noqa: F401 - - return MongoDBSession - except ModuleNotFoundError as e: - raise ImportError( - "MongoDBSession requires the 'mongodb' extra. " - "Install it with: pip install openai-agents[mongodb]" - ) from e - - raise AttributeError(f"module {__name__} has no attribute {name}") +def __getattr__(name: str) -> Any: + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__} has no attribute {name}") + + module_name, optional_dependency = _LAZY_EXPORTS[name] + try: + module = import_module(module_name, __name__) + except ModuleNotFoundError as e: + if optional_dependency is None: + raise ImportError(f"Failed to import {name}: {e}") from e + dependency_name, extra_name = optional_dependency + raise_optional_dependency_error( + name, + dependency_name=dependency_name, + extra_name=extra_name, + cause=e, + ) + + value = getattr(module, name) + globals()[name] = value + return value diff --git a/src/agents/extensions/memory/_optional_imports.py b/src/agents/extensions/memory/_optional_imports.py new file mode 100644 index 0000000000..422d9cb679 --- /dev/null +++ b/src/agents/extensions/memory/_optional_imports.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import NoReturn + + +def raise_optional_dependency_error( + export_name: str, + *, + dependency_name: str, + extra_name: str, + cause: ImportError | None = None, +) -> NoReturn: + error = ImportError( + f"{export_name} requires the '{dependency_name}' extra. " + f"Install it with: pip install openai-agents[{extra_name}]" + ) + if cause is None: + raise error + raise error from cause diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 8d92872406..6ac68f6020 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -29,13 +29,18 @@ import time from typing import Any, Final, Literal +from ._optional_imports import raise_optional_dependency_error + try: from dapr.aio.clients import DaprClient from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions except ImportError as e: - raise ImportError( - "DaprSession requires the 'dapr' package. Install it with: pip install dapr" - ) from e + raise_optional_dependency_error( + "DaprSession", + dependency_name="dapr", + extra_name="dapr", + cause=e, + ) from ...items import TResponseInputItem from ...logger import logger diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 0abaa2bfe2..113acdc6af 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -37,6 +37,8 @@ from datetime import datetime, timezone from typing import Any +from ._optional_imports import raise_optional_dependency_error + try: from importlib.metadata import version as _get_version @@ -49,10 +51,12 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.driver_info import DriverInfo except ImportError as e: - raise ImportError( - "MongoDBSession requires the 'pymongo' package (>=4.14). " - "Install it with: pip install openai-agents[mongodb]" - ) from e + raise_optional_dependency_error( + "MongoDBSession", + dependency_name="mongodb", + extra_name="mongodb", + cause=e, + ) from ...items import TResponseInputItem from ...memory.session import SessionABC diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 60e863428a..11e2dd838b 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -26,13 +26,18 @@ import time from typing import Any +from ._optional_imports import raise_optional_dependency_error + try: import redis.asyncio as redis from redis.asyncio import Redis except ImportError as e: - raise ImportError( - "RedisSession requires the 'redis' package. Install it with: pip install redis" - ) from e + raise_optional_dependency_error( + "RedisSession", + dependency_name="redis", + extra_name="redis", + cause=e, + ) from ...items import TResponseInputItem from ...memory.session import SessionABC diff --git a/tests/extensions/memory/test_memory_imports.py b/tests/extensions/memory/test_memory_imports.py new file mode 100644 index 0000000000..955d5a79d4 --- /dev/null +++ b/tests/extensions/memory/test_memory_imports.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import importlib.abc +import sys +from types import ModuleType + +import pytest + +_PACKAGE_EXPORTS: tuple[tuple[str, str, str, str, str], ...] = ( + ( + "EncryptedSession", + "agents.extensions.memory.encrypt_session", + "agents.extensions.memory.encrypt_session", + "cryptography", + "encrypt", + ), + ("RedisSession", "agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"), + ( + "SQLAlchemySession", + "agents.extensions.memory.sqlalchemy_session", + "agents.extensions.memory.sqlalchemy_session", + "sqlalchemy", + "sqlalchemy", + ), + ("DaprSession", "agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"), + ( + "DAPR_CONSISTENCY_EVENTUAL", + "agents.extensions.memory.dapr_session", + "dapr.aio.clients", + "dapr", + "dapr", + ), + ( + "DAPR_CONSISTENCY_STRONG", + "agents.extensions.memory.dapr_session", + "dapr.aio.clients", + "dapr", + "dapr", + ), + ( + "MongoDBSession", + "agents.extensions.memory.mongodb_session", + "pymongo.asynchronous.collection", + "mongodb", + "mongodb", + ), +) + +_DIRECT_MODULE_IMPORTS: tuple[tuple[str, str, str, str], ...] = ( + ("agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"), + ("agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"), + ( + "agents.extensions.memory.mongodb_session", + "pymongo.asynchronous.collection", + "mongodb", + "mongodb", + ), +) + + +class _BrokenImportFinder(importlib.abc.MetaPathFinder): + def __init__(self, broken_module: str, error_cls: type[ImportError]) -> None: + self._broken_module = broken_module + self._error_cls = error_cls + + def find_spec( + self, + fullname: str, + path: object | None, + target: ModuleType | None = None, + ) -> None: + if fullname == self._broken_module: + raise self._error_cls("simulated dependency import failure") + return None + + +def _reset_package_imports( + monkeypatch: pytest.MonkeyPatch, + memory_module: ModuleType, + symbol: str, + module_name: str, + broken_module: str, +) -> None: + monkeypatch.delitem(memory_module.__dict__, symbol, raising=False) + _reset_loaded_module(monkeypatch, module_name) + _reset_loaded_module(monkeypatch, broken_module) + + +def _reset_loaded_module(monkeypatch: pytest.MonkeyPatch, module_name: str) -> None: + monkeypatch.delitem(sys.modules, module_name, raising=False) + parent_name, short_name = module_name.rsplit(".", 1) + parent_module = sys.modules.get(parent_name) + if parent_module is not None: + monkeypatch.delitem(parent_module.__dict__, short_name, raising=False) + + +def _reset_module_imports( + monkeypatch: pytest.MonkeyPatch, + module_name: str, + broken_module: str, +) -> None: + _reset_loaded_module(monkeypatch, module_name) + _reset_loaded_module(monkeypatch, broken_module) + + +@pytest.mark.parametrize( + ("symbol", "module_name", "broken_module", "dependency_name", "extra_name"), + _PACKAGE_EXPORTS, +) +def test_memory_package_imports_point_to_optional_extra( + monkeypatch: pytest.MonkeyPatch, + symbol: str, + module_name: str, + broken_module: str, + dependency_name: str, + extra_name: str, +) -> None: + import agents.extensions.memory as memory_module + + _reset_package_imports(monkeypatch, memory_module, symbol, module_name, broken_module) + finder = _BrokenImportFinder(broken_module, ModuleNotFoundError) + monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path]) + + with pytest.raises(ImportError) as exc_info: + getattr(memory_module, symbol) + + assert f"requires the '{dependency_name}' extra" in str(exc_info.value) + assert f"openai-agents[{extra_name}]" in str(exc_info.value) + assert isinstance(exc_info.value.__cause__, ImportError) + + +@pytest.mark.parametrize( + ("module_name", "broken_module", "dependency_name", "extra_name"), + _DIRECT_MODULE_IMPORTS, +) +@pytest.mark.parametrize("error_cls", [ImportError, ModuleNotFoundError]) +def test_memory_direct_module_imports_point_to_optional_extra( + monkeypatch: pytest.MonkeyPatch, + module_name: str, + broken_module: str, + dependency_name: str, + extra_name: str, + error_cls: type[ImportError], +) -> None: + _reset_module_imports(monkeypatch, module_name, broken_module) + finder = _BrokenImportFinder(broken_module, error_cls) + monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path]) + + with pytest.raises(ImportError) as exc_info: + __import__(module_name) + + assert f"requires the '{dependency_name}' extra" in str(exc_info.value) + assert f"openai-agents[{extra_name}]" in str(exc_info.value) + assert isinstance(exc_info.value.__cause__, ImportError)