Skip to content

Commit f49536f

Browse files
author
Adam Seering
committed
feat(bird): enhance dataset loader for multi-engine support
- Updates to handle multi-dialect golden SQL dictionaries. - Implements dialect mapping (e.g., mapping BIRD's to ). - Filters input dialects to ensure prompts are only generated for engines with available golden data.
1 parent 7ddd09d commit f49536f

1 file changed

Lines changed: 32 additions & 11 deletions

File tree

evalbench/dataset/dataset.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,36 +153,57 @@ def load_dataset_from_bird_format(dataset: Sequence[dict], config):
153153
dataset_str = str(dataset_config).split("/")[-1].replace(".json", "")
154154
dialects = config["dialects"]
155155
query_type = "dql"
156-
for item in dataset:
156+
for i, item in enumerate(dataset):
157157
# Add "ifs" to handle situations when some keys do not in(or in different format of) the BIRD evaluation dataset
158158
if "question_id" not in item and "id" in item:
159159
item["question_id"] = item["id"]
160160
if "question" not in item and "other" in item:
161161
item["question"] = item["other"]["question"]
162162
if "evidence" not in item and "other" in item:
163163
item["evidence"] = item["other"]["evidence"]
164-
if "question" not in item and "other" in item:
165-
item["question"] = item["other"]["question"]
166164
if "db_id" not in item:
167165
item["db_id"] = dataset_str
168-
if "SQL" not in item:
169-
if dialects[0] in item["golden_sql"]:
170-
item["SQL"] = item["golden_sql"][dialects[0]]
171-
else:
172-
item["SQL"] = ""
166+
167+
# Map BIRD dialects to EvalBench dialects in the golden_sql dict
168+
bird_golden = item.get("golden_sql", {})
169+
eb_golden = {}
170+
# standard mappings
171+
for d in ["postgres", "mysql", "sqlite"]:
172+
if d in bird_golden:
173+
eb_golden[d] = bird_golden[d]
174+
175+
# Spanner GSQL -> googlesql
176+
if "googlesql" in bird_golden:
177+
eb_golden["spanner_gsql"] = bird_golden["googlesql"]
178+
elif "sqlite" in bird_golden: # Fallback to sqlite if others missing
179+
eb_golden["spanner_gsql"] = bird_golden["sqlite"]
180+
181+
# Spanner PG -> postgres
182+
if "postgres" in bird_golden:
183+
eb_golden["spanner_pg"] = bird_golden["postgres"]
184+
elif "sqlite" in bird_golden:
185+
eb_golden["spanner_pg"] = bird_golden["sqlite"]
186+
187+
# filter input.dialects to only those we have golden SQL for
188+
config_dialects = config.get("dialects", [])
189+
input_dialects = [d for d in config_dialects if d in eb_golden]
190+
191+
if i == 0:
192+
print(f"DEBUG BIRD: id={item['question_id']} bird_dialects={list(bird_golden.keys())} config_dialects={config_dialects} -> input_dialects={input_dialects}")
193+
173194
if "difficulty" not in item and "tags" in item:
174195
item["difficulty"] = item["tags"]
175196

176-
if item["SQL"]:
197+
if input_dialects:
177198
eval_input = EvalInputRequest(
178199
id=item["question_id"],
179200
nl_prompt="".join([item["question"], item["evidence"]]).replace(
180201
"`", '"'
181202
),
182203
query_type=query_type,
183204
database=item["db_id"],
184-
dialects=config["dialects"],
185-
golden_sql=item["SQL"],
205+
dialects=input_dialects,
206+
golden_sql=eb_golden,
186207
eval_query="",
187208
setup_sql="",
188209
cleanup_sql="",

0 commit comments

Comments
 (0)