Skip to content

Commit e0f9421

Browse files
authored
[test]: implement AST scanning for operator classes and validate registry entries (#500)
1 parent c336376 commit e0f9421

1 file changed

Lines changed: 158 additions & 0 deletions

File tree

test/cpu_only/test_register.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,164 @@ def _iter_annotation_types(ann):
337337
if errors:
338338
pytest.fail("\n".join(errors), pytrace=False)
339339

340+
def _ast_has_register_decorator(node: "ast.ClassDef") -> bool:
341+
"""Check whether *node* carries ``@OPERATOR_REGISTRY.register()``."""
342+
import ast
343+
for deco in node.decorator_list:
344+
# @OPERATOR_REGISTRY.register()
345+
if (isinstance(deco, ast.Call)
346+
and isinstance(deco.func, ast.Attribute)
347+
and deco.func.attr == "register"
348+
and isinstance(deco.func.value, ast.Name)
349+
and deco.func.value.id == "OPERATOR_REGISTRY"):
350+
return True
351+
# @OPERATOR_REGISTRY.register (no parentheses)
352+
if (isinstance(deco, ast.Attribute)
353+
and deco.attr == "register"
354+
and isinstance(deco.value, ast.Name)
355+
and deco.value.id == "OPERATOR_REGISTRY"):
356+
return True
357+
return False
358+
359+
360+
def _ast_base_names(node: "ast.ClassDef"):
361+
"""Return the set of simple base-class names for *node*."""
362+
import ast
363+
names = set()
364+
for base in node.bases:
365+
if isinstance(base, ast.Name):
366+
names.add(base.id)
367+
elif isinstance(base, ast.Attribute):
368+
names.add(base.attr)
369+
return names
370+
371+
372+
def _scan_operator_classes(operators_dir):
373+
"""
374+
Two-pass AST scan of ``dataflow/operators/``.
375+
376+
Pass 1 — collect intermediate ABC names (class names ending with ``ABC``
377+
that inherit from ``OperatorABC`` or another intermediate ABC).
378+
379+
Pass 2 — collect every *concrete* operator class, i.e. a class that
380+
either carries ``@OPERATOR_REGISTRY.register()`` **or** inherits from
381+
``OperatorABC`` / an intermediate ABC, while its own name does **not**
382+
end with ``ABC``.
383+
384+
Returns
385+
-------
386+
dict {class_name: (rel_path, has_decorator, has_base)}
387+
"""
388+
import ast
389+
from pathlib import Path
390+
391+
operators_dir = Path(operators_dir)
392+
project_root = operators_dir.parent.parent
393+
394+
file_trees = []
395+
for py_file in sorted(operators_dir.rglob("*.py")):
396+
if py_file.name == "__init__.py":
397+
continue
398+
try:
399+
source = py_file.read_text(encoding="utf-8")
400+
tree = ast.parse(source)
401+
except (SyntaxError, UnicodeDecodeError):
402+
continue
403+
rel = py_file.relative_to(project_root).as_posix()
404+
file_trees.append((rel, tree))
405+
406+
# --- pass 1: intermediate ABCs ---
407+
operator_bases = {"OperatorABC"}
408+
changed = True
409+
while changed:
410+
changed = False
411+
for _rel, tree in file_trees:
412+
for node in ast.walk(tree):
413+
if not isinstance(node, ast.ClassDef):
414+
continue
415+
if not node.name.endswith("ABC"):
416+
continue
417+
if node.name in operator_bases:
418+
continue
419+
if _ast_base_names(node) & operator_bases:
420+
operator_bases.add(node.name)
421+
changed = True
422+
423+
# --- pass 2: concrete operator classes ---
424+
result = {}
425+
for rel, tree in file_trees:
426+
for node in ast.walk(tree):
427+
if not isinstance(node, ast.ClassDef):
428+
continue
429+
if node.name.endswith("ABC"):
430+
continue
431+
has_deco = _ast_has_register_decorator(node)
432+
has_base = bool(_ast_base_names(node) & operator_bases)
433+
if has_deco or has_base:
434+
result[node.name] = (rel, has_deco, has_base)
435+
return result
436+
437+
438+
@pytest.mark.cpu
439+
def test_no_operator_missing_from_lazyload():
440+
"""
441+
AST-scan ``dataflow/operators/`` for concrete operator classes (identified
442+
by ``@OPERATOR_REGISTRY.register()`` decorator **or** inheritance from
443+
``OperatorABC`` / intermediate ABCs), then verify every one of them is
444+
present in the registry after ``_get_all()``.
445+
446+
Catches two failure modes:
447+
A. Decorator present, but class not listed in ``__init__.py``
448+
``TYPE_CHECKING`` block → LazyLoad never loads the file.
449+
B. Inherits from ``OperatorABC`` but has **neither** the decorator
450+
**nor** a LazyLoad entry → completely invisible to the framework.
451+
"""
452+
from pathlib import Path
453+
import dataflow
454+
455+
operators_dir = Path(dataflow.__file__).parent / "operators"
456+
ast_classes = _scan_operator_classes(operators_dir)
457+
458+
assert ast_classes, (
459+
"AST scan found zero concrete operator classes — check scan logic."
460+
)
461+
print(f"\n[AST] Found {len(ast_classes)} concrete operator classes")
462+
463+
# --- trigger full LazyLoad, snapshot registry ---
464+
OPERATOR_REGISTRY._get_all()
465+
registered = set(OPERATOR_REGISTRY.get_obj_map().keys())
466+
print(f"[Registry] {len(registered)} operators registered after _get_all()")
467+
468+
# --- diff ---
469+
missing = {
470+
name: info for name, info in ast_classes.items()
471+
if name not in registered
472+
}
473+
474+
if missing:
475+
lines = []
476+
for name, (path, has_deco, has_base) in sorted(missing.items()):
477+
if has_deco and not has_base:
478+
reason = "has @register but missing from __init__.py TYPE_CHECKING"
479+
elif has_base and not has_deco:
480+
reason = ("inherits OperatorABC but MISSING @OPERATOR_REGISTRY.register() "
481+
"AND __init__.py TYPE_CHECKING entry")
482+
else:
483+
reason = "has @register but missing from __init__.py TYPE_CHECKING"
484+
lines.append(f" - {name} -> {path}\n reason: {reason}")
485+
detail = "\n".join(lines)
486+
pytest.fail(
487+
f"\n{len(missing)} operator class(es) defined but NOT in the registry:\n\n"
488+
f"{detail}\n\n"
489+
f"Fix: 1) add @OPERATOR_REGISTRY.register() on the class (if missing),\n"
490+
f" 2) add the import to the corresponding __init__.py "
491+
f"`if TYPE_CHECKING:` block.",
492+
pytrace=False,
493+
)
494+
495+
print(f"[PASS] All {len(ast_classes)} concrete operator classes are in the registry.")
496+
497+
340498
if __name__ == "__main__":
341499
# 全局table,看所有注册的算子的str名称和对应的module路径
342500
# 获得所有算子的类名2class映射

0 commit comments

Comments
 (0)