diff --git a/src/fromager/commands/list_overrides.py b/src/fromager/commands/list_overrides.py index 7d46227f..dca1ca5f 100644 --- a/src/fromager/commands/list_overrides.py +++ b/src/fromager/commands/list_overrides.py @@ -9,6 +9,8 @@ from rich.table import Table from fromager import clickext, context +from fromager.hooks import GLOBAL_HOOK_NAMES +from fromager.overrides import OVERRIDE_HOOK_NAMES from fromager.packagesettings import PatchMap @@ -63,6 +65,7 @@ def list_overrides( variants = sorted(wkctx.settings.all_variants()) variant_names = [str(v) for v in variants] export_data = [] + all_hook_names = GLOBAL_HOOK_NAMES + OVERRIDE_HOOK_NAMES for name in overridden_packages: pbi = wkctx.settings.package_build_info(name) @@ -70,24 +73,7 @@ def list_overrides( plugin_hooks: list[str] = [] if pbi.plugin: - for hook in [ - # from hooks.py - "post_build", - "post_bootstrap", - "prebuilt_wheel", - # from overrides.py, found by searching for find_override_method - "download_source", - "get_resolver_provider", - "prepare_source", - "build_sdist", - "build_wheel", - "get_build_requirements", - "get_build_sdist_requirements", - "get_build_wheel_requirements", - "expected_source_archive_name", - "expected_source_directory_name", - "add_extra_metadata_to_wheels", - ]: + for hook in all_hook_names: if hasattr(pbi.plugin, hook): plugin_hooks.append(hook) plugin_hooks_str = ", ".join(plugin_hooks) diff --git a/src/fromager/hooks.py b/src/fromager/hooks.py index 5acabbec..f5096d9e 100644 --- a/src/fromager/hooks.py +++ b/src/fromager/hooks.py @@ -16,6 +16,13 @@ _mgrs: dict[str, hook.HookManager] = {} +# Event callbacks that run for every package (e.g., after build, after bootstrap). +GLOBAL_HOOK_NAMES: tuple[str, ...] = ( + "post_bootstrap", + "post_build", + "prebuilt_wheel", +) + logger = logging.getLogger(__name__) diff --git a/src/fromager/overrides.py b/src/fromager/overrides.py index 2a95b1a3..7d4d02a0 100644 --- a/src/fromager/overrides.py +++ b/src/fromager/overrides.py @@ -17,6 +17,23 @@ _mgr: extension.ExtensionManager | None = None +# Hooks that per-package plugins can implement to override default build behavior. +OVERRIDE_HOOK_NAMES: tuple[str, ...] = ( + "add_extra_metadata_to_wheels", + "build_sdist", + "build_wheel", + "download_source", + "expected_source_archive_name", + "expected_source_directory_name", + "get_build_backend_dependencies", + "get_build_sdist_dependencies", + "get_build_system_dependencies", + "get_install_dependencies_of_sdist", + "get_resolver_provider", + "prepare_source", + "update_extra_environ", +) + def _get_extensions() -> extension.ExtensionManager: global _mgr diff --git a/tests/test_override_hook_names.py b/tests/test_override_hook_names.py new file mode 100644 index 00000000..cd4094ae --- /dev/null +++ b/tests/test_override_hook_names.py @@ -0,0 +1,87 @@ +"""Verify that hook name constants stay in sync with actual usage. + +Uses Python's Abstract Syntax Tree (AST) module to parse source files and +find every string literal passed to hook-dispatch functions, then checks +that the declared constants match. +""" + +import ast +import pathlib + +from fromager.hooks import GLOBAL_HOOK_NAMES +from fromager.overrides import OVERRIDE_HOOK_NAMES + +SRC_DIR = pathlib.Path(__file__).parent.parent / "src" / "fromager" + + +def _called_function_name(node: ast.Call) -> str | None: + """Return the simple name of the called function, or None.""" + if isinstance(node.func, ast.Name): + return node.func.id + if isinstance(node.func, ast.Attribute): + return node.func.attr + return None + + +def _collect_string_arg( + source_files: list[pathlib.Path], + func_names: set[str], + arg_index: int, +) -> set[str]: + """Find every string literal passed at ``arg_index`` to calls of ``func_names``. + + Scans the AST of each file for calls like ``func("hook_name", ...)`` + and returns the set of string values found at the given position. + """ + found: set[str] = set() + for path in source_files: + tree = ast.parse(path.read_text(), filename=str(path)) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if _called_function_name(node) not in func_names: + continue + if len(node.args) <= arg_index: + continue + arg = node.args[arg_index] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + found.add(arg.value) + return found + + +def test_override_hook_names_match_usage() -> None: + """OVERRIDE_HOOK_NAMES must list every hook passed to + find_override_method / find_and_invoke across the source tree.""" + source_files = [ + p + for p in SRC_DIR.rglob("*.py") + if p.name != "overrides.py" # skip the forwarding call (uses a variable) + ] + used = _collect_string_arg( + source_files, + {"find_and_invoke", "find_override_method"}, + arg_index=1, + ) + registered = set(OVERRIDE_HOOK_NAMES) + missing = used - registered + extra = registered - used + assert not missing, ( + f"Hooks used in source but missing from OVERRIDE_HOOK_NAMES: {missing}" + ) + assert not extra, f"Hooks in OVERRIDE_HOOK_NAMES but not used in source: {extra}" + + +def test_global_hook_names_match_usage() -> None: + """GLOBAL_HOOK_NAMES must list every hook passed to _get_hooks in hooks.py.""" + used = _collect_string_arg( + [SRC_DIR / "hooks.py"], + {"_get_hooks"}, + arg_index=0, + ) + registered = set(GLOBAL_HOOK_NAMES) + missing = used - registered + extra = registered - used + assert not missing, ( + f"Hooks used in hooks.py but missing from GLOBAL_HOOK_NAMES: {missing}" + ) + assert not extra, f"Hooks in GLOBAL_HOOK_NAMES but not used in hooks.py: {extra}"