-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathschema_linking.py
110 lines (90 loc) · 3.74 KB
/
schema_linking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from concurrent.futures import ThreadPoolExecutor
import threading
import time
import json
from datasets import load_dataset
from tqdm import tqdm
# import helpers
from eval.scripts.helpers import reorder_tables, reorder_columns, extract_json_object, extract_list_object, ask_chatgpt
# output path
output_path = 'data/validation_sql_ranked.json'
context_dataset = load_dataset("richardr1126/spider-context-validation", split="validation")
index = int(input('Enter the index to start at: '))
# Global lock for writing to file
lock = threading.Lock()
def process_entry(i):
dataset_entry = context_dataset[i]
prompt = f"""
Given the database schema and question, perform the following actions:
1 - Rank all the tables based on the possibility of being used in the SQL according to the question from
the most relevant to the least relevant, Table or its column that matches more with the question words is
highly relevant and must be placed ahead.
2 - Check whether you consider all the tables.
3 - Output a list object in the order of step 2, Your output should contain all the tables. The format should
be like:
[
"table_1", "table_2", ...
]
Schema:
{dataset_entry['db_info']}
Question:
### {dataset_entry['question']}
"""
#(prompt)
response = ask_chatgpt(prompt)
#print(response)
ranked_tables = extract_list_object(response)
#print(ranked_tables)
reordered_schema = reorder_tables(dataset_entry['db_info'], ranked_tables)
#print(reordered_schema)
col_reorder_prompt = f"""
Given the database tables and question, perform the following actions:
1 - Rank the columns in each table based on the possibility of being used in the SQL, Column that
matches more with the question words or the foreign key is highly relevant and must be placed ahead.
You should output them in the order of the most relevant to the least relevant.
Explain why you choose each column.
2 - Output a JSON object that contains all the columns in each table according to your explanation. The
format should be like:
{{
"table_1": ["column_1", "column_2", ......],
"table_2": ["column_1", "column_2", ......],
"table_3": ["column_1", "column_2", ......],
......
}}
Schema:
{reordered_schema}
Question:
### {dataset_entry['question']}
"""
response = ask_chatgpt(col_reorder_prompt)
#print(response)
ranked_cols = extract_json_object(response)
#print(ranked_cols)
reordered_schema = reorder_columns(reordered_schema, ranked_cols)
#print(reordered_schema)
if (not reordered_schema) or (reordered_schema == ""):
# redo
process_entry(i)
return
output_entry = {
"index": i, # For debugging purposes
"db_id": dataset_entry["db_id"],
"question": dataset_entry["question"],
"db_info": reordered_schema,
"ground_truth": dataset_entry["ground_truth"],
}
# Append the output_entry to the existing list
with lock:
output_dataset[i] = output_entry
# Write the updated list back to the JSON file
with lock:
with open(output_path, 'w') as f:
json.dump(output_dataset, f, indent=2, ensure_ascii=False)
# Initialize the list by reading from the existing JSON file or create a new list if the file doesn't exist.
try:
with open(output_path, 'r') as f:
output_dataset = json.load(f)
except FileNotFoundError:
output_dataset = [None] * len(context_dataset) # Pre-allocate list with Nones
with ThreadPoolExecutor(max_workers=8) as executor: # You can change max_workers based on your system capabilities
list(tqdm(executor.map(process_entry, range(index, len(context_dataset))), total=len(context_dataset)-index))