@@ -318,8 +318,9 @@ def add_graph_edge(self, edge: graph.Edge):
318318 if len (existing_edges ):
319319 _ , end = self .get_node_range (existing_edges [- 1 ])
320320 else :
321- # TODO add the edge after the graph is instantiated
322- raise ValidationError (f"Existing graph `add_edge` call not found in { ENTRYPOINT } " )
321+ # find the instantiation of `StateGraph`
322+ graph_instance = asttools .find_method_calls (self .get_run_method (), 'StateGraph' )[0 ]
323+ _ , end = self .get_node_range (graph_instance )
323324
324325 source , target = edge .source .name , edge .target .name
325326 # wrap the node names in quotes if they are not special nodes
@@ -364,8 +365,11 @@ def add_graph_node(self, node_config: Union[AgentConfig, TaskConfig]):
364365 if len (existing_nodes ):
365366 _ , end = self .get_node_range (existing_nodes [- 1 ])
366367 else :
367- # TODO add the node after the graph is instantiated
368- raise ValidationError (f"Existing graph `add_node` call not found in { ENTRYPOINT } " )
368+ # find the instantiation of `StateGraph`
369+ graph_instance = asttools .find_method_calls (self .get_run_method (), 'StateGraph' )[0 ]
370+ _ , end = self .get_node_range (graph_instance )
371+
372+ # node is always either an Agent or a Task so we can make this assumption
369373 code = f"""
370374 self.graph.add_node("{ node_config .name } ", self.{ node_config .name } )"""
371375 self .edit_node_range (end , end , code )
0 commit comments