-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathquery_executor.py
More file actions
137 lines (115 loc) · 4.74 KB
/
query_executor.py
File metadata and controls
137 lines (115 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Lang2SQL 쿼리 실행을 위한 공용 모듈입니다.
이 모듈은 CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있는
쿼리 실행 함수를 제공합니다.
"""
import logging
from typing import Any, Dict, Optional, Union
from langchain_core.messages import HumanMessage
from utils.llm.graph_utils.basic_graph import builder as basic_builder
from utils.llm.graph_utils.enriched_graph import builder as enriched_builder
from utils.llm.llm_response_parser import LLMResponseParser
logger = logging.getLogger(__name__)
def execute_query(
*,
query: str,
database_env: str,
retriever_name: str = "기본",
top_n: int = 5,
device: str = "cpu",
use_enriched_graph: bool = False,
session_state: Optional[Union[Dict[str, Any], Any]] = None,
) -> Dict[str, Any]:
"""
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 공용 함수입니다.
이 함수는 Lang2SQL 파이프라인(graph)을 사용하여 사용자의 자연어 질문을
SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있습니다.
Args:
query (str): 사용자가 입력한 자연어 기반 질문.
database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod").
retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본".
top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5.
device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu".
use_enriched_graph (bool, optional): 확장된 그래프 사용 여부. 기본값은 False.
session_state (Optional[Union[Dict[str, Any], Any]], optional): Streamlit 세션 상태 (Streamlit에서만 사용).
Returns:
Dict[str, Any]: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리:
- "generated_query": 생성된 SQL 쿼리 (`AIMessage`)
- "messages": 전체 LLM 응답 메시지 목록
- "searched_tables": 참조된 테이블 목록 등 추가 정보
"""
logger.info("Processing query: %s", query)
# 그래프 선택
if use_enriched_graph:
graph_type = "enriched"
graph_builder = enriched_builder
else:
graph_type = "basic"
graph_builder = basic_builder
logger.info("Using %s graph", graph_type)
# 그래프 선택 및 컴파일
if session_state is not None:
# Streamlit 환경: 세션 상태에서 그래프 재사용
graph = session_state.get("graph")
if graph is None:
graph = graph_builder.compile()
session_state["graph"] = graph
else:
# CLI 환경: 매번 새로운 그래프 컴파일
graph = graph_builder.compile()
# 그래프 실행
res = graph.invoke(
input={
"messages": [HumanMessage(content=query)],
"user_database_env": database_env,
"best_practice_query": "",
"retriever_name": retriever_name,
"top_n": top_n,
"device": device,
# 다이얼렉트 정보 주입 (있다면 세션에서, 없으면 기본값)
"dialect_name": (
session_state.get("selected_dialect_option", {}).get("name")
if session_state is not None
else database_env
),
"supports_ilike": (
bool(
session_state.get("selected_dialect_option", {}).get(
"supports_ilike", False
)
)
if session_state is not None
else False
),
"dialect_hints": (
session_state.get("selected_dialect_option", {}).get("hints", [])
if session_state is not None
else []
),
}
)
return res
def extract_sql_from_result(res: Dict[str, Any]) -> Optional[str]:
"""
Lang2SQL 실행 결과에서 SQL 쿼리를 추출합니다.
Args:
res (Dict[str, Any]): execute_query 함수의 반환 결과
Returns:
Optional[str]: 추출된 SQL 쿼리 문자열. 추출 실패 시 None
"""
generated_query = res.get("generated_query")
if not generated_query:
logger.error("생성된 쿼리가 없습니다.")
return None
query_text = (
generated_query.content
if hasattr(generated_query, "content")
else str(generated_query)
)
try:
sql = LLMResponseParser.extract_sql(query_text)
return sql
except ValueError:
logger.error("SQL을 추출할 수 없습니다.")
return None