Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 35 additions & 94 deletions src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

from __future__ import annotations

from importlib import import_module
from typing import TYPE_CHECKING, Any

from ._optional_imports import raise_optional_dependency_error

if TYPE_CHECKING:
from .advanced_sqlite_session import AdvancedSQLiteSession
from .async_sqlite_session import AsyncSQLiteSession
Expand All @@ -35,99 +38,37 @@
"SQLAlchemySession",
]

_LAZY_EXPORTS: dict[str, tuple[str, tuple[str, str] | None]] = {
"EncryptedSession": (".encrypt_session", ("cryptography", "encrypt")),
"RedisSession": (".redis_session", ("redis", "redis")),
"SQLAlchemySession": (".sqlalchemy_session", ("sqlalchemy", "sqlalchemy")),
"AdvancedSQLiteSession": (".advanced_sqlite_session", None),
"AsyncSQLiteSession": (".async_sqlite_session", None),
"DaprSession": (".dapr_session", ("dapr", "dapr")),
"DAPR_CONSISTENCY_EVENTUAL": (".dapr_session", ("dapr", "dapr")),
"DAPR_CONSISTENCY_STRONG": (".dapr_session", ("dapr", "dapr")),
"MongoDBSession": (".mongodb_session", ("mongodb", "mongodb")),
}

def __getattr__(name: str) -> Any:
if name == "EncryptedSession":
try:
from .encrypt_session import EncryptedSession # noqa: F401

return EncryptedSession
except ModuleNotFoundError as e:
raise ImportError(
"EncryptedSession requires the 'cryptography' extra. "
"Install it with: pip install openai-agents[encrypt]"
) from e

if name == "RedisSession":
try:
from .redis_session import RedisSession # noqa: F401

return RedisSession
except ModuleNotFoundError as e:
raise ImportError(
"RedisSession requires the 'redis' extra. "
"Install it with: pip install openai-agents[redis]"
) from e

if name == "SQLAlchemySession":
try:
from .sqlalchemy_session import SQLAlchemySession # noqa: F401

return SQLAlchemySession
except ModuleNotFoundError as e:
raise ImportError(
"SQLAlchemySession requires the 'sqlalchemy' extra. "
"Install it with: pip install openai-agents[sqlalchemy]"
) from e

if name == "AdvancedSQLiteSession":
try:
from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401

return AdvancedSQLiteSession
except ModuleNotFoundError as e:
raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e

if name == "AsyncSQLiteSession":
try:
from .async_sqlite_session import AsyncSQLiteSession # noqa: F401

return AsyncSQLiteSession
except ModuleNotFoundError as e:
raise ImportError(f"Failed to import AsyncSQLiteSession: {e}") from e

if name == "DaprSession":
try:
from .dapr_session import DaprSession # noqa: F401

return DaprSession
except ModuleNotFoundError as e:
raise ImportError(
"DaprSession requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
) from e

if name == "DAPR_CONSISTENCY_EVENTUAL":
try:
from .dapr_session import DAPR_CONSISTENCY_EVENTUAL # noqa: F401

return DAPR_CONSISTENCY_EVENTUAL
except ModuleNotFoundError as e:
raise ImportError(
"DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
) from e

if name == "DAPR_CONSISTENCY_STRONG":
try:
from .dapr_session import DAPR_CONSISTENCY_STRONG # noqa: F401

return DAPR_CONSISTENCY_STRONG
except ModuleNotFoundError as e:
raise ImportError(
"DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
) from e

if name == "MongoDBSession":
try:
from .mongodb_session import MongoDBSession # noqa: F401

return MongoDBSession
except ModuleNotFoundError as e:
raise ImportError(
"MongoDBSession requires the 'mongodb' extra. "
"Install it with: pip install openai-agents[mongodb]"
) from e

raise AttributeError(f"module {__name__} has no attribute {name}")
def __getattr__(name: str) -> Any:
if name not in _LAZY_EXPORTS:
raise AttributeError(f"module {__name__} has no attribute {name}")

module_name, optional_dependency = _LAZY_EXPORTS[name]
try:
module = import_module(module_name, __name__)
except ModuleNotFoundError as e:
if optional_dependency is None:
raise ImportError(f"Failed to import {name}: {e}") from e
dependency_name, extra_name = optional_dependency
raise_optional_dependency_error(
name,
dependency_name=dependency_name,
extra_name=extra_name,
cause=e,
)

value = getattr(module, name)
globals()[name] = value
return value
19 changes: 19 additions & 0 deletions src/agents/extensions/memory/_optional_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import NoReturn


def raise_optional_dependency_error(
export_name: str,
*,
dependency_name: str,
extra_name: str,
cause: ImportError | None = None,
) -> NoReturn:
error = ImportError(
f"{export_name} requires the '{dependency_name}' extra. "
f"Install it with: pip install openai-agents[{extra_name}]"
)
if cause is None:
raise error
raise error from cause
11 changes: 8 additions & 3 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@
import time
from typing import Any, Final, Literal

from ._optional_imports import raise_optional_dependency_error

try:
from dapr.aio.clients import DaprClient
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
except ImportError as e:
raise ImportError(
"DaprSession requires the 'dapr' package. Install it with: pip install dapr"
) from e
raise_optional_dependency_error(
"DaprSession",
dependency_name="dapr",
extra_name="dapr",
cause=e,
)

from ...items import TResponseInputItem
from ...logger import logger
Expand Down
12 changes: 8 additions & 4 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from datetime import datetime, timezone
from typing import Any

from ._optional_imports import raise_optional_dependency_error

try:
from importlib.metadata import version as _get_version

