Skip to content

Commit 88dcd6c

Browse files
Fix: Enhance robustness and clarity in kg-solver components (OpenSPG#548)
This commit addresses several issues identified during a code analysis, focusing on improving the robustness, error handling, and clarity of various components within the KAG solver. The following changes have been made: 1. **KAGIterativePlanner**: * Corrected `is_static()` method to return `False`, aligning with its iterative behavior. 2. **KAGRetrievedResponse**: * Removed a misleading note about an f-string formatting error from the `to_string()` method's docstring, as the error was not present in the code. 3. **KAGStaticPlanner**: * Improved `finish_judger` error handling: If the LLM call to judge the answer fails, it now logs a warning and returns `False` (treating the answer as potentially bad) instead of defaulting to `True`. 4. **ChunkRetrievedExecutor**: * Clarified schema name: Changed the `name` field in its schema dictionary from "Retriever" to "ChunkRetriever" to better differentiate it from other retriever executors like `KagHybridExecutor`. 5. **PyBasedMathExecutor**: * Added a configurable timeout (defaulting to 5 seconds) to the `subprocess.run()` call within the `run_py_code` function. This prevents indefinite hangs from long-running or stuck Python scripts generated by the LLM. Includes handling for `subprocess.TimeoutExpired`. 6. **DefaultStaticPlanningPrompt**: * Enhanced `parse_response` method: Implemented more robust JSON decoding and structural validation for the LLM-generated DAG plan. It now raises more descriptive `ValueError` exceptions, including details of the malformed data, when `KeyError` or `TypeError` occurs during task creation from the DAG, aiding in debugging. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent 0228eba commit 88dcd6c

6 files changed

Lines changed: 63 additions & 20 deletions

File tree

kag/solver/executor/math/py_based_math_executor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,36 @@
2929

3030

3131
def run_py_code(python_code: str, **kwargs):
32+
# Default timeout in seconds
33+
default_timeout = 5
34+
# Allow timeout to be passed via kwargs if needed for more flexibility
35+
timeout_duration = kwargs.get("timeout", default_timeout)
36+
3237
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
3338
temp_file.write(python_code.encode("utf-8"))
3439
temp_file_path = temp_file.name
3540

41+
stdout_value = None
42+
stderr_value = None
43+
3644
try:
3745
python_executable = sys.executable
3846
result = subprocess.run(
39-
[python_executable, temp_file_path], capture_output=True, text=True
47+
[python_executable, temp_file_path],
48+
capture_output=True,
49+
text=True,
50+
timeout=timeout_duration # Added timeout
4051
)
52+
stdout_value = result.stdout
53+
stderr_value = result.stderr
54+
except subprocess.TimeoutExpired as e:
55+
stderr_value = f"Code execution timed out after {timeout_duration} seconds: {e}"
56+
except Exception as e: # Catch other potential errors during subprocess.run
57+
stderr_value = f"An unexpected error occurred during code execution: {e}"
4158
finally:
4259
os.remove(temp_file_path)
4360

44-
stdout_value = result.stdout
45-
stderr_value = result.stderr
46-
if len(stderr_value) > 0:
61+
if stderr_value: # If there's any error (timeout or other execution error)
4762
return None, stderr_value, python_code
4863
return stdout_value, None, python_code
4964

kag/solver/executor/retriever/local_knowledge_base/chunk_retrieved_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def schema(self) -> dict:
8787
dict: Schema definition in OpenAI Function format
8888
"""
8989
return {
90-
"name": "Retriever",
90+
"name": "ChunkRetriever", # Changed from "Retriever"
9191
"description": "Retrieve relevant knowledge from the local knowledge base.",
9292
"parameters": {
9393
"query": {

kag/solver/executor/retriever/local_knowledge_base/kag_retriever/kag_hybrid_executor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ def to_string(self) -> str:
133133
134134
Returns:
135135
str: Formatted string containing task description and sub-question results
136-
137-
Note:
138-
Contains formatting error: "task: f{self.retrieved_task}"
139-
should be corrected to "task: {self.retrieved_task}"
140136
"""
141137
refer_docs = self.to_reference_list()
142138
for doc in refer_docs:

kag/solver/planner/kag_iterative_planner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def invoke(self, query, **kwargs) -> List[Task]:
8080
**kwargs,
8181
)
8282

