-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathbase.py
More file actions
258 lines (211 loc) · 9.33 KB
/
base.py
File metadata and controls
258 lines (211 loc) · 9.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import json
from langgraph.graph.message import add_messages
from typing_extensions import Annotated, TypedDict
from utils.llm.chains import (
document_suitability_chain,
profile_extraction_chain,
query_enrichment_chain,
query_maker_chain,
question_gate_chain,
)
from utils.llm.retrieval import search_tables
# 노드 식별자 정의
QUESTION_GATE = "question_gate"
EVALUATE_DOCUMENT_SUITABILITY = "evaluate_document_suitability"
GET_TABLE_INFO = "get_table_info"
TOOL = "tool"
TABLE_FILTER = "table_filter"
QUERY_MAKER = "query_maker"
PROFILE_EXTRACTION = "profile_extraction"
CONTEXT_ENRICHMENT = "context_enrichment"
# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함)
class QueryMakerState(TypedDict):
messages: Annotated[list, add_messages]
user_database_env: str
searched_tables: dict[str, dict[str, str]]
document_suitability: dict
best_practice_query: str
question_profile: dict
generated_query: str
retriever_name: str
top_n: int
device: str
question_gate_result: dict
# 다이얼렉트 정보
dialect_name: str
supports_ilike: bool
dialect_hints: list[str]
# 노드 함수: QUESTION_GATE 노드
def question_gate_node(state: QueryMakerState):
"""
사용자의 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하는 게이트 노드입니다.
- question_gate_chain 으로 적합성을 판정하여
`question_gate_result`를 설정합니다.
Args:
state (QueryMakerState): 그래프 상태
Returns:
QueryMakerState: 게이트 판정 결과가 반영된 상태
"""
question_text = state["messages"][0].content
suitability = question_gate_chain.invoke({"question": question_text})
state["question_gate_result"] = {
"reason": getattr(suitability, "reason", ""),
"missing_entities": getattr(suitability, "missing_entities", []),
"requires_data_science": getattr(suitability, "requires_data_science", False),
}
return state
# 노드 함수: PROFILE_EXTRACTION 노드
def profile_extraction_node(state: QueryMakerState):
"""
자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다.
이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부,
그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다.
추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다:
- `is_timeseries`: 시계열 분석 필요 여부
- `is_aggregation`: 집계 함수 필요 여부
- `has_filter`: 조건 필터 필요 여부
- `is_grouped`: 그룹화 필요 여부
- `has_ranking`: 정렬/순위 필요 여부
- `has_temporal_comparison`: 기간 비교 포함 여부
- `intent_type`: 질문의 주요 의도 유형
"""
result = profile_extraction_chain.invoke({"question": state["messages"][0].content})
state["question_profile"] = result
print("profile_extraction_node : ", result)
return state
# 노드 함수: CONTEXT_ENRICHMENT 노드
def context_enrichment_node(state: QueryMakerState):
"""
주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다.
이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다.
보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다.
주요 작업:
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국' → 'USA').
- 보강된 질문을 출력합니다.
Args:
state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체.
상태 객체는 `messages`, `question_profile`, `searched_tables` 등의 정보를 포함합니다.
Returns:
QueryMakerState: 보강된 질문이 포함된 상태 객체.
Example:
Given the refined question "What are the total sales in the last month?",
the function would enrich it with additional information such as:
- Ensuring the time period is specified correctly.
- Correcting any column names if necessary.
- Returning the enriched version of the question.
"""
searched_tables = state["searched_tables"]
searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2)
# question_profile이 BaseModel인 경우 model_dump() 사용, dict인 경우 그대로 사용
if hasattr(state["question_profile"], "model_dump"):
question_profile = state["question_profile"].model_dump()
else:
question_profile = state["question_profile"]
question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2)
# 초기 사용자 입력 사용
refined_question = state["messages"][0].content
enriched_text = query_enrichment_chain.invoke(
input={
"refined_question": refined_question,
"profiles": question_profile_json,
"related_tables": searched_tables_json,
}
)
state["messages"].append(enriched_text)
print("After context enrichment : ", enriched_text.content)
return state
def get_table_info_node(state: QueryMakerState):
# retriever_name과 top_n을 이용하여 검색 수행
documents_dict = search_tables(
query=state["messages"][0].content,
retriever_name=state["retriever_name"],
top_n=state["top_n"],
device=state["device"],
)
state["searched_tables"] = documents_dict
return state
# 노드 함수: DOCUMENT_SUITABILITY 노드
def document_suitability_node(state: QueryMakerState):
"""
GET_TABLE_INFO에서 수집된 테이블 후보들에 대해 문서 적합성 점수를 계산하는 노드입니다.
질문(`messages[0].content`)과 `searched_tables`(테이블→칼럼 설명 맵)를 입력으로
프롬프트 체인(`document_suitability_chain`)을 호출하고, 결과 딕셔너리를
`document_suitability` 상태 키에 저장합니다.
Returns:
QueryMakerState: 문서 적합성 평가 결과가 포함된 상태
"""
# 관련 테이블이 없으면 즉시 반환
if not state.get("searched_tables"):
state["document_suitability"] = {}
return state
res = document_suitability_chain.invoke(
{
"question": state["messages"][0].content,
"tables": state["searched_tables"],
}
)
items = (
res.get("results", [])
if isinstance(res, dict)
else getattr(res, "results", None)
or (res.model_dump().get("results", []) if hasattr(res, "model_dump") else [])
)
normalized = {}
for x in items:
d = (
x.model_dump()
if hasattr(x, "model_dump")
else (
x
if isinstance(x, dict)
else {
"table_name": getattr(x, "table_name", ""),
"score": getattr(x, "score", 0),
"reason": getattr(x, "reason", ""),
"matched_columns": getattr(x, "matched_columns", []),
"missing_entities": getattr(x, "missing_entities", []),
}
)
)
t = d.get("table_name")
if not t:
continue
normalized[t] = {
"score": float(d.get("score", 0)),
"reason": d.get("reason", ""),
"matched_columns": d.get("matched_columns", []),
"missing_entities": d.get("missing_entities", []),
}
state["document_suitability"] = normalized
return state
# 노드 함수: QUERY_MAKER 노드
def query_maker_node(state: QueryMakerState):
# 사용자 원 질문 + (있다면) 컨텍스트 보강 결과를 하나의 문자열로 결합
parts = [state["messages"][0].content]
if len(state["messages"]) > 1:
last_msg = state["messages"][-1]
last_content = (
last_msg.content if hasattr(last_msg, "content") else str(last_msg)
)
if isinstance(last_content, str) and last_content.strip():
parts.append(last_content)
combined_input = "\n\n---\n\n".join(parts)
searched_tables_json = json.dumps(
state["searched_tables"], ensure_ascii=False, indent=2
)
res = query_maker_chain.invoke(
input={
"user_input": combined_input,
"user_database_env": state["user_database_env"],
"searched_tables": searched_tables_json,
# 다이얼렉트 변수 전달
"dialect_name": state.get("dialect_name", ""),
"supports_ilike": state.get("supports_ilike", False),
"dialect_hints": ", ".join(state.get("dialect_hints", [])),
}
)
state["generated_query"] = res
state["messages"].append(res)
return state