Skip to content

Commit c9a1e18

Browse files
committed
added query by query evaluation
1 parent 5af8d6a commit c9a1e18

11 files changed

Lines changed: 262 additions & 1104 deletions

File tree

README.md

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,50 +52,59 @@ python data_preprocess.py
5252
```
5353
### Prompt Generation
5454
Select examples with masked question similarity:
55-
```
55+
```bash
5656
python generate_question.py \
5757
--data_type spider \
5858
--split test \
5959
--tokenizer gpt-3.5-turbo \
6060
--max_seq_len 4096 \
61+
--max_ans_len 200 \
6162
--prompt_repr SQL \
62-
--k_shot 9 \
63+
--k_shot 3 \
6364
--example_type QA \
64-
--selector_type EUCDISQUESTIONMASK
65-
```
66-
Select examples considering both question similarity and query similarity:
67-
```
68-
python generate_question.py \
69-
--data_type spider \
70-
--split test \
71-
--tokenizer gpt-3.5-turbo \
72-
--max_seq_len 4096 \
73-
--selector_type EUCDISMASKPRESKLSIMTHR \
74-
--pre_test_result [your_pre_generated_queries_file] \
75-
--prompt_repr SQL \
76-
--k_shot 9 \
77-
--example_type QA
65+
--selector_type EUCDISQUESTIONMASK
7866
```
7967

8068
### Calling the LLM
69+
70+
#### Using OpenAI Models
8171
Without voting:
82-
```
72+
```bash
8373
python ask_llm.py \
84-
--openai_api_key [your_openai_api_key] \
74+
--openai_api_key [your_openai_api_key] \
8575
--model gpt-4 \
8676
--question [prompt_dir]
8777
```
8878
With self-consistency voting:
89-
```
79+
```bash
9080
python ask_llm.py \
91-
--openai_api_key [your_openai_api_key] \
81+
--openai_api_key [your_openai_api_key] \
9282
--model gpt-4 \
9383
--question [prompt_dir] \
9484
--n 5 \
9585
--db_dir ./dataset/spider/database \
9686
--temperature 1.0
9787
```
9888