83+
def is_static(self):
84+
return False
85+
8386
async def ainvoke(self, query, **kwargs) -> List[Task]:
8487
"""Asynchronously generates task plan using LLM.
8588

kag/solver/planner/kag_static_planner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Unless required by applicable law or agreed to in writing, software distributed under the License
1010
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1111
# or implied.
12+
import logging
1213
import re
1314
from typing import List
1415

@@ -86,11 +87,10 @@ async def finish_judger(self, query: str, answer: str):
8687
return False
8788
return True
8889
except Exception as e:
89-
print(f"Failed to run finish_judger, info: {e}")
90-
import traceback
91-
92-
traceback.print_exc()
93-
return True
90+
# import logging # Make sure logging is imported if not already at the top of the file
91+
logger = logging.getLogger(__name__) # Get a logger instance
92+
logger.warning(f"LLM call failed in finish_judger for query '{query}'. Error: {e}", exc_info=True)
93+
return False # Treat as potentially bad answer
9494

9595
async def query_rewrite(self, task: Task, **kwargs):
9696
"""Performs asynchronous query rewriting using LLM and context.

kag/solver/prompt/static_planning_prompt.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,38 @@ def template_variables(self) -> List[str]:
144144

145145
def parse_response(self, response: str, **kwargs):
146146
if isinstance(response, str):
147-
response = json.loads(response)
148-
if not isinstance(response, dict):
149-
raise ValueError(f"response should be a dict, but got {type(response)}")
150-
if "output" in response:
151-
response = response["output"]
152-
return Task.create_tasks_from_dag(response)
147+
try:
148+
response_json = json.loads(response)
149+
except json.JSONDecodeError as e:
150+
raise ValueError(f"Failed to decode LLM response as JSON: {e}. Response: {response}")
151+
elif isinstance(response, dict):
152+
response_json = response # If it's already a dict (e.g. from direct LLM client parsing)
153+
else:
154+
raise ValueError(f"LLM response is not a JSON string or a dictionary. Got type: {type(response)}. Response: {response}")
155+
156+
if not isinstance(response_json, dict):
157+
# This case might be redundant if json.loads already ensures a dict or list,
158+
# but good for safety if the initial response could be a non-dict JSON type.
159+
raise ValueError(f"Parsed LLM response should be a dict, but got {type(response_json)}. Response: {response_json}")
160+
161+
# Handle if the LLM wraps the DAG in an "output" key, as per original logic
162+
actual_dag_data = response_json.get("output", response_json)
163+
164+
if not isinstance(actual_dag_data, dict):
165+
raise ValueError(f"The core plan data (after handling potential 'output' key) is not a dictionary. Got type: {type(actual_dag_data)}. Data: {actual_dag_data}")
166+
167+
try:
168+
return Task.create_tasks_from_dag(actual_dag_data)
169+
except (KeyError, TypeError) as e:
170+
error_message = (
171+
f"LLM response for static planning was malformed. Error: {e}. "
172+
f"Each task in the DAG dictionary must define 'executor', 'arguments', and 'dependent_task_ids'. "
173+
f"Problematic DAG data: {actual_dag_data}"
174+
)
175+
raise ValueError(error_message)
176+
except Exception as e: # Catch any other unexpected errors from create_tasks_from_dag
177+
error_message = (
178+
f"An unexpected error occurred while creating tasks from DAG. Error: {e}. "
179+
f"Problematic DAG data: {actual_dag_data}"
180+
)
181+
raise ValueError(error_message)

0 commit comments

Comments
 (0)