-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathdatahub.py
More file actions
167 lines (131 loc) · 5.16 KB
/
datahub.py
File metadata and controls
167 lines (131 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import re
from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any
from langchain.schema import Document
from data_utils.datahub_source import DatahubMetadataFetcher
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
T = TypeVar("T")
R = TypeVar("R")
def parallel_process(
items: Iterable[T],
process_fn: Callable[[T], R],
max_workers: int = 8,
desc: Optional[str] = None,
show_progress: bool = True,
) -> List[R]:
"""병렬 처리를 위한 유틸리티 함수"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_fn, item) for item in items]
if show_progress:
futures = tqdm(futures, desc=desc)
return [future.result() for future in futures]
def set_gms_server(gms_server: str):
try:
os.environ["DATAHUB_SERVER"] = gms_server
fetcher = DatahubMetadataFetcher(gms_server=gms_server)
except ValueError as e:
raise ValueError(f"GMS 서버 설정 실패: {str(e)}")
def _get_fetcher():
gms_server = os.getenv("DATAHUB_SERVER")
if not gms_server:
raise ValueError("GMS 서버가 설정되지 않았습니다.")
return DatahubMetadataFetcher(gms_server=gms_server)
def _process_urn(urn: str, fetcher: DatahubMetadataFetcher) -> tuple[str, str]:
table_name = fetcher.get_table_name(urn)
table_description = fetcher.get_table_description(urn)
return (table_name, table_description)
def _process_column_info(
urn: str, table_name: str, fetcher: DatahubMetadataFetcher
) -> Optional[List[Dict[str, str]]]:
if fetcher.get_table_name(urn) == table_name:
return fetcher.get_column_names_and_descriptions(urn)
return None
def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
fetcher = _get_fetcher()
urns = fetcher.get_urns()
table_info = {}
results = parallel_process(
urns,
lambda urn: _process_urn(urn, fetcher),
max_workers=max_workers,
desc="테이블 정보 수집 중",
)
for table_name, table_description in results:
if table_name and table_description:
table_info[table_name] = table_description
return table_info
def _get_column_info(
table_name: str, urn_table_mapping: Dict[str, str], max_workers: int = 8
) -> List[Dict[str, str]]:
target_urn = urn_table_mapping.get(table_name)
if not target_urn:
return []
fetcher = _get_fetcher()
column_info = fetcher.get_column_names_and_descriptions(target_urn)
return column_info
def _extract_dataset_name_from_urn(urn: str) -> Optional[str]:
"""URN 문자열에서 데이터셋 이름(예: delta.default.stg_gh_events)만 추출.
지원 패턴:
- dataset URN: urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD)
- schemaField URN: urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD),event_id)
"""
match = re.search(
r"urn:li:dataset:\(urn:li:dataPlatform:[^,]+,([^,]+),[^)]+\)", urn
)
if match:
return match.group(1)
return None
def get_info_from_db(max_workers: int = 8) -> List[Document]:
table_info = _get_table_info(max_workers=max_workers)
fetcher = _get_fetcher()
urns = list(fetcher.get_urns())
urn_table_mapping = {}
display_name_by_table = {}
for urn in urns:
original_name = fetcher.get_table_name(urn)
if original_name:
urn_table_mapping[original_name] = urn
parsed_name = _extract_dataset_name_from_urn(urn)
if parsed_name:
display_name_by_table[original_name] = parsed_name
def process_table_info(item: tuple[str, str, str]) -> str:
original_table_name, table_description, display_table_name = item
# 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치)
column_info = _get_column_info(
original_table_name, urn_table_mapping, max_workers=max_workers
)
column_info_str = "\n".join(
[
f"{col['column_name']}: {col['column_description']}"
for col in column_info
]
)
used_name = display_table_name or original_table_name
return f"{used_name}: {table_description}\nColumns:\n {column_info_str}"
# 표시용 이름을 세 번째 파라미터로 함께 전달
items_with_display = [
(
name,
desc,
display_name_by_table.get(name, name),
)
for name, desc in table_info.items()
]
table_info_str_list = parallel_process(
items_with_display,
process_table_info,
max_workers=max_workers,
desc="컬럼 정보 수집 중",
)
return [Document(page_content=info) for info in table_info_str_list]
def get_metadata_from_db() -> List[Dict]:
fetcher = _get_fetcher()
urns = list(fetcher.get_urns())
metadata = []
total = len(urns)
for idx, urn in enumerate(urns, 1):
print(f"[{idx}/{total}] Processing URN: {urn}")
table_metadata = fetcher.build_table_metadata(urn)
metadata.append(table_metadata)
return metadata