89+
#### Using Ollama/Local Models
90+
```bash
91+
python ask_llm.py \
92+
--model {model_name} \
93+
--question ./dataset/process/SPIDER-TEST_SQL_3-SHOT_EUCDISQUESTIONMASK_QA-EXAMPLE_CTX-200_ANS-4096 \
94+
--n 1 \
95+
--temperature 0.7 \
96+
--openai_api_key %OLLAMA_API_KEY% \
97+
--openai_api_base %OLLAMA_BASE_URL%
98+
```
99+
100+
**Note:** The `ask_llm.py` script now performs automatic evaluation during execution. The script will:
101+
- Generate SQL queries and save them to `[prompt_dir]\RESULTS_MODEL-{model}.txt`
102+
- Evaluate each query against the gold standard in real-time
103+
- Save evaluation results to `results/eval_{model}.txt` (viewable in real-time as the script runs)
104+
- Display running accuracy after each question and final accuracy at the end
105+
106+
You do NOT need to run `evaluation.py` separately.
107+
99108
### Running Example
100109
```
101110
bash run_dail_sql_mini.sh [your_openai_api_key]

ask_llm.py

Lines changed: 145 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,22 @@
1111

1212
from utils.post_process import process_duplication, get_sqls
1313

14+
# MODIFICATION: Import the evaluation function and other necessary modules
15+
from eval.exec_eval import eval_exec_match
16+
import asyncio
17+
1418
QUESTION_FILE = "questions.json"
1519

1620

1721
if __name__ == '__main__':
1822
parser = argparse.ArgumentParser()
19-
parser.add_argument("--question", type=str)
20-
parser.add_argument("--openai_api_key", type=str)
23+
parser.add_argument("--question", type=str, required=True)
24+
parser.add_argument("--openai_api_key", type=str, required=True)
2125
parser.add_argument("--openai_group_id", type=str, default="org-ktBefi7n9aK7sZjwc2R9G1Wo")
26+
parser.add_argument("--openai_api_base", type=str, default="", help="Custom API base URL for Ollama or other OpenAI-compatible APIs")
2227
parser.add_argument("--model", type=str, choices=[LLM.TEXT_DAVINCI_003,
2328
LLM.GPT_35_TURBO,
2429
LLM.GPT_35_TURBO_0613,
25-
# LLM.TONG_YI_QIAN_WEN,
26-
LLM.GPT_35_TURBO_16K,
2730
LLM.GPT_4,
2831
LLM.OLLAMA_CODELLAMA_7B,
2932
LLM.OLLAMA_DEEPSEEK_CODER_6_7B],
@@ -33,8 +36,8 @@
3336
parser.add_argument("--temperature", type=float, default=0)
3437
parser.add_argument("--mini_index_path", type=str, default="")
3538
parser.add_argument("--batch_size", type=int, default=1)
36-
parser.add_argument("--n", type=int, default=5, help="Size of self-consistent set")
37-
parser.add_argument("--db_dir", type=str, default="dataset/spider/database")
39+
parser.add_argument("--n", type=int, default=1, help="Size of self-consistent set")
40+
parser.add_argument("--db_dir", type=str, default="dataset/spider/database", help="Path to the database directory")
3841
args = parser.parse_args()
3942

4043
# check args (Ollama path currently supports only batch_size==1)
@@ -44,65 +47,79 @@
4447
f"{args.model} doesn't support batch_size > 1"
4548

4649
questions_json = json.load(open(os.path.join(args.question, QUESTION_FILE), "r"))
47-
questions = [_["prompt"] for _ in questions_json["questions"]]
48-
db_ids = [_["db_id"] for _ in questions_json["questions"]]
50+
51+
# MODIFICATION: We need the full question objects, not just the prompts
52+
all_questions_data = questions_json["questions"]
4953

5054
# init openai api
51-
init_chatgpt(args.openai_api_key, args.openai_group_id, args.model)
55+
init_chatgpt(args.openai_api_key, args.openai_group_id, args.model, args.openai_api_base)
5256

5357
if args.start_index == 0:
5458
mode = "w"
5559
else:
5660
mode = "a"
5761

58-
# sanitize model name for filesystem (e.g., Windows disallows ":")
5962
safe_model = args.model.replace(":", "_").replace("/", "_")
6063

6164
if args.mini_index_path:
6265
mini_index = json.load(open(args.mini_index_path, 'r'))
63-
questions = [questions[i] for i in mini_index]
66+
# MODIFICATION: Filter the full data objects
67+
all_questions_data = [all_questions_data[i] for i in mini_index]
6468
out_file = f"{args.question}/RESULTS_MODEL-{safe_model}_MINI.txt"
6569
else:
6670
out_file = f"{args.question}/RESULTS_MODEL-{safe_model}.txt"
6771

68-
question_loader = DataLoader(questions, batch_size=args.batch_size, shuffle=False, drop_last=False)
72+
# MODIFICATION: Create evaluation results file path
73+
eval_out_file = os.path.join("results", f"eval_{safe_model}.txt")
74+
75+
# The DataLoader will now handle dictionaries
76+
question_loader = DataLoader(all_questions_data, batch_size=args.batch_size, shuffle=False, drop_last=False)
77+
78+
# MODIFICATION: Add counters for live evaluation
79+
total_questions = 0
80+
correct_predictions = 0
6981

7082
token_cnt = 0
71-
with open(out_file, mode) as f:
72-
for i, batch in enumerate(tqdm(question_loader)):
73-
if i < args.start_index:
83+
with open(out_file, mode) as f, open(eval_out_file, mode) as eval_f:
84+
# MODIFICATION: Enumerate provides an index starting from 0
85+
for i, batch_data in enumerate(tqdm(question_loader)):
86+
87+
# The DataLoader might return lists of values for each key if batch_size > 1
88+
# We need to reconstruct the list of dicts
89+
batch_prompts = batch_data['prompt']
90+
91+
current_batch_index = i * args.batch_size
92+
if current_batch_index < args.start_index:
7493
continue
75-
if i >= args.end_index:
94+
if current_batch_index >= args.end_index:
7695
break
96+
7797
try:
78-
res = ask_llm(args.model, batch, args.temperature, args.n)
98+
res = ask_llm(args.model, batch_prompts, args.temperature, args.n)
7999
except openai.error.InvalidRequestError:
80-
print(f"The {i}-th question has too much tokens! Return \"SELECT\" instead")
81-
res = ""
100+
print(f"The question batch starting at index {current_batch_index} has too many tokens! Returning empty string.")
101+
res = {"response": ["" for _ in batch_prompts], "total_tokens": 0}
82102

83-
# parse result
84103
token_cnt += res["total_tokens"]
104+
105+
# Process each item in the batch
106+
final_sqls_for_batch = []
85107
if args.n == 1:
86108
for sql in res["response"]:
87-
# remove \n and extra spaces
88109
sql = sql.replace("```", " ")
89-
# keep only the content starting from first SELECT if present
90110
idx = sql.upper().find("SELECT")
91111
if idx != -1:
92112
sql = sql[idx:]
93113
sql = " ".join(sql.replace("\n", " ").split())
94114
sql = process_duplication(sql)
95-
# python version should >= 3.8
96-
if sql.startswith("SELECT"):
97-
f.write(sql + "\n")
98-
elif sql.startswith(" "):
99-
f.write("SELECT" + sql + "\n")
100-
else:
101-
f.write("SELECT " + sql + "\n")
102-
else:
103-
results = []
104-
cur_db_ids = db_ids[i * args.batch_size: i * args.batch_size + len(batch)]
105-
for sqls, db_id in zip(res["response"], cur_db_ids):
115+
if not sql.upper().startswith("SELECT"):
116+
sql = "SELECT " + sql
117+
final_sqls_for_batch.append(sql)
118+
else: # Self-consistency voting
119+
results_for_voting = []
120+
db_ids_batch = batch_data['db_id']
121+
for j in range(len(batch_prompts)):
122+
sqls = res["response"][j] # res["response"] is a list of lists if n > 1
106123
processed_sqls = []
107124
for sql in sqls:
108125
sql = sql.replace("```", " ")
@@ -111,19 +128,101 @@
111128
sql = sql[idx:]
112129
sql = " ".join(sql.replace("\n", " ").split())
113130
sql = process_duplication(sql)
114-
if sql.startswith("SELECT"):
115-
pass
116-
elif sql.startswith(" "):
117-
sql = "SELECT" + sql
118-
else:
131+
if not sql.upper().startswith("SELECT"):
119132
sql = "SELECT " + sql
120133
processed_sqls.append(sql)
121-
result = {
122-
'db_id': db_id,
134+
135+
results_for_voting.append({
136+
'db_id': db_ids_batch[j],
123137
'p_sqls': processed_sqls
124-
}
125-
final_sqls = get_sqls([result], args.n, args.db_dir)
126-
127-
for sql in final_sqls:
128-
f.write(sql + "\n")
129-
138+
})
139+
140+
final_sqls_for_batch = get_sqls(results_for_voting, args.n, args.db_dir)
141+
142+
# MODIFICATION: Live evaluation for each predicted SQL in the batch
143+
for j, predicted_sql in enumerate(final_sqls_for_batch):
144+
item_index = current_batch_index + j
145+
146+
# Write to file first to save the prediction
147+
f.write(predicted_sql + "\n")
148+
149+
# Get corresponding gold query and db_id from the batch_data
150+
gold_response = batch_data['response'][j]
151+
gold_sql = "SELECT " + gold_response
152+
db_id = batch_data['db_id'][j]
153+
154+
db_path = os.path.join(args.db_dir, db_id, db_id + ".sqlite")
155+
156+
# Perform evaluation
157+
try:
158+
# eval_exec_match is not async, so we don't need to run it in an event loop here.
159+
# It handles its own asyncio.run call internally.
160+
exec_score = eval_exec_match(
161+
db=db_path,
162+
p_str=predicted_sql,
163+
g_str=gold_sql,
164+
plug_value=False,
165+
keep_distinct=False,
166+
progress_bar_for_each_datapoint=False
167+
)
168+
except Exception as e:
169+
print(f"Error evaluating question {item_index}: {e}")
170+
exec_score = 0 # Consider it incorrect if evaluation fails
171+
172+
total_questions += 1
173+
if exec_score == 1:
174+
correct_predictions += 1
175+
result_msg = f"Question {item_index} - CORRECT"
176+
print(result_msg)
177+
eval_f.write(result_msg + "\n")
178+
else:
179+
result_msg = f"Question {item_index} - INCORRECT"
180+
gold_msg = f" - Gold: {gold_sql}"
181+
pred_msg = f" - Pred: {predicted_sql}"
182+
print(result_msg)
183+
print(gold_msg)
184+
print(pred_msg)
185+
eval_f.write(result_msg + "\n")
186+
eval_f.write(gold_msg + "\n")
187+
eval_f.write(pred_msg + "\n")
188+
189+
# Calculate and print running accuracy
190+
if total_questions > 0:
191+
running_accuracy = (correct_predictions / total_questions) * 100
192+
accuracy_msg = f"Running Accuracy: {running_accuracy:.2f}% ({correct_predictions}/{total_questions})"
193+
print(accuracy_msg)
194+
eval_f.write(accuracy_msg + "\n")
195+
196+
# Ensure the file is written to disk after each batch
197+
f.flush()
198+
eval_f.flush()
199+
200+
# MODIFICATION: Print and save final results
201+
separator = "\n" + "="*20
202+
header = " FINAL RESULTS "
203+
print(separator)
204+
print(header)
205+
print("="*20)
206+
207+
with open(eval_out_file, "a") as eval_f:
208+
eval_f.write(separator + "\n")
209+
eval_f.write(header + "\n")
210+
eval_f.write("="*20 + "\n")
211+
212+
if total_questions > 0:
213+
final_accuracy = (correct_predictions / total_questions) * 100
214+
total_msg = f"Total Questions Evaluated: {total_questions}"
215+
correct_msg = f"Correct Predictions: {correct_predictions}"
216+
accuracy_msg = f"Final Execution Accuracy: {final_accuracy:.2f}%"
217+
218+
print(total_msg)
219+
print(correct_msg)
220+
print(accuracy_msg)
221+
222+
eval_f.write(total_msg + "\n")
223+
eval_f.write(correct_msg + "\n")
224+
eval_f.write(accuracy_msg + "\n")
225+
else:
226+
no_eval_msg = "No questions were evaluated."
227+
print(no_eval_msg)
228+
eval_f.write(no_eval_msg + "\n")

0 commit comments

Comments
 (0)