-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy path__init__.py
More file actions
256 lines (221 loc) · 8.45 KB
/
__init__.py
File metadata and controls
256 lines (221 loc) · 8.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Lang2SQL CLI 프로그램입니다.
이 프로그램은 Datahub GMS 서버 URL을 설정하고, 필요 시 Streamlit 인터페이스를 실행합니다.
명령어 예시: lang2sql --datahub_server http://localhost:8080 --run-streamlit
"""
import logging
import subprocess
import click
from llm_utils.check_server import CheckServer
from llm_utils.tools import set_gms_server
from version import __version__
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
@click.group()
@click.version_option(version=__version__)
@click.pass_context
@click.option(
"--datahub_server",
default="http://localhost:8080",
help=(
"Datahub GMS 서버의 URL을 설정합니다. "
"기본값은 'http://localhost:8080'이며, "
"운영 환경 또는 테스트 환경에 맞게 변경할 수 있습니다."
),
)
@click.option(
"--run-streamlit",
is_flag=True,
help=(
"이 옵션을 지정하면 CLI 실행 시 Streamlit 애플리케이션을 바로 실행합니다. "
"별도의 명령어 입력 없이 웹 인터페이스를 띄우고 싶을 때 사용합니다."
),
)
@click.option(
"-p",
"--port",
type=int,
default=8501,
help=(
"Streamlit 서버가 바인딩될 포트 번호를 지정합니다. "
"기본 포트는 8501이며, 포트 충돌을 피하거나 여러 인스턴스를 실행할 때 변경할 수 있습니다."
),
)
# pylint: disable=redefined-outer-name
def cli(
ctx: click.Context,
datahub_server: str,
run_streamlit: bool,
port: int,
) -> None:
"""
Datahub GMS 서버 URL을 설정하고, Streamlit 애플리케이션을 실행할 수 있는 CLI 명령 그룹입니다.
이 함수는 다음 역할을 수행합니다:
- 전달받은 'datahub_server' URL을 바탕으로 GMS 서버 연결을 설정합니다.
- 설정 과정 중 오류가 발생하면 오류 메시지를 출력하고 프로그램을 종료합니다.
- '--run-streamlit' 옵션이 활성화된 경우, 지정된 포트에서 Streamlit 웹 앱을 즉시 실행합니다.
매개변수:
ctx (click.Context): 명령어 실행 컨텍스트 객체입니다.
datahub_server (str): 설정할 Datahub GMS 서버의 URL입니다.
run_streamlit (bool): Streamlit 앱을 실행할지 여부를 나타내는 플래그입니다.
port (int): Streamlit 서버가 바인딩될 포트 번호입니다.
주의:
'set_gms_server' 함수에서 ValueError가 발생할 경우, 프로그램은 비정상 종료(exit code 1)합니다.
"""
logger.info(
"Initialization started: GMS server = %s, run_streamlit = %s, port = %d",
datahub_server,
run_streamlit,
port,
)
if CheckServer.is_gms_server_healthy(url=datahub_server):
set_gms_server(datahub_server)
logger.info("GMS server URL successfully set: %s", datahub_server)
else:
logger.error("GMS server health check failed. URL: %s", datahub_server)
ctx.exit(1)
if run_streamlit:
run_streamlit_command(port)
def run_streamlit_command(port: int) -> None:
"""
지정된 포트에서 Streamlit 애플리케이션을 실행하는 함수입니다.
이 함수는 subprocess를 통해 'streamlit run' 명령어를 실행하여
'interface/streamlit_app.py' 파일을 웹 서버 형태로 구동합니다.
사용자가 지정한 포트 번호를 Streamlit 서버의 포트로 설정합니다.
매개변수:
port (int): Streamlit 서버가 바인딩될 포트 번호입니다.
주의:
- Streamlit이 시스템에 설치되어 있어야 정상 동작합니다.
- subprocess 호출 실패 시 예외가 발생할 수 있습니다.
"""
logger.info("Starting Streamlit application on port %d...", port)
try:
subprocess.run(
[
"streamlit",
"run",
"interface/streamlit_app.py",
"--server.port",
str(port),
],
check=True,
)
logger.info("Streamlit application started successfully.")
except subprocess.CalledProcessError as e:
logger.error("Failed to start Streamlit application: %s", e)
raise
@cli.command(name="run-streamlit")
@click.option(
"-p",
"--port",
type=int,
default=8501,
help=(
"Streamlit 애플리케이션이 바인딩될 포트 번호를 지정합니다. "
"기본 포트는 8501이며, 필요 시 포트 충돌을 피하거나 "
"여러 인스턴스를 동시에 실행할 때 다른 포트 번호를 설정할 수 있습니다."
),
)
def run_streamlit_cli_command(port: int) -> None:
"""
CLI 명령어를 통해 Streamlit 애플리케이션을 실행하는 함수입니다.
이 명령은 'interface/streamlit_app.py' 파일을 Streamlit 서버로 구동하며,
사용자가 지정한 포트 번호를 바인딩하여 웹 인터페이스를 제공합니다.
매개변수:
port (int): Streamlit 서버가 사용할 포트 번호입니다. 기본값은 8501입니다.
주의:
- Streamlit이 시스템에 설치되어 있어야 정상적으로 실행됩니다.
- Streamlit 실행에 실패할 경우 subprocess 호출에서 예외가 발생할 수 있습니다.
"""
logger.info("Executing 'run-streamlit' command on port %d...", port)
run_streamlit_command(port)
@cli.command(name="query")
@click.argument("question", type=str)
@click.option(
"--database-env",
default="clickhouse",
help="사용할 데이터베이스 환경 (기본값: clickhouse)",
)
@click.option(
"--retriever-name",
default="기본",
help="테이블 검색기 이름 (기본값: 기본)",
)
@click.option(
"--top-n",
type=int,
default=5,
help="검색된 상위 테이블 수 제한 (기본값: 5)",
)
@click.option(
"--device",
default="cpu",
help="LLM 실행에 사용할 디바이스 (기본값: cpu)",
)
@click.option(
"--use-enriched-graph",
is_flag=True,
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
)
@click.option(
"--use-simplified-graph",
is_flag=True,
help="단순화된 그래프(QUERY_REFINER 제거) 사용 여부",
)
def query_command(
question: str,
database_env: str,
retriever_name: str,
top_n: int,
device: str,
use_enriched_graph: bool,
use_simplified_graph: bool,
) -> None:
"""
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.
이 명령은 사용자가 입력한 자연어 질문을 받아서 SQL 쿼리로 변환하고,
생성된 SQL 쿼리만을 표준 출력으로 출력합니다.
매개변수:
question (str): SQL로 변환할 자연어 질문
database_env (str): 사용할 데이터베이스 환경
retriever_name (str): 테이블 검색기 이름
top_n (int): 검색된 상위 테이블 수 제한
device (str): LLM 실행에 사용할 디바이스
use_enriched_graph (bool): 확장된 그래프 사용 여부
예시:
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" --use-enriched-graph
"""
try:
from llm_utils.query_executor import execute_query, extract_sql_from_result
# 공용 함수를 사용하여 쿼리 실행
res = execute_query(
query=question,
database_env=database_env,
retriever_name=retriever_name,
top_n=top_n,
device=device,
use_enriched_graph=use_enriched_graph,
use_simplified_graph=use_simplified_graph,
)
# SQL 추출 및 출력
sql = extract_sql_from_result(res)
if sql:
print(sql)
else:
# SQL 추출 실패 시 원본 쿼리 텍스트 출력
generated_query = res.get("generated_query")
if generated_query:
query_text = (
generated_query.content
if hasattr(generated_query, "content")
else str(generated_query)
)
print(query_text)
except Exception as e:
logger.error("쿼리 처리 중 오류 발생: %s", e)
raise