From 484170104fbb410304acce2972d1e456753f5044 Mon Sep 17 00:00:00 2001 From: Dvir Dukhan <12258836+DvirDukhan@users.noreply.github.com> Date: Thu, 28 May 2026 08:51:59 +0300 Subject: [PATCH] refactor(analyzers): extract TreeSitterAnalyzer base class (T15 #663) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- api/analyzers/javascript/analyzer.py | 84 ++++----------- api/analyzers/kotlin/analyzer.py | 61 +++++------ api/analyzers/python/analyzer.py | 70 ++++-------- api/analyzers/tree_sitter_base.py | 107 +++++++++++++++++++ tests/analyzers/__init__.py | 0 tests/analyzers/fixtures/multilang/sample.js | 2 + tests/analyzers/fixtures/multilang/sample.kt | 2 + tests/analyzers/fixtures/multilang/sample.py | 7 ++ tests/analyzers/test_tree_sitter_base.py | 77 +++++++++++++ 9 files changed, 267 insertions(+), 143 deletions(-) create mode 100644 api/analyzers/tree_sitter_base.py create mode 100644 tests/analyzers/__init__.py create mode 100644 tests/analyzers/fixtures/multilang/sample.js create mode 100644 tests/analyzers/fixtures/multilang/sample.kt create mode 100644 tests/analyzers/fixtures/multilang/sample.py create mode 100644 tests/analyzers/test_tree_sitter_base.py diff --git a/api/analyzers/javascript/analyzer.py b/api/analyzers/javascript/analyzer.py index abc2879f..d5833cb9 100644 --- a/api/analyzers/javascript/analyzer.py +++ b/api/analyzers/javascript/analyzer.py @@ -3,10 +3,8 @@ from pathlib import Path from typing import Optional -from multilspy import SyncLanguageServer from ...entities.entity import Entity -from ...entities.file import File -from ..analyzer import AbstractAnalyzer +from ..tree_sitter_base import TreeSitterAnalyzer import tree_sitter_javascript as tsjs from tree_sitter import Language, Node @@ -15,13 +13,28 @@ logger = logging.getLogger('code_graph') -class JavaScriptAnalyzer(AbstractAnalyzer): +class JavaScriptAnalyzer(TreeSitterAnalyzer): """Analyzer for JavaScript source files using tree-sitter. Extracts functions, classes, and methods from JavaScript code. Resolves class inheritance (extends) and function/method call references. """ + entity_node_types = { + 'function_declaration': "Function", + 'class_declaration': "Class", + 'method_definition': "Method", + } + type_definition_node_types = ('class_declaration',) + callable_definition_node_types = ( + 'function_declaration', + 'method_definition', + 'class_declaration', + ) + callable_exclude_node_types = ('class_declaration',) + type_resolution_keys = ("base_class",) + method_resolution_keys = ("call",) + def __init__(self) -> None: """Initialize the JavaScript analyzer with the tree-sitter JS grammar.""" super().__init__(Language(tsjs.language())) @@ -33,26 +46,6 @@ def add_dependencies(self, path: Path, files: list[Path]) -> None: """ pass - def get_entity_label(self, node: Node) -> str: - """Return the graph label for a given AST node type. - - Args: - node: A tree-sitter AST node representing a JavaScript entity. - - Returns: - One of 'Function', 'Class', or 'Method'. - - Raises: - ValueError: If the node type is not a recognised entity. - """ - if node.type == 'function_declaration': - return "Function" - elif node.type == 'class_declaration': - return "Class" - elif node.type == 'method_definition': - return "Method" - raise ValueError(f"Unknown entity type: {node.type}") - def get_entity_name(self, node: Node) -> str: """Extract the declared name from a JavaScript entity node. @@ -92,10 +85,6 @@ def get_entity_docstring(self, node: Node) -> Optional[str]: return None raise ValueError(f"Unknown entity type: {node.type}") - def get_entity_types(self) -> list[str]: - """Return the tree-sitter node types recognised as JavaScript entities.""" - return ['function_declaration', 'class_declaration', 'method_definition'] - def add_symbols(self, entity: Entity) -> None: """Extract symbols (references) from a JavaScript entity. @@ -128,45 +117,12 @@ def is_dependency(self, file_path: str) -> bool: """ return "node_modules" in Path(file_path).parts - def resolve_path(self, file_path: str, path: Path) -> str: - """Resolve an import path relative to the project root.""" - return file_path - - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - """Resolve a type reference to its class declaration entity.""" - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res - - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - """Resolve a call expression to the target function or method entity.""" - res = [] + def _extract_call_target(self, node: Node) -> Optional[Node]: + """Extract the callable target from a JavaScript call expression.""" if node.type == 'call_expression': func_node = node.child_by_field_name('function') if func_node and func_node.type == 'member_expression': func_node = func_node.child_by_field_name('property') if func_node: node = func_node - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['function_declaration', 'method_definition', 'class_declaration']) - if method_dec and method_dec.type == 'class_declaration': - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res - - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: - """Dispatch symbol resolution based on the symbol category. - - Routes ``base_class`` symbols to type resolution and ``call`` symbols - to method resolution. - """ - if key == "base_class": - return self.resolve_type(files, lsp, file_path, path, symbol) - elif key == "call": - return self.resolve_method(files, lsp, file_path, path, symbol) - else: - raise ValueError(f"Unknown key {key}") + return node diff --git a/api/analyzers/kotlin/analyzer.py b/api/analyzers/kotlin/analyzer.py index 3758c302..208a51af 100644 --- a/api/analyzers/kotlin/analyzer.py +++ b/api/analyzers/kotlin/analyzer.py @@ -2,7 +2,7 @@ from ...entities.entity import Entity from ...entities.file import File from typing import Optional -from ..analyzer import AbstractAnalyzer +from ..tree_sitter_base import TreeSitterAnalyzer from multilspy import SyncLanguageServer @@ -12,7 +12,27 @@ import logging logger = logging.getLogger('code_graph') -class KotlinAnalyzer(AbstractAnalyzer): +class KotlinAnalyzer(TreeSitterAnalyzer): + entity_node_types = { + 'class_declaration': "Class", + 'object_declaration': "Object", + 'function_declaration': "Function", + } + type_definition_node_types = ('class_declaration', 'object_declaration') + callable_definition_node_types = ( + 'function_declaration', + 'class_declaration', + 'object_declaration', + ) + callable_exclude_node_types = ('class_declaration', 'object_declaration') + type_resolution_keys = ( + "implement_interface", + "base_class", + "parameters", + "return_type", + ) + method_resolution_keys = ("call",) + def __init__(self) -> None: super().__init__(Language(tskotlin.language())) @@ -44,7 +64,7 @@ def get_entity_name(self, node: Node) -> str: if child.type == 'identifier': return child.text.decode('utf-8') raise ValueError(f"Cannot extract name from entity type: {node.type}") - + def get_entity_docstring(self, node: Node) -> Optional[str]: if node.type in ['class_declaration', 'object_declaration', 'function_declaration']: # Check for KDoc comment (/** ... */) before the node @@ -54,14 +74,11 @@ def get_entity_docstring(self, node: Node) -> Optional[str]: if comment_text.startswith('/**'): return comment_text return None - raise ValueError(f"Unknown entity type: {node.type}") + raise ValueError(f"Unknown entity type: {node.type}") - def get_entity_types(self) -> list[str]: - return ['class_declaration', 'object_declaration', 'function_declaration'] - def _get_delegation_types(self, entity: Entity) -> list[tuple]: """Extract type identifiers from delegation specifiers in order. - + Returns list of (node, is_constructor_invocation) tuples. constructor_invocation indicates a superclass; plain user_type indicates an interface. """ @@ -91,25 +108,25 @@ def add_symbols(self, entity: Entity) -> None: entity.add_symbol("base_class", node) else: entity.add_symbol("implement_interface", node) - + elif entity.node.type == 'object_declaration': types = self._get_delegation_types(entity) for node, _ in types: entity.add_symbol("implement_interface", node) - + elif entity.node.type == 'function_declaration': # Find function calls captures = self._captures("(call_expression) @reference.call", entity.node) if 'reference.call' in captures: for caller in captures['reference.call']: entity.add_symbol("call", caller) - + # Find parameters with types captures = self._captures("(parameter (user_type (identifier) @parameter))", entity.node) if 'parameter' in captures: for parameter in captures['parameter']: entity.add_symbol("parameters", parameter) - + # Find return type captures = self._captures("(function_declaration (user_type (identifier) @return_type))", entity.node) if 'return_type' in captures: @@ -120,18 +137,6 @@ def is_dependency(self, file_path: str) -> bool: # Check if file is in a dependency directory (e.g., build, .gradle cache) return "build/" in file_path or ".gradle/" in file_path or "/cache/" in file_path - def resolve_path(self, file_path: str, path: Path) -> str: - # For Kotlin, just return the file path as-is for now - return file_path - - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration', 'object_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: res = [] # For call expressions, we need to extract the function name @@ -147,11 +152,3 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ res.append(file.entities[method_dec]) break return res - - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: - if key in ["implement_interface", "base_class", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) - elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) - else: - raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 7a991202..7757ff60 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -1,12 +1,12 @@ import os import subprocess -from multilspy import SyncLanguageServer from pathlib import Path import tomllib -from ...entities import * from typing import Optional -from ..analyzer import AbstractAnalyzer + +from ...entities.entity import Entity +from ..tree_sitter_base import TreeSitterAnalyzer import tree_sitter_python as tspython from tree_sitter import Language, Node @@ -14,10 +14,19 @@ import logging logger = logging.getLogger('code_graph') -class PythonAnalyzer(AbstractAnalyzer): +class PythonAnalyzer(TreeSitterAnalyzer): + entity_node_types = { + 'class_definition': "Class", + 'function_definition': "Function", + } + type_definition_node_types = ('class_definition',) + callable_definition_node_types = ('function_definition', 'class_definition') + type_resolution_keys = ("base_class", "parameters", "return_type") + method_resolution_keys = ("call",) + def __init__(self) -> None: super().__init__(Language(tspython.language())) - + def add_dependencies(self, path: Path, files: list[Path]): if Path(f"{path}/venv").is_dir(): return @@ -40,18 +49,11 @@ def add_dependencies(self, path: Path, files: list[Path]): for requirement in requirements: files.extend(Path(f"{path}/venv/lib/").rglob(f"**/site-packages/{requirement}/*.py")) - def get_entity_label(self, node: Node) -> str: - if node.type == 'class_definition': - return "Class" - elif node.type == 'function_definition': - return "Function" - raise ValueError(f"Unknown entity type: {node.type}") - def get_entity_name(self, node: Node) -> str: if node.type in ['class_definition', 'function_definition']: return node.child_by_field_name('name').text.decode('utf-8') raise ValueError(f"Unknown entity type: {node.type}") - + def get_entity_docstring(self, node: Node) -> Optional[str]: if node.type in ['class_definition', 'function_definition']: body = node.child_by_field_name('body') @@ -59,11 +61,8 @@ def get_entity_docstring(self, node: Node) -> Optional[str]: docstring_node = body.children[0].child(0) return docstring_node.text.decode('utf-8') return None - raise ValueError(f"Unknown entity type: {node.type}") - - def get_entity_types(self) -> list[str]: - return ['class_definition', 'function_definition'] - + raise ValueError(f"Unknown entity type: {node.type}") + def add_symbols(self, entity: Entity) -> None: if entity.node.type == 'class_definition': superclasses = entity.node.child_by_field_name("superclasses") @@ -88,37 +87,14 @@ def add_symbols(self, entity: Entity) -> None: def is_dependency(self, file_path: str) -> bool: return "venv" in file_path - def resolve_path(self, file_path: str, path: Path) -> str: - return file_path - - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, node: Node) -> list[Entity]: - res = [] + def _extract_type_target(self, node: Node) -> Optional[Node]: if node.type == 'attribute': - node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_definition']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + return node.child_by_field_name('attribute') + return node - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] + def _extract_call_target(self, node: Node) -> Optional[Node]: if node.type == 'call': node = node.child_by_field_name('function') - if node.type == 'attribute': + if node and node.type == 'attribute': node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['function_definition', 'class_definition']) - if not method_dec: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res - - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: - if key in ["base_class", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) - elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) - else: - raise ValueError(f"Unknown key {key}") + return node diff --git a/api/analyzers/tree_sitter_base.py b/api/analyzers/tree_sitter_base.py new file mode 100644 index 00000000..459ccc38 --- /dev/null +++ b/api/analyzers/tree_sitter_base.py @@ -0,0 +1,107 @@ +"""Shared base class for tree-sitter-backed analyzers.""" + +from pathlib import Path +from typing import Optional + +from multilspy import SyncLanguageServer +from tree_sitter import Node + +from api.entities.entity import Entity +from api.entities.file import File + +from .analyzer import AbstractAnalyzer + + +class TreeSitterAnalyzer(AbstractAnalyzer): + """Base implementation for analyzers that use tree-sitter plus LSP resolution. + + Subclasses declare the node types they treat as graph entities and the symbol + keys that resolve to type or callable definitions. Language-specific AST + normalization can be implemented by overriding the target-extraction hooks. + """ + + entity_node_types: dict[str, str] = {} + type_definition_node_types: tuple[str, ...] = () + callable_definition_node_types: tuple[str, ...] = () + callable_exclude_node_types: tuple[str, ...] = () + type_resolution_keys: tuple[str, ...] = () + method_resolution_keys: tuple[str, ...] = () + + def resolve_path(self, file_path: str, path: Path) -> str: + """Resolve an LSP path into the key used by the analyzed file map.""" + return file_path + + def get_entity_types(self) -> list[str]: + """Return the tree-sitter node types recognized as graph entities.""" + return list(self.entity_node_types.keys()) + + def get_entity_label(self, node: Node) -> str: + """Return the graph label for an entity node declared by the subclass.""" + try: + return self.entity_node_types[node.type] + except KeyError as exc: + raise ValueError(f"Unknown entity type: {node.type}") from exc + + def resolve_symbol( + self, + files: dict[Path, File], + lsp: SyncLanguageServer, + file_path: Path, + path: Path, + key: str, + symbol: Node, + ) -> list[Entity]: + """Dispatch a captured symbol to type or callable resolution.""" + if key in self.type_resolution_keys: + return self.resolve_type(files, lsp, file_path, path, symbol) + if key in self.method_resolution_keys: + return self.resolve_method(files, lsp, file_path, path, symbol) + raise ValueError(f"Unknown key {key}") + + def _extract_call_target(self, node: Node) -> Optional[Node]: + """Normalize a call symbol before resolving it to a callable definition.""" + return node + + def _extract_type_target(self, node: Node) -> Optional[Node]: + """Normalize a type symbol before resolving it to a type definition.""" + return node + + def resolve_type( + self, + files: dict[Path, File], + lsp: SyncLanguageServer, + file_path: Path, + path: Path, + node: Node, + ) -> list[Entity]: + """Resolve a type reference to matching type-definition entities.""" + res = [] + target = self._extract_type_target(node) + if target is None: + return res + for file, resolved_node in self.resolve(files, lsp, file_path, path, target): + type_dec = self.find_parent(resolved_node, self.type_definition_node_types) + if type_dec in file.entities: + res.append(file.entities[type_dec]) + return res + + def resolve_method( + self, + files: dict[Path, File], + lsp: SyncLanguageServer, + file_path: Path, + path: Path, + node: Node, + ) -> list[Entity]: + """Resolve a call reference to matching callable-definition entities.""" + res = [] + target = self._extract_call_target(node) + if target is None: + return res + for file, resolved_node in self.resolve(files, lsp, file_path, path, target): + method_dec = self.find_parent(resolved_node, self.callable_definition_node_types) + if method_dec and method_dec.type in self.callable_exclude_node_types: + continue + if method_dec in file.entities: + res.append(file.entities[method_dec]) + return res diff --git a/tests/analyzers/__init__.py b/tests/analyzers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/analyzers/fixtures/multilang/sample.js b/tests/analyzers/fixtures/multilang/sample.js new file mode 100644 index 00000000..77782571 --- /dev/null +++ b/tests/analyzers/fixtures/multilang/sample.js @@ -0,0 +1,2 @@ +class JsThing { run() { return jsHelper(); } } +function jsHelper() { return 1; } diff --git a/tests/analyzers/fixtures/multilang/sample.kt b/tests/analyzers/fixtures/multilang/sample.kt new file mode 100644 index 00000000..11d6e4a3 --- /dev/null +++ b/tests/analyzers/fixtures/multilang/sample.kt @@ -0,0 +1,2 @@ +class KtThing { fun run(): Int = ktHelper() } +fun ktHelper(): Int = 1 diff --git a/tests/analyzers/fixtures/multilang/sample.py b/tests/analyzers/fixtures/multilang/sample.py new file mode 100644 index 00000000..3277d993 --- /dev/null +++ b/tests/analyzers/fixtures/multilang/sample.py @@ -0,0 +1,7 @@ +class PyThing: + def run(self): + return py_helper() + + +def py_helper(): + return 1 diff --git a/tests/analyzers/test_tree_sitter_base.py b/tests/analyzers/test_tree_sitter_base.py new file mode 100644 index 00000000..e88976f9 --- /dev/null +++ b/tests/analyzers/test_tree_sitter_base.py @@ -0,0 +1,77 @@ +from collections import Counter +from pathlib import Path + +from api.analyzers.javascript.analyzer import JavaScriptAnalyzer +from api.analyzers.kotlin.analyzer import KotlinAnalyzer +from api.analyzers.python.analyzer import PythonAnalyzer +from api.analyzers.source_analyzer import SourceAnalyzer, analyzers +from api.entities.file import File + + +class MockGraph: + def __init__(self): + self._next_id = 1 + self.files = [] + self.entities = {} + self.edges = [] + + def add_file(self, file): + file.id = self._next_id + self._next_id += 1 + self.files.append(file) + + def add_entity(self, label, name, doc, path, src_start, src_end, props): + entity_id = self._next_id + self._next_id += 1 + self.entities[entity_id] = { + "label": label, + "name": name, + "doc": doc, + "path": path, + "src_start": src_start, + "src_end": src_end, + "props": props, + } + return entity_id + + def connect_entities(self, rel, src, dest, props=None): + self.edges.append((rel, src, dest, props)) + + +def test_tree_sitter_subclasses_expose_expected_entity_node_types(): + assert list(PythonAnalyzer.entity_node_types) == [ + 'class_definition', + 'function_definition', + ] + assert list(JavaScriptAnalyzer.entity_node_types) == [ + 'function_declaration', + 'class_declaration', + 'method_definition', + ] + assert list(KotlinAnalyzer.entity_node_types) == [ + 'class_declaration', + 'object_declaration', + 'function_declaration', + ] + + +def test_tree_sitter_multilanguage_fixture_graph_counts(): + source_analyzer = SourceAnalyzer() + graph = MockGraph() + fixture_dir = Path(__file__).parent / "fixtures" / "multilang" + + for file_path in sorted(fixture_dir.iterdir()): + analyzer = analyzers[file_path.suffix] + tree = analyzer.parser.parse(file_path.read_bytes()) + file = File(file_path, tree) + graph.add_file(file) + source_analyzer.create_hierarchy(file, analyzer, graph) + + assert len(graph.files) == 3 + assert len(graph.entities) == 9 + assert len(graph.files) + len(graph.entities) == 12 + assert len(graph.edges) == 9 + assert Counter(entity["label"] for entity in graph.entities.values()) == Counter( + {"Class": 3, "Function": 4, "Method": 2} + ) + assert Counter(edge[0] for edge in graph.edges) == Counter({"DEFINES": 9})