276 lines
7.4 KiB
Plaintext
276 lines
7.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "4ceb7b34-bb25-4dbe-9530-19f842c0920b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"from utils.schema import *\n",
|
|
"from utils.reverse_logic import *\n",
|
|
"from utils.reverse_sql import *\n",
|
|
"from utils.reverse_middle import *\n",
|
|
"import os, sqlite3\n",
|
|
"import random\n",
|
|
"import copy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "7a23f470-b294-4807-8974-0bd1fefb2a7e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open('raw_data/text_to_sql_data.json','r') as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
"with open('raw_data/tables.json','r') as f:\n",
|
|
" db = json.load(f)\n",
|
|
"db_dict = {}\n",
|
|
"schemas, db_names, tables = get_schemas_from_json('raw_data/tables.json')\n",
|
|
"for item in db:\n",
|
|
" db_dict[item['db_id']] = item\n",
|
|
"\n",
|
|
"db_reverse = {}\n",
|
|
"db_c_ori2pre = {}\n",
|
|
"db_t_ori2pre = {}\n",
|
|
"for k,v in db_dict.items():\n",
|
|
" db_reverse[k] = {}\n",
|
|
" column = []\n",
|
|
" table = []\n",
|
|
" db_c_ori2pre[k] = {}\n",
|
|
" db_t_ori2pre[k] = {}\n",
|
|
" for o,c in zip(v['column_names'],v['column_names_original']):\n",
|
|
" column.append([o[0],o[1],c[1]])\n",
|
|
" db_c_ori2pre[k][c[1]] = o[1]\n",
|
|
" db_reverse[k]['column'] = column\n",
|
|
" for o,c in zip(v['table_names'],v['table_names_original']):\n",
|
|
" table.append([o,c])\n",
|
|
" db_t_ori2pre[k][c] = o\n",
|
|
" db_reverse[k]['table'] = table\n",
|
|
"db_aug = {}\n",
|
|
"for k,v in db_dict.items():\n",
|
|
" col_temp = {}\n",
|
|
" for index, item in enumerate(v['column_names']):\n",
|
|
" if item[0] not in col_temp.keys():\n",
|
|
" if item[1] in v['primary_keys']:\n",
|
|
" col_temp[item[0]] = [(index,item[1],v['column_types'][index],'primary')]\n",
|
|
" else:\n",
|
|
" col_temp[item[0]] = [(index,item[1],v['column_types'][index],'None')]\n",
|
|
" else:\n",
|
|
" if item[1] in v['primary_keys']:\n",
|
|
" col_temp[item[0]].append((index,item[1],v['column_types'][index],'primary'))\n",
|
|
" else:\n",
|
|
" col_temp[item[0]].append((index,item[1],v['column_types'][index],'None'))\n",
|
|
" db_aug[k] = col_temp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "67f791a5-4dae-4ca2-a8c0-ac6c510e39d1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"8314\n",
|
|
"10191\n",
|
|
"27207\n",
|
|
"34627\n",
|
|
"2719\n",
|
|
"3395\n",
|
|
"8120\n",
|
|
"8739\n",
|
|
"8740\n",
|
|
"9142\n",
|
|
"9143\n",
|
|
"9897\n",
|
|
"9898\n",
|
|
"10113\n",
|
|
"10114\n",
|
|
"10115\n",
|
|
"11053\n",
|
|
"11054\n",
|
|
"11055\n",
|
|
"11417\n",
|
|
"11418\n",
|
|
"12356\n",
|
|
"12357\n",
|
|
"14995\n",
|
|
"14996\n",
|
|
"18075\n",
|
|
"18076\n",
|
|
"18077\n",
|
|
"18108\n",
|
|
"18109\n",
|
|
"18110\n",
|
|
"18111\n",
|
|
"18164\n",
|
|
"18165\n",
|
|
"19200\n",
|
|
"19201\n",
|
|
"20287\n",
|
|
"20435\n",
|
|
"22637\n",
|
|
"26839\n",
|
|
"26840\n",
|
|
"27898\n",
|
|
"27899\n",
|
|
"27920\n",
|
|
"27921\n",
|
|
"28075\n",
|
|
"28076\n",
|
|
"28105\n",
|
|
"28106\n",
|
|
"28107\n",
|
|
"28108\n",
|
|
"29514\n",
|
|
"29530\n",
|
|
"30916\n",
|
|
"30955\n",
|
|
"31184\n",
|
|
"31218\n",
|
|
"31345\n",
|
|
"31512\n",
|
|
"32400\n",
|
|
"33131\n",
|
|
"34321\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"final = []\n",
|
|
"for index,item in enumerate(data):\n",
|
|
" temp = {}\n",
|
|
" db_id = item['db_id']\n",
|
|
" if db_id in ['formula_1','store_1','scholar']:\n",
|
|
" continue\n",
|
|
" sql = item['query']\n",
|
|
" schema = schemas[db_id]\n",
|
|
" table = tables[db_id]\n",
|
|
" schema = Schema(schema, table)\n",
|
|
" try:\n",
|
|
" sql_label = get_sql(schema, sql)\n",
|
|
" temp['struct'] = sql_label\n",
|
|
" flag = True\n",
|
|
" for titem in temp['struct']['where']:\n",
|
|
" if titem == 'and' or titem == 'or':\n",
|
|
" continue\n",
|
|
" elif isinstance(titem[3],tuple):\n",
|
|
" flag = False\n",
|
|
"\n",
|
|
" temp['query'] = item['query']\n",
|
|
" temp['question'] = item['question']\n",
|
|
" for k, v in db_dict[db_id].items():\n",
|
|
" temp[k] = v\n",
|
|
" if flag:\n",
|
|
" final.append(temp)\n",
|
|
" except:\n",
|
|
" print(index)\n",
|
|
"data_final = []\n",
|
|
"sql_all = []\n",
|
|
"for index,item in enumerate(final):\n",
|
|
" try:\n",
|
|
" if len(item['struct']['from'][\"table_units\"]) > 1 and len(item['struct']['from'][\"conds\"]) == 0:\n",
|
|
" continue\n",
|
|
" a = reverse(item['struct'],item['column_names'])\n",
|
|
" logic = translate_logic(a,0,item['db_id'],db_reverse)\n",
|
|
" sql = translate_sql(a,0,item['db_id'],db_reverse)\n",
|
|
" temp = {}\n",
|
|
" temp['db_id'] = item['db_id']\n",
|
|
" temp['query_ori'] = item['query']\n",
|
|
" temp['struct'] = item['struct'] \n",
|
|
" temp['logic'] = logic\n",
|
|
" temp['query'] = sql\n",
|
|
" temp['middle'] = a\n",
|
|
" temp['question'] = item['question']\n",
|
|
" if temp['query'] not in sql_all:\n",
|
|
" data_final.append(temp)\n",
|
|
" sql_all.append(temp['query'])\n",
|
|
" except:\n",
|
|
" print(index)\n",
|
|
" pass\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "83e86a90-3871-4b2e-ac6f-c32f93b19731",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"15281"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(data_final)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "097df3ae-b9fc-4f38-95ca-3ffc22355cbc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_raw = []\n",
|
|
"for item in data_final:\n",
|
|
" temp = {}\n",
|
|
" temp['question'] = item['question']\n",
|
|
" temp['query'] = item['query']\n",
|
|
" data_raw.append(temp)\n",
|
|
"data_prep = {}\n",
|
|
"for item in data_final:\n",
|
|
" data_prep[item['query']] = item['logic']\n",
|
|
"with open('preprocessed/question_sql.json','w') as f:\n",
|
|
" json.dump(data_raw,f)\n",
|
|
"with open('preprocessed/logic.json','w') as f:\n",
|
|
" json.dump(data_prep,f)\n",
|
|
"with open ('preprocessed/alldata.json','w') as f:\n",
|
|
" json.dump(data_final,f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4af6facf-5c80-45e3-aba3-6ad377c7860f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|