55import os
66from collections import defaultdict
77from dataclasses import dataclass , field
8+ from functools import cache
89from itertools import chain
910from typing import TYPE_CHECKING
1011
4748 from codeflash .languages .python .context .unused_definition_remover import UsageInfo
4849
4950
51+ @cache
52+ def get_jedi_project (project_root_path : str ) -> object :
53+ import jedi
54+
55+ return jedi .Project (path = project_root_path )
56+
57+
5058@dataclass
5159class FileContextCache :
5260 original_module : cst .Module
@@ -102,15 +110,22 @@ def get_code_optimization_context(
102110 testgen_token_limit : int = TESTGEN_CONTEXT_TOKEN_LIMIT ,
103111 call_graph : DependencyResolver | None = None ,
104112) -> CodeOptimizationContext :
113+ project_root_path = project_root_path .resolve ()
114+ jedi_project = get_jedi_project (str (project_root_path ))
115+
105116 # Get FunctionSource representation of helpers of FTO
106117 fto_input = {function_to_optimize .file_path : {function_to_optimize .qualified_name }}
107118 if call_graph is not None :
108119 helpers_of_fto_dict , helpers_of_fto_list = call_graph .get_callees (fto_input )
109120 else :
110- helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (fto_input , project_root_path )
121+ helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (
122+ fto_input , project_root_path , jedi_project = jedi_project
123+ )
111124
112125 # Add function to optimize into helpers of FTO dict, as they'll be processed together
113- fto_as_function_source = get_function_to_optimize_as_function_source (function_to_optimize , project_root_path )
126+ fto_as_function_source = get_function_to_optimize_as_function_source (
127+ function_to_optimize , project_root_path , jedi_project = jedi_project
128+ )
114129 helpers_of_fto_dict [function_to_optimize .file_path ].add (fto_as_function_source )
115130
116131 # Format data to search for helpers of helpers using get_function_sources_from_jedi
@@ -124,7 +139,7 @@ def get_code_optimization_context(
124139 qualified_names .update ({f"{ qn .rsplit ('.' , 1 )[0 ]} .__init__" for qn in qualified_names if "." in qn })
125140
126141 helpers_of_helpers_dict , helpers_of_helpers_list = get_function_sources_from_jedi (
127- helpers_of_fto_qualified_names_dict , project_root_path
142+ helpers_of_fto_qualified_names_dict , project_root_path , jedi_project = jedi_project
128143 )
129144
130145 # Extract all code contexts in a single pass (one CST parse per file)
@@ -133,11 +148,14 @@ def get_code_optimization_context(
133148 final_read_writable_code = all_ctx .read_writable
134149
135150 # Ensure the target file is first in the code blocks so the LLM knows which file to optimize
136- target_relative = function_to_optimize .file_path .resolve ().relative_to (project_root_path .resolve ())
137- target_blocks = [cs for cs in final_read_writable_code .code_strings if cs .file_path == target_relative ]
138- other_blocks = [cs for cs in final_read_writable_code .code_strings if cs .file_path != target_relative ]
139- if target_blocks :
140- final_read_writable_code .code_strings = target_blocks + other_blocks
151+ try :
152+ target_relative = function_to_optimize .file_path .resolve ().relative_to (project_root_path )
153+ target_blocks = [cs for cs in final_read_writable_code .code_strings if cs .file_path == target_relative ]
154+ other_blocks = [cs for cs in final_read_writable_code .code_strings if cs .file_path != target_relative ]
155+ if target_blocks :
156+ final_read_writable_code .code_strings = target_blocks + other_blocks
157+ except ValueError :
158+ pass
141159
142160 read_only_code_markdown = all_ctx .read_only
143161
@@ -434,13 +452,13 @@ def re_extract_from_cache(
434452) -> CodeStringsMarkdown :
435453 """Re-extract context from cached modules without file I/O or CST parsing."""
436454 result = CodeStringsMarkdown ()
437- for cache in file_caches :
455+ for file_cache in file_caches :
438456 try :
439457 pruned = parse_code_and_prune_cst (
440- cache .cleaned_module ,
458+ file_cache .cleaned_module ,
441459 code_context_type ,
442- cache .fto_names ,
443- cache .hoh_names ,
460+ file_cache .fto_names ,
461+ file_cache .hoh_names ,
444462 remove_docstrings = remove_docstrings ,
445463 )
446464 except ValueError :
@@ -450,24 +468,25 @@ def re_extract_from_cache(
450468 code = ast .unparse (ast .parse (pruned .code ))
451469 else :
452470 code = add_needed_imports_from_module (
453- src_module_code = cache .original_module ,
471+ src_module_code = file_cache .original_module ,
454472 dst_module_code = pruned ,
455- src_path = cache .file_path ,
456- dst_path = cache .file_path ,
473+ src_path = file_cache .file_path ,
474+ dst_path = file_cache .file_path ,
457475 project_root = project_root_path ,
458- helper_functions = cache .helper_functions ,
476+ helper_functions = file_cache .helper_functions ,
459477 )
460- result .code_strings .append (CodeString (code = code , file_path = cache .relative_path ))
478+ result .code_strings .append (CodeString (code = code , file_path = file_cache .relative_path ))
461479 return result
462480
463481
464482def get_function_to_optimize_as_function_source (
465- function_to_optimize : FunctionToOptimize , project_root_path : Path
483+ function_to_optimize : FunctionToOptimize , project_root_path : Path , * , jedi_project : object | None = None
466484) -> FunctionSource :
467485 import jedi
468486
469487 # Use jedi to find function to optimize
470- script = jedi .Script (path = function_to_optimize .file_path , project = jedi .Project (path = project_root_path ))
488+ project = jedi_project if jedi_project is not None else get_jedi_project (str (project_root_path .resolve ()))
489+ script = jedi .Script (path = function_to_optimize .file_path , project = project )
471490
472491 # Get all names in the file
473492 names = script .get_names (all_scopes = True , definitions = True , references = False )
@@ -498,22 +517,36 @@ def get_function_to_optimize_as_function_source(
498517
499518
500519def get_function_sources_from_jedi (
501- file_path_to_qualified_function_names : dict [Path , set [str ]], project_root_path : Path
520+ file_path_to_qualified_function_names : dict [Path , set [str ]],
521+ project_root_path : Path ,
522+ * ,
523+ jedi_project : object | None = None ,
502524) -> tuple [dict [Path , set [FunctionSource ]], list [FunctionSource ]]:
503525 import jedi
504526
527+ project_root_path = project_root_path .resolve ()
528+ project = jedi_project if jedi_project is not None else get_jedi_project (str (project_root_path ))
505529 file_path_to_function_source = defaultdict (set )
506530 function_source_list : list [FunctionSource ] = []
507531 for file_path , qualified_function_names in file_path_to_qualified_function_names .items ():
508- script = jedi .Script (path = file_path , project = jedi . Project ( path = project_root_path ) )
532+ script = jedi .Script (path = file_path , project = project )
509533 file_refs = script .get_names (all_scopes = True , definitions = False , references = True )
510534
535+ # Pre-group references by their parent function's qualified name for O(1) lookup
536+ refs_by_parent : dict [str , list [Name ]] = defaultdict (list )
537+ for ref in file_refs :
538+ if not ref .full_name :
539+ continue
540+ try :
541+ parent = ref .parent ()
542+ if parent is None or parent .type != "function" :
543+ continue
544+ refs_by_parent [get_qualified_name (parent .module_name , parent .full_name )].append (ref )
545+ except (AttributeError , ValueError ):
546+ continue
547+
511548 for qualified_function_name in qualified_function_names :
512- names = [
513- ref
514- for ref in file_refs
515- if ref .full_name and belongs_to_function_qualified (ref , qualified_function_name )
516- ]
549+ names = refs_by_parent .get (qualified_function_name , [])
517550 for name in names :
518551 try :
519552 definitions : list [Name ] = name .goto (follow_imports = True , follow_builtin_imports = False )
@@ -1103,7 +1136,7 @@ def _resolve_imported_class_reference(
11031136 module_name , class_name = resolved_name .rsplit ("." , 1 )
11041137 try :
11051138 script_code = f"from { module_name } import { class_name } "
1106- script = jedi .Script (script_code , project = jedi . Project ( path = project_root_path ))
1139+ script = jedi .Script (script_code , project = get_jedi_project ( str ( project_root_path ) ))
11071140 definitions = script .goto (1 , len (f"from { module_name } import " ) + len (class_name ), follow_imports = True )
11081141 except Exception :
11091142 return None
@@ -1263,7 +1296,7 @@ def extract_parameter_type_constructors(
12631296 def append_type_context (type_name : str , module_name : str , * , transitive : bool = False ) -> None :
12641297 try :
12651298 script_code = f"from { module_name } import { type_name } "
1266- script = jedi .Script (script_code , project = jedi . Project ( path = project_root_path ))
1299+ script = jedi .Script (script_code , project = get_jedi_project ( str ( project_root_path ) ))
12671300 definitions = script .goto (1 , len (f"from { module_name } import " ) + len (type_name ), follow_imports = True )
12681301 if not definitions :
12691302 return
@@ -1429,7 +1462,7 @@ def extract_class_and_bases(
14291462 continue
14301463 try :
14311464 test_code = f"import { module_name } "
1432- script = jedi .Script (test_code , project = jedi . Project ( path = project_root_path ))
1465+ script = jedi .Script (test_code , project = get_jedi_project ( str ( project_root_path ) ))
14331466 completions = script .goto (1 , len (test_code ))
14341467
14351468 if not completions :
0 commit comments