diff --git a/build_scripts/gen_api_md.py b/build_scripts/gen_api_md.py index 890d5cbd3..700b6ed12 100644 --- a/build_scripts/gen_api_md.py +++ b/build_scripts/gen_api_md.py @@ -22,6 +22,11 @@ API_JSON_DIR = Path("doc/_api") API_MD_DIR = Path("doc/api") +# Modules excluded from generated API docs (internal implementation details) +EXCLUDED_MODULES = { + "pyrit.backend", +} + def render_params(params: list[dict]) -> str: """Render parameter list as a markdown table.""" @@ -88,6 +93,30 @@ def render_signature(member: dict) -> str: return f"({sig})" +def _escape_docstring_examples(text: str) -> str: + """Wrap doctest-style examples (>>> lines) in code fences.""" + lines = text.split("\n") + result: list[str] = [] + in_example = False + for line in lines: + stripped = line.strip() + if stripped.startswith(">>>") and not in_example: + in_example = True + result.append("```python") + result.append(line) + elif in_example and stripped.startswith((">>>", "...")): + result.append(line) + elif in_example: + result.append("```") + in_example = False + result.append(line) + else: + result.append(line) + if in_example: + result.append("```") + return "\n".join(result) + + def render_function(func: dict, heading_level: str = "###") -> str: """Render a function as markdown.""" name = func["name"] @@ -97,18 +126,13 @@ def render_function(func: dict, heading_level: str = "###") -> str: ret = func.get("returns_annotation", "") ret_str = f" → {ret}" if ret else "" - # Use heading for name, code block for full signature if long - full_sig = f"{prefix}{name}{sig}{ret_str}" - if len(full_sig) > 80: - parts = [f"{heading_level} {prefix}{name}\n"] - parts.append(f"```python\n{prefix}{name}{sig}{ret_str}\n```\n") - else: - parts = [f"{heading_level} `{full_sig}`\n"] + parts = [f"{heading_level} `{prefix}{name}`\n"] + parts.append(f"```python\n{prefix}{name}{sig}{ret_str}\n```\n") ds = func.get("docstring", {}) if ds: if ds.get("text"): - parts.append(ds["text"] + "\n") + parts.append(_escape_docstring_examples(ds["text"]) + "\n") params_table = render_params(ds.get("params", [])) if params_table: parts.append(params_table + "\n") @@ -128,11 +152,13 @@ def render_class(cls: dict) -> str: bases = cls.get("bases", []) bases_str = f"({', '.join(bases)})" if bases else "" - parts = [f"## `class {name}{bases_str}`\n"] + parts = [f"## `{name}`\n"] + if bases_str: + parts.append(f"Bases: `{bases_str[1:-1]}`\n") ds = cls.get("docstring", {}) if ds and ds.get("text"): - parts.append(ds["text"] + "\n") + parts.append(_escape_docstring_examples(ds["text"]) + "\n") # __init__ init = cls.get("init") @@ -151,6 +177,16 @@ def render_class(cls: dict) -> str: return "\n".join(parts) +def render_alias(alias: dict) -> str: + """Render an alias as markdown.""" + name = alias["name"] + target = alias.get("target", "") + parts = [f"### `{name}`\n"] + if target: + parts.append(f"Alias of `{target}`.\n") + return "\n".join(parts) + + def render_module(data: dict) -> str: """Render a full module page.""" mod_name = data["name"] @@ -162,10 +198,8 @@ def render_module(data: dict) -> str: members = data.get("members", []) - # Separate classes and functions classes = [m for m in members if m.get("kind") == "class"] functions = [m for m in members if m.get("kind") == "function"] - aliases = [m for m in members if m.get("kind") == "alias"] if functions: parts.append("## Functions\n") @@ -176,89 +210,181 @@ def render_module(data: dict) -> str: return "\n".join(parts) -def split_aggregate_json(api_json_dir: Path) -> None: - """Split aggregate JSON files that contain nested submodules into individual files. +def _build_definition_index( + data: dict, + index: dict | None = None, + name_to_modules: dict[str, list[str]] | None = None, +) -> tuple[dict, dict[str, list[str]]]: + """Build a flat lookup from fully-qualified name to member definition. + + Also builds a reverse lookup mapping each short member name to the list of + module paths where it is defined, so imports can be distinguished from native + definitions. + """ + if index is None: + index = {} + if name_to_modules is None: + name_to_modules = {} + mod_name = data.get("name", "") + for member in data.get("members", []): + kind = member.get("kind", "") + name = member.get("name", "") + if kind in ("class", "function") and name: + fqn = f"{mod_name}.{name}" if mod_name else name + index[fqn] = member + name_to_modules.setdefault(name, []).append(mod_name) + if kind == "module": + _build_definition_index(member, index, name_to_modules) + return index, name_to_modules + + +def _resolve_aliases(modules: list[dict], definition_index: dict, name_to_modules: dict[str, list[str]]) -> None: + """Replace bare alias entries with the full definition they point to. + + Aliases whose targets resolve to a class or function in the definition index + are swapped in-place so they render with full documentation. Unresolvable + aliases that appear to reference a pyrit class (capitalized name with a + pyrit target) are kept as minimal class stubs. Aliases pointing outside the + pyrit namespace are dropped. + + Also removes classes/functions that griffe reports as direct members but are + actually imported from a different pyrit module (the same short name is + defined in another module in the index). + """ + for module in modules: + mod_name = module.get("name", "") + resolved_members: list[dict] = [] + for member in module.get("members", []): + kind = member.get("kind", "") + name = member.get("name", "") + + if kind == "alias": + target = member.get("target", "") + if not target.startswith("pyrit."): + continue # External import (stdlib, third-party) – skip + if target in definition_index: + defn = definition_index[target].copy() + defn["name"] = name + resolved_members.append(defn) + elif name and name[0].isupper(): + resolved_members.append({"name": name, "kind": "class"}) + elif kind in ("class", "function"): + # Keep only if this module's tree contains a definition. + # A member defined in this module or its children is native; + # appearances in unrelated modules are just imports. + defining_modules = name_to_modules.get(name, []) + is_native = not defining_modules or any( + m == mod_name or m.startswith(mod_name + ".") for m in defining_modules + ) + if is_native: + resolved_members.append(member) + else: + resolved_members.append(member) + + module["members"] = resolved_members + + +def _expand_module(module: dict) -> list[dict]: + """Recursively expand pure-aggregate modules into their children. + + A pure-aggregate module has only submodule members and no direct public API + (classes, functions, aliases). Its children are returned instead, recursing + further if a child is also a pure aggregate. + """ + members = module.get("members", []) + has_api = any(m.get("kind") in ("class", "function", "alias") for m in members) + submodules = [m for m in members if m.get("kind") == "module"] + + if has_api or not submodules: + # Module has its own API, or is a leaf – keep it (filter empty later) + return [module] + + # Pure aggregate – recurse into children + result: list[dict] = [] + for sub in submodules: + result.extend(_expand_module(sub)) + return result + + +def collect_top_level_modules(api_json_dir: Path) -> list[dict]: + """Collect top-level modules from aggregate JSON files. When pydoc2json.py runs with --submodules, it produces a single JSON file - (e.g. pyrit_all.json) whose members are submodules. This function recursively - splits those nested submodules into individual JSON files so that each - submodule gets its own API reference page. + (e.g. pyrit_all.json) whose members are submodules. We only generate pages + for the public packages users import from, not for deeply nested internal + submodules whose content is re-exported by the parent. + + Pure-aggregate modules (those with only submodule members) are recursively + expanded so their children with real API surface get their own pages. """ + modules: list[dict] = [] for jf in sorted(api_json_dir.glob("*.json")): data = json.loads(jf.read_text(encoding="utf-8")) - _split_submodules(data, jf.name, api_json_dir) + modules.extend(_expand_module(data)) - -def _split_submodules(data: dict, source_name: str, api_json_dir: Path) -> None: - """Recursively extract and write submodule members to individual JSON files.""" - for member in data.get("members", []): - if member.get("kind") != "module": - continue - sub_name = member["name"] - sub_path = api_json_dir / f"{sub_name}.json" - if not sub_path.exists(): - sub_path.write_text(json.dumps(member, indent=2, default=str), encoding="utf-8") - print(f"Split {sub_name} from {source_name}") - # Recurse into nested submodules - _split_submodules(member, source_name, api_json_dir) + # Drop excluded and empty modules + return [ + m + for m in modules + if not any(m.get("name", "").startswith(ex) for ex in EXCLUDED_MODULES) + and any(member.get("kind") in ("class", "function", "alias") for member in m.get("members", [])) + ] def main() -> None: API_MD_DIR.mkdir(parents=True, exist_ok=True) - # Split aggregate JSON files (e.g. pyrit_all.json) into per-module files - split_aggregate_json(API_JSON_DIR) - - # Exclude aggregate files that only contain submodules (no direct classes/functions) json_files = sorted(API_JSON_DIR.glob("*.json")) if not json_files: print("No JSON files found in", API_JSON_DIR) return - # Collect module data, skipping pure-aggregate files - modules = [] + modules = collect_top_level_modules(API_JSON_DIR) + + # Build a lookup of all definitions and resolve aliases to their targets + definition_index: dict = {} + name_to_modules: dict[str, list[str]] = {} for jf in json_files: data = json.loads(jf.read_text(encoding="utf-8")) + _build_definition_index(data, definition_index, name_to_modules) + _resolve_aliases(modules, definition_index, name_to_modules) + + # Generate per-module pages + for data in modules: + mod_name = data["name"] + slug = mod_name.replace(".", "_") + md_path = API_MD_DIR / f"{slug}.md" + content = render_module(data) members = data.get("members", []) - # Skip files whose members are all submodules (aggregates like pyrit_all.json) - non_module_members = [m for m in members if m.get("kind") != "module"] - if not non_module_members and any(m.get("kind") == "module" for m in members): - continue - modules.append(data) + rendered_count = sum(1 for m in members if m.get("kind") in ("class", "function")) + md_path.write_text(content, encoding="utf-8") + print(f"Written {md_path} ({rendered_count} members)") # Generate index page index_parts = ["# API Reference\n"] for data in modules: mod_name = data["name"] members = data.get("members", []) - member_count = len(members) slug = mod_name.replace(".", "_") - classes = [m["name"] for m in members if m.get("kind") == "class"][:8] - preview = ", ".join(f"`{c}`" for c in classes) - if len(classes) < member_count: - preview += f" ... ({member_count} total)" + + classes = [f"`{m['name']}`" for m in members if m.get("kind") == "class"] + functions = [f"`{m['name']}()`" for m in members if m.get("kind") == "function"] + rendered_count = len(classes) + len(functions) + preview_items = (classes + functions)[:8] + preview = ", ".join(preview_items) + if rendered_count > len(preview_items): + preview += f" ... ({rendered_count} total)" + index_parts.append(f"## [{mod_name}]({slug}.md)\n") if preview: index_parts.append(preview + "\n") + else: + index_parts.append("_No public API members detected._\n") index_path = API_MD_DIR / "index.md" index_path.write_text("\n".join(index_parts), encoding="utf-8") print(f"Written {index_path}") - # Generate per-module pages - for data in modules: - mod_name = data["name"] - members = data.get("members", []) - # Skip modules with no members and no meaningful docstring - ds_text = (data.get("docstring") or {}).get("text", "") - if not members and len(ds_text) < 50: - continue - slug = mod_name.replace(".", "_") - md_path = API_MD_DIR / f"{slug}.md" - content = render_module(data) - md_path.write_text(content, encoding="utf-8") - print(f"Written {md_path} ({len(members)} members)") - if __name__ == "__main__": main() diff --git a/build_scripts/pydoc2json.py b/build_scripts/pydoc2json.py index 8a053869e..1ae100f3b 100644 --- a/build_scripts/pydoc2json.py +++ b/build_scripts/pydoc2json.py @@ -141,6 +141,33 @@ def class_to_dict(cls: griffe.Class) -> dict: return result +def _resolve_alias_from_source(target_path: str) -> dict | None: + """Try to resolve an unresolvable alias by loading the target .py file directly. + + When griffe cannot resolve an alias (e.g. due to missing __init__.py in + namespace packages), fall back to parsing the individual source file and + extracting the class or function definition. + """ + parts = target_path.rsplit(".", 1) + if len(parts) != 2: + return None + module_path, member_name = parts + source_file = Path(module_path.replace(".", "/") + ".py") + if not source_file.exists(): + return None + try: + code = source_file.read_text(encoding="utf-8") + file_mod = griffe.visit(module_path, code=code, filepath=source_file) + member = file_mod.members.get(member_name) + if isinstance(member, griffe.Class): + return class_to_dict(member) + if isinstance(member, griffe.Function): + return function_to_dict(member) + except Exception: + pass + return None + + def module_to_dict(mod: griffe.Module, include_submodules: bool = False) -> dict: """Convert a griffe Module to a structured dict.""" result = { @@ -167,8 +194,13 @@ def module_to_dict(mod: griffe.Module, include_submodules: bool = False) -> dict elif isinstance(target, griffe.Function): result["members"].append(function_to_dict(target)) except Exception: - # Unresolvable alias — just record the name - result["members"].append({"name": name, "kind": "alias", "target": str(member.target_path)}) + # Griffe cannot resolve (e.g. namespace package) — try source file + resolved = _resolve_alias_from_source(str(member.target_path)) + if resolved: + resolved["name"] = name + result["members"].append(resolved) + else: + result["members"].append({"name": name, "kind": "alias", "target": str(member.target_path)}) elif isinstance(member, griffe.Module) and include_submodules: result["members"].append(module_to_dict(member, include_submodules=True)) except Exception as e: diff --git a/doc/myst.yml b/doc/myst.yml index c1f28dd7e..aaac88d9f 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -167,46 +167,32 @@ project: children: - file: api/pyrit_analytics.md - file: api/pyrit_auth.md + - file: api/pyrit_cli_frontend_core.md + - file: api/pyrit_cli_pyrit_backend.md + - file: api/pyrit_cli_pyrit_scan.md + - file: api/pyrit_cli_pyrit_shell.md - file: api/pyrit_common.md + - file: api/pyrit_datasets.md - file: api/pyrit_embedding.md - file: api/pyrit_exceptions.md - file: api/pyrit_executor_attack.md - children: - - file: api/pyrit_executor_attack_single_turn.md - - file: api/pyrit_executor_attack_multi_turn.md - - file: api/pyrit_executor_attack_core.md - - file: api/pyrit_executor_attack_component.md - - file: api/pyrit_executor_attack_printer.md - file: api/pyrit_executor_benchmark.md - file: api/pyrit_executor_core.md - file: api/pyrit_executor_promptgen.md - children: - - file: api/pyrit_executor_promptgen_core.md - - file: api/pyrit_executor_promptgen_fuzzer.md - file: api/pyrit_executor_workflow.md - file: api/pyrit_identifiers.md - file: api/pyrit_memory.md - file: api/pyrit_message_normalizer.md - file: api/pyrit_models.md - children: - - file: api/pyrit_models_seeds.md - file: api/pyrit_prompt_converter.md - file: api/pyrit_prompt_normalizer.md - file: api/pyrit_prompt_target.md - file: api/pyrit_registry.md - children: - - file: api/pyrit_registry_class_registries.md - - file: api/pyrit_registry_instance_registries.md - file: api/pyrit_scenario.md - children: - - file: api/pyrit_scenario_core.md - - file: api/pyrit_scenario_scenarios_airt.md - - file: api/pyrit_scenario_scenarios_foundry.md - - file: api/pyrit_scenario_scenarios_garak.md - file: api/pyrit_score.md - file: api/pyrit_setup.md - children: - - file: api/pyrit_setup_initializers.md + - file: api/pyrit_show_versions.md + - file: api/pyrit_ui.md - file: blog/README.md children: - file: blog/2025_06_06.md diff --git a/pyrit/cli/banner.py b/pyrit/cli/_banner.py similarity index 99% rename from pyrit/cli/banner.py rename to pyrit/cli/_banner.py index 243267ff5..859cb107a 100644 --- a/pyrit/cli/banner.py +++ b/pyrit/cli/_banner.py @@ -23,7 +23,7 @@ from enum import Enum from typing import Optional -from pyrit.cli.banner_assets import BRAILLE_RACCOON, PYRIT_LETTERS, PYRIT_WIDTH, RACCOON_TAIL +from pyrit.cli._banner_assets import BRAILLE_RACCOON, PYRIT_LETTERS, PYRIT_WIDTH, RACCOON_TAIL class ColorRole(Enum): diff --git a/pyrit/cli/banner_assets.py b/pyrit/cli/_banner_assets.py similarity index 100% rename from pyrit/cli/banner_assets.py rename to pyrit/cli/_banner_assets.py diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index c07857bca..3f84eff34 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -20,7 +20,8 @@ if TYPE_CHECKING: from pyrit.models.scenario_result import ScenarioResult -from pyrit.cli import banner, frontend_core +from pyrit.cli import _banner as banner +from pyrit.cli import frontend_core class PyRITShell(cmd.Cmd): diff --git a/tests/unit/cli/test_banner.py b/tests/unit/cli/test_banner.py index e6a028afc..d7667b1a4 100644 --- a/tests/unit/cli/test_banner.py +++ b/tests/unit/cli/test_banner.py @@ -4,7 +4,7 @@ import os from unittest.mock import MagicMock, patch -from pyrit.cli.banner import ( +from pyrit.cli._banner import ( ANSI_COLORS, DARK_THEME, LIGHT_THEME, @@ -167,7 +167,7 @@ def test_no_animation_returns_static(self) -> None: assert "Python Risk Identification Tool" in result def test_no_animation_when_not_tty(self) -> None: - with patch("pyrit.cli.banner.can_animate", return_value=False): + with patch("pyrit.cli._banner.can_animate", return_value=False): result = play_animation() assert "Python Risk Identification Tool" in result @@ -176,10 +176,10 @@ def test_animation_writes_frames_to_stdout(self) -> None: mock_stdout.isatty.return_value = True with ( - patch("pyrit.cli.banner.can_animate", return_value=True), - patch("pyrit.cli.banner._detect_theme", return_value=DARK_THEME), - patch("pyrit.cli.banner.time.sleep"), - patch("pyrit.cli.banner.sys.stdout", mock_stdout), + patch("pyrit.cli._banner.can_animate", return_value=True), + patch("pyrit.cli._banner._detect_theme", return_value=DARK_THEME), + patch("pyrit.cli._banner.time.sleep"), + patch("pyrit.cli._banner.sys.stdout", mock_stdout), ): result = play_animation() @@ -203,10 +203,10 @@ def sleep_then_interrupt(duration: float) -> None: raise KeyboardInterrupt with ( - patch("pyrit.cli.banner.can_animate", return_value=True), - patch("pyrit.cli.banner._detect_theme", return_value=DARK_THEME), - patch("pyrit.cli.banner.time.sleep", side_effect=sleep_then_interrupt), - patch("pyrit.cli.banner.sys.stdout", mock_stdout), + patch("pyrit.cli._banner.can_animate", return_value=True), + patch("pyrit.cli._banner._detect_theme", return_value=DARK_THEME), + patch("pyrit.cli._banner.time.sleep", side_effect=sleep_then_interrupt), + patch("pyrit.cli._banner.sys.stdout", mock_stdout), ): result = play_animation() diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index e85ed354a..c70aa43a5 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -9,7 +9,8 @@ from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch -from pyrit.cli import banner, pyrit_shell +from pyrit.cli import _banner as banner +from pyrit.cli import pyrit_shell class TestPyRITShell: @@ -43,7 +44,7 @@ def test_prompt_and_intro(self): # Verify that cmdloop calls play_animation and passes the result as intro with ( - patch("pyrit.cli.banner.play_animation", return_value="TEST_BANNER") as mock_play, + patch("pyrit.cli._banner.play_animation", return_value="TEST_BANNER") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop, ): shell.cmdloop() @@ -58,7 +59,7 @@ def test_cmdloop_honors_explicit_intro(self): shell = pyrit_shell.PyRITShell(context=mock_context) - with patch("pyrit.cli.banner.play_animation") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop: + with patch("pyrit.cli._banner.play_animation") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop: shell.cmdloop(intro="Custom intro") mock_play.assert_not_called()