1111
1212from utils .post_process import process_duplication , get_sqls
1313
14+ # MODIFICATION: Import the evaluation function and other necessary modules
15+ from eval .exec_eval import eval_exec_match
16+ import asyncio
17+
1418QUESTION_FILE = "questions.json"
1519
1620
1721if __name__ == '__main__' :
1822 parser = argparse .ArgumentParser ()
19- parser .add_argument ("--question" , type = str )
20- parser .add_argument ("--openai_api_key" , type = str )
23+ parser .add_argument ("--question" , type = str , required = True )
24+ parser .add_argument ("--openai_api_key" , type = str , required = True )
2125 parser .add_argument ("--openai_group_id" , type = str , default = "org-ktBefi7n9aK7sZjwc2R9G1Wo" )
26+ parser .add_argument ("--openai_api_base" , type = str , default = "" , help = "Custom API base URL for Ollama or other OpenAI-compatible APIs" )
2227 parser .add_argument ("--model" , type = str , choices = [LLM .TEXT_DAVINCI_003 ,
2328 LLM .GPT_35_TURBO ,
2429 LLM .GPT_35_TURBO_0613 ,
25- # LLM.TONG_YI_QIAN_WEN,
26- LLM .GPT_35_TURBO_16K ,
2730 LLM .GPT_4 ,
2831 LLM .OLLAMA_CODELLAMA_7B ,
2932 LLM .OLLAMA_DEEPSEEK_CODER_6_7B ],
3336 parser .add_argument ("--temperature" , type = float , default = 0 )
3437 parser .add_argument ("--mini_index_path" , type = str , default = "" )
3538 parser .add_argument ("--batch_size" , type = int , default = 1 )
36- parser .add_argument ("--n" , type = int , default = 5 , help = "Size of self-consistent set" )
37- parser .add_argument ("--db_dir" , type = str , default = "dataset/spider/database" )
39+ parser .add_argument ("--n" , type = int , default = 1 , help = "Size of self-consistent set" )
40+ parser .add_argument ("--db_dir" , type = str , default = "dataset/spider/database" , help = "Path to the database directory" )
3841 args = parser .parse_args ()
3942
4043 # check args (Ollama path currently supports only batch_size==1)
4447 f"{ args .model } doesn't support batch_size > 1"
4548
4649 questions_json = json .load (open (os .path .join (args .question , QUESTION_FILE ), "r" ))
47- questions = [_ ["prompt" ] for _ in questions_json ["questions" ]]
48- db_ids = [_ ["db_id" ] for _ in questions_json ["questions" ]]
50+
51+ # MODIFICATION: We need the full question objects, not just the prompts
52+ all_questions_data = questions_json ["questions" ]
4953
5054 # init openai api
51- init_chatgpt (args .openai_api_key , args .openai_group_id , args .model )
55+ init_chatgpt (args .openai_api_key , args .openai_group_id , args .model , args . openai_api_base )
5256
5357 if args .start_index == 0 :
5458 mode = "w"
5559 else :
5660 mode = "a"
5761
58- # sanitize model name for filesystem (e.g., Windows disallows ":")
5962 safe_model = args .model .replace (":" , "_" ).replace ("/" , "_" )
6063
6164 if args .mini_index_path :
6265 mini_index = json .load (open (args .mini_index_path , 'r' ))
63- questions = [questions [i ] for i in mini_index ]
66+ # MODIFICATION: Filter the full data objects
67+ all_questions_data = [all_questions_data [i ] for i in mini_index ]
6468 out_file = f"{ args .question } /RESULTS_MODEL-{ safe_model } _MINI.txt"
6569 else :
6670 out_file = f"{ args .question } /RESULTS_MODEL-{ safe_model } .txt"
6771
68- question_loader = DataLoader (questions , batch_size = args .batch_size , shuffle = False , drop_last = False )
72+ # MODIFICATION: Create evaluation results file path
73+ eval_out_file = os .path .join ("results" , f"eval_{ safe_model } .txt" )
74+
75+ # The DataLoader will now handle dictionaries
76+ question_loader = DataLoader (all_questions_data , batch_size = args .batch_size , shuffle = False , drop_last = False )
77+
78+ # MODIFICATION: Add counters for live evaluation
79+ total_questions = 0
80+ correct_predictions = 0
6981
7082 token_cnt = 0
71- with open (out_file , mode ) as f :
72- for i , batch in enumerate (tqdm (question_loader )):
73- if i < args .start_index :
83+ with open (out_file , mode ) as f , open (eval_out_file , mode ) as eval_f :
84+ # MODIFICATION: Enumerate provides an index starting from 0
85+ for i , batch_data in enumerate (tqdm (question_loader )):
86+
87+ # The DataLoader might return lists of values for each key if batch_size > 1
88+ # We need to reconstruct the list of dicts
89+ batch_prompts = batch_data ['prompt' ]
90+
91+ current_batch_index = i * args .batch_size
92+ if current_batch_index < args .start_index :
7493 continue
75- if i >= args .end_index :
94+ if current_batch_index >= args .end_index :
7695 break
96+
7797 try :
78- res = ask_llm (args .model , batch , args .temperature , args .n )
98+ res = ask_llm (args .model , batch_prompts , args .temperature , args .n )
7999 except openai .error .InvalidRequestError :
80- print (f"The { i } -th question has too much tokens! Return \" SELECT \" instead " )
81- res = ""
100+ print (f"The question batch starting at index { current_batch_index } has too many tokens! Returning empty string. " )
101+ res = { "response" : [ "" for _ in batch_prompts ], "total_tokens" : 0 }
82102
83- # parse result
84103 token_cnt += res ["total_tokens" ]
104+
105+ # Process each item in the batch
106+ final_sqls_for_batch = []
85107 if args .n == 1 :
86108 for sql in res ["response" ]:
87- # remove \n and extra spaces
88109 sql = sql .replace ("```" , " " )
89- # keep only the content starting from first SELECT if present
90110 idx = sql .upper ().find ("SELECT" )
91111 if idx != - 1 :
92112 sql = sql [idx :]
93113 sql = " " .join (sql .replace ("\n " , " " ).split ())
94114 sql = process_duplication (sql )
95- # python version should >= 3.8
96- if sql .startswith ("SELECT" ):
97- f .write (sql + "\n " )
98- elif sql .startswith (" " ):
99- f .write ("SELECT" + sql + "\n " )
100- else :
101- f .write ("SELECT " + sql + "\n " )
102- else :
103- results = []
104- cur_db_ids = db_ids [i * args .batch_size : i * args .batch_size + len (batch )]
105- for sqls , db_id in zip (res ["response" ], cur_db_ids ):
115+ if not sql .upper ().startswith ("SELECT" ):
116+ sql = "SELECT " + sql
117+ final_sqls_for_batch .append (sql )
118+ else : # Self-consistency voting
119+ results_for_voting = []
120+ db_ids_batch = batch_data ['db_id' ]
121+ for j in range (len (batch_prompts )):
122+ sqls = res ["response" ][j ] # res["response"] is a list of lists if n > 1
106123 processed_sqls = []
107124 for sql in sqls :
108125 sql = sql .replace ("```" , " " )
111128 sql = sql [idx :]
112129 sql = " " .join (sql .replace ("\n " , " " ).split ())
113130 sql = process_duplication (sql )
114- if sql .startswith ("SELECT" ):
115- pass
116- elif sql .startswith (" " ):
117- sql = "SELECT" + sql
118- else :
131+ if not sql .upper ().startswith ("SELECT" ):
119132 sql = "SELECT " + sql
120133 processed_sqls .append (sql )
121- result = {
122- 'db_id' : db_id ,
134+
135+ results_for_voting .append ({
136+ 'db_id' : db_ids_batch [j ],
123137 'p_sqls' : processed_sqls
124- }
125- final_sqls = get_sqls ([result ], args .n , args .db_dir )
126-
127- for sql in final_sqls :
128- f .write (sql + "\n " )
129-
138+ })
139+
140+ final_sqls_for_batch = get_sqls (results_for_voting , args .n , args .db_dir )
141+
142+ # MODIFICATION: Live evaluation for each predicted SQL in the batch
143+ for j , predicted_sql in enumerate (final_sqls_for_batch ):
144+ item_index = current_batch_index + j
145+
146+ # Write to file first to save the prediction
147+ f .write (predicted_sql + "\n " )
148+
149+ # Get corresponding gold query and db_id from the batch_data
150+ gold_response = batch_data ['response' ][j ]
151+ gold_sql = "SELECT " + gold_response
152+ db_id = batch_data ['db_id' ][j ]
153+
154+ db_path = os .path .join (args .db_dir , db_id , db_id + ".sqlite" )
155+
156+ # Perform evaluation
157+ try :
158+ # eval_exec_match is not async, so we don't need to run it in an event loop here.
159+ # It handles its own asyncio.run call internally.
160+ exec_score = eval_exec_match (
161+ db = db_path ,
162+ p_str = predicted_sql ,
163+ g_str = gold_sql ,
164+ plug_value = False ,
165+ keep_distinct = False ,
166+ progress_bar_for_each_datapoint = False
167+ )
168+ except Exception as e :
169+ print (f"Error evaluating question { item_index } : { e } " )
170+ exec_score = 0 # Consider it incorrect if evaluation fails
171+
172+ total_questions += 1
173+ if exec_score == 1 :
174+ correct_predictions += 1
175+ result_msg = f"Question { item_index } - CORRECT"
176+ print (result_msg )
177+ eval_f .write (result_msg + "\n " )
178+ else :
179+ result_msg = f"Question { item_index } - INCORRECT"
180+ gold_msg = f" - Gold: { gold_sql } "
181+ pred_msg = f" - Pred: { predicted_sql } "
182+ print (result_msg )
183+ print (gold_msg )
184+ print (pred_msg )
185+ eval_f .write (result_msg + "\n " )
186+ eval_f .write (gold_msg + "\n " )
187+ eval_f .write (pred_msg + "\n " )
188+
189+ # Calculate and print running accuracy
190+ if total_questions > 0 :
191+ running_accuracy = (correct_predictions / total_questions ) * 100
192+ accuracy_msg = f"Running Accuracy: { running_accuracy :.2f} % ({ correct_predictions } /{ total_questions } )"
193+ print (accuracy_msg )
194+ eval_f .write (accuracy_msg + "\n " )
195+
196+ # Ensure the file is written to disk after each batch
197+ f .flush ()
198+ eval_f .flush ()
199+
200+ # MODIFICATION: Print and save final results
201+ separator = "\n " + "=" * 20
202+ header = " FINAL RESULTS "
203+ print (separator )
204+ print (header )
205+ print ("=" * 20 )
206+
207+ with open (eval_out_file , "a" ) as eval_f :
208+ eval_f .write (separator + "\n " )
209+ eval_f .write (header + "\n " )
210+ eval_f .write ("=" * 20 + "\n " )
211+
212+ if total_questions > 0 :
213+ final_accuracy = (correct_predictions / total_questions ) * 100
214+ total_msg = f"Total Questions Evaluated: { total_questions } "
215+ correct_msg = f"Correct Predictions: { correct_predictions } "
216+ accuracy_msg = f"Final Execution Accuracy: { final_accuracy :.2f} %"
217+
218+ print (total_msg )
219+ print (correct_msg )
220+ print (accuracy_msg )
221+
222+ eval_f .write (total_msg + "\n " )
223+ eval_f .write (correct_msg + "\n " )
224+ eval_f .write (accuracy_msg + "\n " )
225+ else :
226+ no_eval_msg = "No questions were evaluated."
227+ print (no_eval_msg )
228+ eval_f .write (no_eval_msg + "\n " )
0 commit comments