Skip to content

Commit 3acbf0c

Browse files
committed
Refactor automark and its test
1 parent eb9738d commit 3acbf0c

File tree

2 files changed

+300
-765
lines changed

2 files changed

+300
-765
lines changed

scripts/update_lib/cmd_auto_mark.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -277,19 +277,10 @@ def _expand_stripped_to_children(
277277
class_bases, class_methods = _build_inheritance_info(tree)
278278

279279
for parent_cls, method_name in unmatched:
280-
# parent must actually define this method
281280
if method_name not in class_methods.get(parent_cls, set()):
282281
continue
283-
for cls in class_bases:
284-
if cls == parent_cls:
285-
continue
286-
if method_name in class_methods.get(cls, set()):
287-
continue
288-
if (
289-
_find_method_definition(cls, method_name, class_bases, class_methods)
290-
== parent_cls
291-
and (cls, method_name) in all_failing_tests
292-
):
282+
for cls in _find_all_inheritors(parent_cls, method_name, class_bases, class_methods):
283+
if (cls, method_name) in all_failing_tests:
293284
result.add((cls, method_name))
294285

295286
return result
@@ -331,16 +322,7 @@ def _consolidate_to_parent(
331322
new_error_messages = dict(error_messages) if error_messages else {}
332323

333324
for (parent, method_name), failing_children in groups.items():
334-
# Find ALL classes that inherit this method from parent
335-
all_inheritors: set[str] = set()
336-
for cls_name in class_bases:
337-
if cls_name == parent:
338-
continue
339-
# Skip if this class defines the method itself
340-
if method_name in class_methods.get(cls_name, set()):
341-
continue
342-
if _find_method_definition(cls_name, method_name, class_bases, class_methods) == parent:
343-
all_inheritors.add(cls_name)
325+
all_inheritors = _find_all_inheritors(parent, method_name, class_bases, class_methods)
344326

345327
if all_inheritors and failing_children >= all_inheritors:
346328
# All inheritors fail → mark on parent instead
@@ -400,6 +382,20 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo
400382
return True
401383

402384

385+
def _method_removal_range(
386+
func_node: ast.FunctionDef | ast.AsyncFunctionDef, lines: list[str]
387+
) -> range:
388+
"""Line range covering an entire method including decorators and a preceding COMMENT line."""
389+
first = (
390+
func_node.decorator_list[0].lineno - 1
391+
if func_node.decorator_list
392+
else func_node.lineno - 1
393+
)
394+
if first > 0 and lines[first - 1].strip().startswith("#") and COMMENT in lines[first - 1]:
395+
first -= 1
396+
return range(first, func_node.end_lineno)
397+
398+
403399
def _build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
404400
"""
405401
Build inheritance information from AST.
@@ -455,6 +451,20 @@ def _find_method_definition(
455451
return None
456452

457453

454+
def _find_all_inheritors(
455+
parent: str, method_name: str, class_bases: dict, class_methods: dict
456+
) -> set[str]:
457+
"""Find all classes that inherit *method_name* from *parent* (not overriding it)."""
458+
return {
459+
cls
460+
for cls in class_bases
461+
if cls != parent
462+
and method_name not in class_methods.get(cls, set())
463+
and _find_method_definition(cls, method_name, class_bases, class_methods)
464+
== parent
465+
}
466+
467+
458468
def remove_expected_failures(
459469
contents: str, tests_to_remove: set[tuple[str, str]]
460470
) -> str:
@@ -490,15 +500,7 @@ def remove_expected_failures(
490500
remove_entire_method = _is_super_call_only(item)
491501

492502
if remove_entire_method:
493-
first_line = item.lineno - 1
494-
if item.decorator_list:
495-
first_line = item.decorator_list[0].lineno - 1
496-
if first_line > 0:
497-
prev_line = lines[first_line - 1].strip()
498-
if prev_line.startswith("#") and COMMENT in prev_line:
499-
first_line -= 1
500-
for i in range(first_line, item.end_lineno):
501-
lines_to_remove.add(i)
503+
lines_to_remove.update(_method_removal_range(item, lines))
502504
else:
503505
for dec in item.decorator_list:
504506
dec_line = dec.lineno - 1
@@ -662,13 +664,7 @@ def strip_reasonless_expected_failures(
662664
# exists only to apply the decorator; without it
663665
# the override is pointless and blocks parent
664666
# consolidation)
665-
first_line = item.decorator_list[0].lineno - 1
666-
if first_line > 0:
667-
prev = lines[first_line - 1].strip()
668-
if prev.startswith("#") and COMMENT in prev:
669-
first_line -= 1
670-
for i in range(first_line, item.end_lineno):
671-
lines_to_remove.add(i)
667+
lines_to_remove.update(_method_removal_range(item, lines))
672668
else:
673669
lines_to_remove.add(dec_line)
674670

0 commit comments

Comments
 (0)