diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 3e93ce163..a1eaa7513 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -577,7 +577,7 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M def collect_existing_class_names(tree: ast.Module) -> set[str]: class_names = set() - stack = [tree] + stack: list[ast.AST] = [tree] while stack: node = stack.pop() @@ -586,27 +586,14 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]: class_names.add(node.name) # Only traverse nodes that can contain ClassDef nodes - if isinstance( - node, - ( - ast.Module, - ast.ClassDef, - ast.FunctionDef, - ast.AsyncFunctionDef, - ast.If, - ast.For, - ast.AsyncFor, - ast.While, - ast.With, - ast.AsyncWith, - ast.Try, - ast.ExceptHandler, - ), - ): - stack.extend(getattr(node, "body", [])) - stack.extend(getattr(node, "orelse", [])) - stack.extend(getattr(node, "finalbody", [])) - stack.extend(getattr(node, "handlers", [])) + if hasattr(node, "body"): + stack.extend(node.body) + if hasattr(node, "orelse"): + stack.extend(node.orelse) + if hasattr(node, "finalbody"): + stack.extend(node.finalbody) + if hasattr(node, "handlers"): + stack.extend(node.handlers) return class_names