Expand All @@ -49,10 +51,12 @@
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.driver_info import DriverInfo
except ImportError as e:
raise ImportError(
"MongoDBSession requires the 'pymongo' package (>=4.14). "
"Install it with: pip install openai-agents[mongodb]"
) from e
raise_optional_dependency_error(
"MongoDBSession",
dependency_name="mongodb",
extra_name="mongodb",
cause=e,
)

from ...items import TResponseInputItem
from ...memory.session import SessionABC
Expand Down
11 changes: 8 additions & 3 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@
import time
from typing import Any

from ._optional_imports import raise_optional_dependency_error

try:
import redis.asyncio as redis
from redis.asyncio import Redis
except ImportError as e:
raise ImportError(
"RedisSession requires the 'redis' package. Install it with: pip install redis"
) from e
raise_optional_dependency_error(
"RedisSession",
dependency_name="redis",
extra_name="redis",
cause=e,
)

from ...items import TResponseInputItem
from ...memory.session import SessionABC
Expand Down
154 changes: 154 additions & 0 deletions tests/extensions/memory/test_memory_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

import importlib.abc
import sys
from types import ModuleType

import pytest

_PACKAGE_EXPORTS: tuple[tuple[str, str, str, str, str], ...] = (
(
"EncryptedSession",
"agents.extensions.memory.encrypt_session",
"agents.extensions.memory.encrypt_session",
"cryptography",
"encrypt",
),
("RedisSession", "agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"),
(
"SQLAlchemySession",
"agents.extensions.memory.sqlalchemy_session",
"agents.extensions.memory.sqlalchemy_session",
"sqlalchemy",
"sqlalchemy",
),
("DaprSession", "agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"),
(
"DAPR_CONSISTENCY_EVENTUAL",
"agents.extensions.memory.dapr_session",
"dapr.aio.clients",
"dapr",
"dapr",
),
(
"DAPR_CONSISTENCY_STRONG",
"agents.extensions.memory.dapr_session",
"dapr.aio.clients",
"dapr",
"dapr",
),
(
"MongoDBSession",
"agents.extensions.memory.mongodb_session",
"pymongo.asynchronous.collection",
"mongodb",
"mongodb",
),
)

_DIRECT_MODULE_IMPORTS: tuple[tuple[str, str, str, str], ...] = (
("agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"),
("agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"),
(
"agents.extensions.memory.mongodb_session",
"pymongo.asynchronous.collection",
"mongodb",
"mongodb",
),
)


class _BrokenImportFinder(importlib.abc.MetaPathFinder):
def __init__(self, broken_module: str, error_cls: type[ImportError]) -> None:
self._broken_module = broken_module
self._error_cls = error_cls

def find_spec(
self,
fullname: str,
path: object | None,
target: ModuleType | None = None,
) -> None:
if fullname == self._broken_module:
raise self._error_cls("simulated dependency import failure")
return None


def _reset_package_imports(
monkeypatch: pytest.MonkeyPatch,
memory_module: ModuleType,
symbol: str,
module_name: str,
broken_module: str,
) -> None:
monkeypatch.delitem(memory_module.__dict__, symbol, raising=False)
_reset_loaded_module(monkeypatch, module_name)
_reset_loaded_module(monkeypatch, broken_module)


def _reset_loaded_module(monkeypatch: pytest.MonkeyPatch, module_name: str) -> None:
monkeypatch.delitem(sys.modules, module_name, raising=False)
parent_name, short_name = module_name.rsplit(".", 1)
parent_module = sys.modules.get(parent_name)
if parent_module is not None:
monkeypatch.delitem(parent_module.__dict__, short_name, raising=False)


def _reset_module_imports(
monkeypatch: pytest.MonkeyPatch,
module_name: str,
broken_module: str,
) -> None:
_reset_loaded_module(monkeypatch, module_name)
_reset_loaded_module(monkeypatch, broken_module)


@pytest.mark.parametrize(
("symbol", "module_name", "broken_module", "dependency_name", "extra_name"),
_PACKAGE_EXPORTS,
)
def test_memory_package_imports_point_to_optional_extra(
monkeypatch: pytest.MonkeyPatch,
symbol: str,
module_name: str,
broken_module: str,
dependency_name: str,
extra_name: str,
) -> None:
import agents.extensions.memory as memory_module

_reset_package_imports(monkeypatch, memory_module, symbol, module_name, broken_module)
finder = _BrokenImportFinder(broken_module, ModuleNotFoundError)
monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path])

with pytest.raises(ImportError) as exc_info:
getattr(memory_module, symbol)

assert f"requires the '{dependency_name}' extra" in str(exc_info.value)
assert f"openai-agents[{extra_name}]" in str(exc_info.value)
assert isinstance(exc_info.value.__cause__, ImportError)


@pytest.mark.parametrize(
("module_name", "broken_module", "dependency_name", "extra_name"),
_DIRECT_MODULE_IMPORTS,
)
@pytest.mark.parametrize("error_cls", [ImportError, ModuleNotFoundError])
def test_memory_direct_module_imports_point_to_optional_extra(
monkeypatch: pytest.MonkeyPatch,
module_name: str,
broken_module: str,
dependency_name: str,
extra_name: str,
error_cls: type[ImportError],
) -> None:
_reset_module_imports(monkeypatch, module_name, broken_module)
finder = _BrokenImportFinder(broken_module, error_cls)
monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path])

with pytest.raises(ImportError) as exc_info:
__import__(module_name)

assert f"requires the '{dependency_name}' extra" in str(exc_info.value)
assert f"openai-agents[{extra_name}]" in str(exc_info.value)
assert isinstance(exc_info.value.__cause__, ImportError)
Loading