Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/instructions/models.instructions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
applyTo: "pyrit/models/**"
---

# `pyrit.models` Guidelines

## Import Boundary

`pyrit.models` is the canonical data layer. Files in `pyrit/models/` may
import only from:

- the standard library
- `pydantic`
- `pyrit.common.deprecation`
- other `pyrit.models.*` submodules

If a helper needs another `pyrit.*` package, it does not belong on a model —
put it in that package as a free function or static helper.

The CI test `tests/unit/models/test_import_boundary.py` enforces this using an
allowlist of known transitional violations, each tagged with the phase that
removes it. The list must shrink monotonically: removing an import from source
without also removing its allowlist entry fails the test, and adding a new
unlisted import also fails the test.
88 changes: 87 additions & 1 deletion pyrit/common/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from __future__ import annotations

import importlib
import warnings
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Iterable


def print_deprecation_message(
Expand Down Expand Up @@ -41,3 +42,88 @@ def print_deprecation_message(
DeprecationWarning,
stacklevel=3,
)


def deprecated_kwarg(
values: Any,
*,
old_name: str,
new_name: str,
removed_in: str,
model: str,
) -> Any:
"""
Promote a deprecated kwarg to its new name and emit a DeprecationWarning.

Designed for use inside a Pydantic ``@model_validator(mode="before")``. If
``values`` is not a dict (e.g. Pydantic passed a model instance), it is
returned unchanged. If ``old_name`` is present, it is popped; its value is
assigned to ``new_name`` only when ``new_name`` is not already set.

Args:
values: The pre-validation values dict from a Pydantic validator.
old_name: The deprecated kwarg name.
new_name: The replacement kwarg name.
removed_in: The version in which ``old_name`` will be removed.
model: A label for the model receiving the kwarg, used in the warning.

Returns:
The (possibly modified) ``values`` argument.
"""
if not isinstance(values, dict):
return values
if old_name in values:
old_value = values.pop(old_name)
if new_name not in values:
values[new_name] = old_value
warnings.warn(
f"The '{old_name}' argument to {model} is deprecated and will be "
f"removed in {removed_in}. Use '{new_name}' instead.",
DeprecationWarning,
stacklevel=3,
)
return values


def module_deprecation_getattr(
*,
old_module: str,
target_module: str,
names: Iterable[str],
removed_in: str,
) -> Callable[[str], Any]:
"""
Build a module-level ``__getattr__`` that re-exports names from ``target_module``.

Each name in ``names`` is resolved from ``target_module`` on first access,
with a one-time ``DeprecationWarning`` per name. Attribute access for names
outside the configured set raises ``AttributeError``. Intended for use as
``__getattr__ = module_deprecation_getattr(...)`` in a shim module's
``__init__.py`` or top-level file.

Args:
old_module: The fully-qualified name of the deprecated module (the shim).
target_module: The fully-qualified name of the module to forward to.
names: The names to expose via the shim.
removed_in: The version in which the shim will be removed.

Returns:
A ``__getattr__`` function suitable for module-level assignment.
"""
name_set = frozenset(names)
warned: set[str] = set()

def __getattr__(name: str) -> Any: # noqa: N807 - module __getattr__ hook must use this name
if name not in name_set:
raise AttributeError(f"module {old_module!r} has no attribute {name!r}")
if name not in warned:
warned.add(name)
print_deprecation_message(
old_item=f"{old_module}.{name}",
new_item=f"{target_module}.{name}",
removed_in=removed_in,
)
module = importlib.import_module(target_module)
return getattr(module, name)

return __getattr__
10 changes: 9 additions & 1 deletion pyrit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Public model exports for PyRIT core data structures and helpers."""
"""
Public model exports for PyRIT core data structures and helpers.

``pyrit.models`` is the canonical data layer. Files in this package must
import only from the standard library, ``pydantic``,
``pyrit.common.deprecation``, and other ``pyrit.models.*`` submodules. The
CI test ``tests/unit/models/test_import_boundary.py`` enforces this. See
``.github/instructions/models.instructions.md`` for the rule.
"""

from pyrit.models.attack_result import AttackOutcome, AttackResult, AttackResultT
from pyrit.models.chat_message import (
Expand Down
162 changes: 161 additions & 1 deletion tests/unit/common/test_deprecation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import types
import warnings

from pyrit.common.deprecation import print_deprecation_message
from pyrit.common.deprecation import (
deprecated_kwarg,
module_deprecation_getattr,
print_deprecation_message,
)


def _old_func():
Expand Down Expand Up @@ -59,3 +65,157 @@ def test_deprecation_warning_mixed_types():
assert len(w) == 1
assert "_OldClass" in str(w[0].message)
assert "some.new.path" in str(w[0].message)


# --- deprecated_kwarg ----------------------------------------------------


def test_deprecated_kwarg_promotes_old_to_new():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = deprecated_kwarg(
{"old": 42},
old_name="old",
new_name="new",
removed_in="9.9",
model="ExampleModel",
)
assert result == {"new": 42}
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "old" in str(w[0].message)
assert "new" in str(w[0].message)
assert "ExampleModel" in str(w[0].message)
assert "9.9" in str(w[0].message)


def test_deprecated_kwarg_noop_when_only_new_set():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = deprecated_kwarg(
{"new": 42},
old_name="old",
new_name="new",
removed_in="9.9",
model="ExampleModel",
)
assert result == {"new": 42}
assert len(w) == 0


def test_deprecated_kwarg_does_not_overwrite_new_when_both_set():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = deprecated_kwarg(
{"old": 42, "new": 7},
old_name="old",
new_name="new",
removed_in="9.9",
model="ExampleModel",
)
assert result == {"new": 7}
assert len(w) == 1


def test_deprecated_kwarg_passes_through_non_dict():
sentinel = object()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = deprecated_kwarg(
sentinel,
old_name="old",
new_name="new",
removed_in="9.9",
model="ExampleModel",
)
assert result is sentinel
assert len(w) == 0


# --- module_deprecation_getattr ------------------------------------------


def _make_target_module(name: str) -> types.ModuleType:
module = types.ModuleType(name)
module.exposed_value = 123
module.another_value = "hello"
sys.modules[name] = module
return module


def test_module_deprecation_getattr_resolves_and_warns_once():
target = _make_target_module("pyrit_tests_target_module_for_deprecation")
try:
getter = module_deprecation_getattr(
old_module="legacy.module",
target_module=target.__name__,
names=["exposed_value", "another_value"],
removed_in="9.9",
)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
first = getter("exposed_value")
second = getter("exposed_value") # repeat: no new warning
other = getter("another_value") # different name: warns once

assert first == 123
assert second == 123
assert other == "hello"
assert len(w) == 2
messages = [str(item.message) for item in w]
assert any("legacy.module.exposed_value" in m for m in messages)
assert any("legacy.module.another_value" in m for m in messages)
for item in w:
assert issubclass(item.category, DeprecationWarning)
assert "9.9" in str(item.message)
assert target.__name__ in str(item.message)
finally:
sys.modules.pop(target.__name__, None)


def test_module_deprecation_getattr_raises_for_unknown_name():
target = _make_target_module("pyrit_tests_target_module_for_deprecation_unknown")
try:
getter = module_deprecation_getattr(
old_module="legacy.module",
target_module=target.__name__,
names=["exposed_value"],
removed_in="9.9",
)
try:
getter("does_not_exist")
except AttributeError as exc:
assert "legacy.module" in str(exc)
assert "does_not_exist" in str(exc)
else:
raise AssertionError("Expected AttributeError")
finally:
sys.modules.pop(target.__name__, None)


def test_module_deprecation_getattr_warnings_isolated_per_factory():
"""Each call to the factory has its own one-time-warning state."""
target = _make_target_module("pyrit_tests_target_module_for_deprecation_isolated")
try:
getter_a = module_deprecation_getattr(
old_module="legacy.module.a",
target_module=target.__name__,
names=["exposed_value"],
removed_in="9.9",
)
getter_b = module_deprecation_getattr(
old_module="legacy.module.b",
target_module=target.__name__,
names=["exposed_value"],
removed_in="9.9",
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
getter_a("exposed_value")
getter_a("exposed_value") # no warning
getter_b("exposed_value") # warns once (separate factory)
getter_b("exposed_value") # no warning
assert len(w) == 2
finally:
sys.modules.pop(target.__name__, None)
Loading
Loading