Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions src/fromager/commands/list_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from rich.table import Table

from fromager import clickext, context
from fromager.hooks import GLOBAL_HOOK_NAMES
from fromager.overrides import OVERRIDE_HOOK_NAMES
from fromager.packagesettings import PatchMap


Expand Down Expand Up @@ -63,31 +65,15 @@ def list_overrides(
variants = sorted(wkctx.settings.all_variants())
variant_names = [str(v) for v in variants]
export_data = []
all_hook_names = GLOBAL_HOOK_NAMES + OVERRIDE_HOOK_NAMES

for name in overridden_packages:
pbi = wkctx.settings.package_build_info(name)
ps = wkctx.settings.package_setting(name)

plugin_hooks: list[str] = []
if pbi.plugin:
for hook in [
# from hooks.py
"post_build",
"post_bootstrap",
"prebuilt_wheel",
# from overrides.py, found by searching for find_override_method
"download_source",
"get_resolver_provider",
"prepare_source",
"build_sdist",
"build_wheel",
"get_build_requirements",
"get_build_sdist_requirements",
"get_build_wheel_requirements",
"expected_source_archive_name",
"expected_source_directory_name",
"add_extra_metadata_to_wheels",
]:
for hook in all_hook_names:
if hasattr(pbi.plugin, hook):
plugin_hooks.append(hook)
plugin_hooks_str = ", ".join(plugin_hooks)
Expand Down
7 changes: 7 additions & 0 deletions src/fromager/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

_mgrs: dict[str, hook.HookManager] = {}

# Event callbacks that run for every package (e.g., after build, after bootstrap).
GLOBAL_HOOK_NAMES: tuple[str, ...] = (
"post_bootstrap",
"post_build",
"prebuilt_wheel",
)

logger = logging.getLogger(__name__)


Expand Down
17 changes: 17 additions & 0 deletions src/fromager/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@

_mgr: extension.ExtensionManager | None = None

# Hooks that per-package plugins can implement to override default build behavior.
OVERRIDE_HOOK_NAMES: tuple[str, ...] = (
"add_extra_metadata_to_wheels",
"build_sdist",
"build_wheel",
"download_source",
"expected_source_archive_name",
"expected_source_directory_name",
"get_build_backend_dependencies",
"get_build_sdist_dependencies",
"get_build_system_dependencies",
"get_install_dependencies_of_sdist",
"get_resolver_provider",
"prepare_source",
"update_extra_environ",
)


def _get_extensions() -> extension.ExtensionManager:
global _mgr
Expand Down
87 changes: 87 additions & 0 deletions tests/test_override_hook_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Verify that hook name constants stay in sync with actual usage.

Uses Python's Abstract Syntax Tree (AST) module to parse source files and
find every string literal passed to hook-dispatch functions, then checks
that the declared constants match.
"""

import ast
import pathlib

from fromager.hooks import GLOBAL_HOOK_NAMES
from fromager.overrides import OVERRIDE_HOOK_NAMES

SRC_DIR = pathlib.Path(__file__).parent.parent / "src" / "fromager"


def _called_function_name(node: ast.Call) -> str | None:
"""Return the simple name of the called function, or None."""
if isinstance(node.func, ast.Name):
return node.func.id
if isinstance(node.func, ast.Attribute):
return node.func.attr
return None


def _collect_string_arg(
source_files: list[pathlib.Path],
func_names: set[str],
arg_index: int,
) -> set[str]:
"""Find every string literal passed at ``arg_index`` to calls of ``func_names``.

Scans the AST of each file for calls like ``func("hook_name", ...)``
and returns the set of string values found at the given position.
"""
found: set[str] = set()
for path in source_files:
tree = ast.parse(path.read_text(), filename=str(path))
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
if _called_function_name(node) not in func_names:
continue
if len(node.args) <= arg_index:
continue
arg = node.args[arg_index]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
found.add(arg.value)
return found


def test_override_hook_names_match_usage() -> None:
"""OVERRIDE_HOOK_NAMES must list every hook passed to
find_override_method / find_and_invoke across the source tree."""
source_files = [
p
for p in SRC_DIR.rglob("*.py")
if p.name != "overrides.py" # skip the forwarding call (uses a variable)
]
used = _collect_string_arg(
source_files,
{"find_and_invoke", "find_override_method"},
arg_index=1,
)
registered = set(OVERRIDE_HOOK_NAMES)
missing = used - registered
extra = registered - used
assert not missing, (
f"Hooks used in source but missing from OVERRIDE_HOOK_NAMES: {missing}"
)
assert not extra, f"Hooks in OVERRIDE_HOOK_NAMES but not used in source: {extra}"


def test_global_hook_names_match_usage() -> None:
"""GLOBAL_HOOK_NAMES must list every hook passed to _get_hooks in hooks.py."""
used = _collect_string_arg(
[SRC_DIR / "hooks.py"],
{"_get_hooks"},
arg_index=0,
)
registered = set(GLOBAL_HOOK_NAMES)
missing = used - registered
extra = registered - used
assert not missing, (
f"Hooks used in hooks.py but missing from GLOBAL_HOOK_NAMES: {missing}"
)
assert not extra, f"Hooks in GLOBAL_HOOK_NAMES but not used in hooks.py: {extra}"
Loading