@@ -277,19 +277,10 @@ def _expand_stripped_to_children(
277277 class_bases , class_methods = _build_inheritance_info (tree )
278278
279279 for parent_cls , method_name in unmatched :
280- # parent must actually define this method
281280 if method_name not in class_methods .get (parent_cls , set ()):
282281 continue
283- for cls in class_bases :
284- if cls == parent_cls :
285- continue
286- if method_name in class_methods .get (cls , set ()):
287- continue
288- if (
289- _find_method_definition (cls , method_name , class_bases , class_methods )
290- == parent_cls
291- and (cls , method_name ) in all_failing_tests
292- ):
282+ for cls in _find_all_inheritors (parent_cls , method_name , class_bases , class_methods ):
283+ if (cls , method_name ) in all_failing_tests :
293284 result .add ((cls , method_name ))
294285
295286 return result
@@ -331,16 +322,7 @@ def _consolidate_to_parent(
331322 new_error_messages = dict (error_messages ) if error_messages else {}
332323
333324 for (parent , method_name ), failing_children in groups .items ():
334- # Find ALL classes that inherit this method from parent
335- all_inheritors : set [str ] = set ()
336- for cls_name in class_bases :
337- if cls_name == parent :
338- continue
339- # Skip if this class defines the method itself
340- if method_name in class_methods .get (cls_name , set ()):
341- continue
342- if _find_method_definition (cls_name , method_name , class_bases , class_methods ) == parent :
343- all_inheritors .add (cls_name )
325+ all_inheritors = _find_all_inheritors (parent , method_name , class_bases , class_methods )
344326
345327 if all_inheritors and failing_children >= all_inheritors :
346328 # All inheritors fail → mark on parent instead
@@ -400,6 +382,20 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo
400382 return True
401383
402384
385+ def _method_removal_range (
386+ func_node : ast .FunctionDef | ast .AsyncFunctionDef , lines : list [str ]
387+ ) -> range :
388+ """Line range covering an entire method including decorators and a preceding COMMENT line."""
389+ first = (
390+ func_node .decorator_list [0 ].lineno - 1
391+ if func_node .decorator_list
392+ else func_node .lineno - 1
393+ )
394+ if first > 0 and lines [first - 1 ].strip ().startswith ("#" ) and COMMENT in lines [first - 1 ]:
395+ first -= 1
396+ return range (first , func_node .end_lineno )
397+
398+
403399def _build_inheritance_info (tree : ast .Module ) -> tuple [dict , dict ]:
404400 """
405401 Build inheritance information from AST.
@@ -455,6 +451,20 @@ def _find_method_definition(
455451 return None
456452
457453
454+ def _find_all_inheritors (
455+ parent : str , method_name : str , class_bases : dict , class_methods : dict
456+ ) -> set [str ]:
457+ """Find all classes that inherit *method_name* from *parent* (not overriding it)."""
458+ return {
459+ cls
460+ for cls in class_bases
461+ if cls != parent
462+ and method_name not in class_methods .get (cls , set ())
463+ and _find_method_definition (cls , method_name , class_bases , class_methods )
464+ == parent
465+ }
466+
467+
458468def remove_expected_failures (
459469 contents : str , tests_to_remove : set [tuple [str , str ]]
460470) -> str :
@@ -490,15 +500,7 @@ def remove_expected_failures(
490500 remove_entire_method = _is_super_call_only (item )
491501
492502 if remove_entire_method :
493- first_line = item .lineno - 1
494- if item .decorator_list :
495- first_line = item .decorator_list [0 ].lineno - 1
496- if first_line > 0 :
497- prev_line = lines [first_line - 1 ].strip ()
498- if prev_line .startswith ("#" ) and COMMENT in prev_line :
499- first_line -= 1
500- for i in range (first_line , item .end_lineno ):
501- lines_to_remove .add (i )
503+ lines_to_remove .update (_method_removal_range (item , lines ))
502504 else :
503505 for dec in item .decorator_list :
504506 dec_line = dec .lineno - 1
@@ -662,13 +664,7 @@ def strip_reasonless_expected_failures(
662664 # exists only to apply the decorator; without it
663665 # the override is pointless and blocks parent
664666 # consolidation)
665- first_line = item .decorator_list [0 ].lineno - 1
666- if first_line > 0 :
667- prev = lines [first_line - 1 ].strip ()
668- if prev .startswith ("#" ) and COMMENT in prev :
669- first_line -= 1
670- for i in range (first_line , item .end_lineno ):
671- lines_to_remove .add (i )
667+ lines_to_remove .update (_method_removal_range (item , lines ))
672668 else :
673669 lines_to_remove .add (dec_line )
674670
0 commit comments