Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __getattr__(name: str) -> Any:
from .encrypt_session import EncryptedSession # noqa: F401

return EncryptedSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"EncryptedSession requires the 'cryptography' extra. "
"Install it with: pip install openai-agents[encrypt]"
Expand All @@ -53,7 +53,7 @@ def __getattr__(name: str) -> Any:
from .redis_session import RedisSession # noqa: F401

return RedisSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"RedisSession requires the 'redis' extra. "
"Install it with: pip install openai-agents[redis]"
Expand All @@ -64,7 +64,7 @@ def __getattr__(name: str) -> Any:
from .sqlalchemy_session import SQLAlchemySession # noqa: F401

return SQLAlchemySession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"SQLAlchemySession requires the 'sqlalchemy' extra. "
"Install it with: pip install openai-agents[sqlalchemy]"
Expand All @@ -75,23 +75,23 @@ def __getattr__(name: str) -> Any:
from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401

return AdvancedSQLiteSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e

if name == "AsyncSQLiteSession":
try:
from .async_sqlite_session import AsyncSQLiteSession # noqa: F401

return AsyncSQLiteSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(f"Failed to import AsyncSQLiteSession: {e}") from e

if name == "DaprSession":
try:
from .dapr_session import DaprSession # noqa: F401

return DaprSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"DaprSession requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
Expand All @@ -102,7 +102,7 @@ def __getattr__(name: str) -> Any:
from .dapr_session import DAPR_CONSISTENCY_EVENTUAL # noqa: F401

return DAPR_CONSISTENCY_EVENTUAL
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
Expand All @@ -113,7 +113,7 @@ def __getattr__(name: str) -> Any:
from .dapr_session import DAPR_CONSISTENCY_STRONG # noqa: F401

return DAPR_CONSISTENCY_STRONG
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. "
"Install it with: pip install openai-agents[dapr]"
Expand All @@ -124,7 +124,7 @@ def __getattr__(name: str) -> Any:
from .mongodb_session import MongoDBSession # noqa: F401

return MongoDBSession
except ModuleNotFoundError as e:
except ImportError as e:
raise ImportError(
"MongoDBSession requires the 'mongodb' extra. "
"Install it with: pip install openai-agents[mongodb]"
Expand Down
95 changes: 95 additions & 0 deletions tests/extensions/memory/test_memory_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import importlib
import importlib.abc
import sys
from types import ModuleType

import pytest

# (symbol, broken_module, extra_name) tuples covering the extras-backed
# lazy exports in `agents.extensions.memory.__init__`. Each entry asserts
# that the package-level `__getattr__` produces a helpful
# `pip install openai-agents[<extra>]` message even when the backing
# module re-raises the dependency failure as a plain `ImportError`
# (as `redis_session`, `dapr_session`, and `mongodb_session` do).
_EXTRA_EXPORTS: tuple[tuple[str, str, str], ...] = (
("RedisSession", "agents.extensions.memory.redis_session", "redis"),
("DaprSession", "agents.extensions.memory.dapr_session", "dapr"),
(
"DAPR_CONSISTENCY_EVENTUAL",
"agents.extensions.memory.dapr_session",
"dapr",
),
(
"DAPR_CONSISTENCY_STRONG",
"agents.extensions.memory.dapr_session",
"dapr",
),
("MongoDBSession", "agents.extensions.memory.mongodb_session", "mongodb"),
("EncryptedSession", "agents.extensions.memory.encrypt_session", "encrypt"),
(
"SQLAlchemySession",
"agents.extensions.memory.sqlalchemy_session",
"sqlalchemy",
),
)


class _BrokenMemoryModuleFinder(importlib.abc.MetaPathFinder):
def __init__(self, broken_module: str, error_cls: type[ImportError]) -> None:
self._broken_module = broken_module
self._error_cls = error_cls

def find_spec(
self,
fullname: str,
path: object | None,
target: ModuleType | None = None,
) -> None:
if fullname == self._broken_module:
raise self._error_cls("simulated dependency import failure")
return None


def _reset_memory_imports(
monkeypatch: pytest.MonkeyPatch,
memory_module: ModuleType,
broken_module: str,
symbol: str,
) -> None:
monkeypatch.delitem(sys.modules, broken_module, raising=False)
short = broken_module.rsplit(".", 1)[-1]
monkeypatch.delitem(memory_module.__dict__, short, raising=False)
monkeypatch.delitem(memory_module.__dict__, symbol, raising=False)


@pytest.mark.parametrize(
("symbol", "broken_module", "extra"),
_EXTRA_EXPORTS,
)
@pytest.mark.parametrize("error_cls", [ImportError, ModuleNotFoundError])
def test_memory_extras_error_message_points_to_install_extra(
monkeypatch: pytest.MonkeyPatch,
symbol: str,
broken_module: str,
extra: str,
error_cls: type[ImportError],
) -> None:
"""Lazy memory exports must surface the `openai-agents[<extra>]` hint
regardless of whether the backing module raises `ImportError` or
`ModuleNotFoundError`. Backing modules like `redis_session` re-raise
`ImportError`, which used to bypass the outer wrapper's
`except ModuleNotFoundError`."""

import agents.extensions.memory as memory_module

_reset_memory_imports(monkeypatch, memory_module, broken_module, symbol)
finder = _BrokenMemoryModuleFinder(broken_module, error_cls)
monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path])

with pytest.raises(ImportError) as exc_info:
getattr(memory_module, symbol)

assert f"openai-agents[{extra}]" in str(exc_info.value)
assert exc_info.value.__cause__ is not None
Loading