DAMO-ConvAI/star/data_systhesis/preprocess.ipynb

276 lines
7.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4ceb7b34-bb25-4dbe-9530-19f842c0920b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from utils.schema import *\n",
"from utils.reverse_logic import *\n",
"from utils.reverse_sql import *\n",
"from utils.reverse_middle import *\n",
"import os, sqlite3\n",
"import random\n",
"import copy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7a23f470-b294-4807-8974-0bd1fefb2a7e",
"metadata": {},
"outputs": [],
"source": [
"with open('raw_data/text_to_sql_data.json','r') as f:\n",
" data = json.load(f)\n",
"\n",
"with open('raw_data/tables.json','r') as f:\n",
" db = json.load(f)\n",
"db_dict = {}\n",
"schemas, db_names, tables = get_schemas_from_json('raw_data/tables.json')\n",
"for item in db:\n",
" db_dict[item['db_id']] = item\n",
"\n",
"db_reverse = {}\n",
"db_c_ori2pre = {}\n",
"db_t_ori2pre = {}\n",
"for k,v in db_dict.items():\n",
" db_reverse[k] = {}\n",
" column = []\n",
" table = []\n",
" db_c_ori2pre[k] = {}\n",
" db_t_ori2pre[k] = {}\n",
" for o,c in zip(v['column_names'],v['column_names_original']):\n",
" column.append([o[0],o[1],c[1]])\n",
" db_c_ori2pre[k][c[1]] = o[1]\n",
" db_reverse[k]['column'] = column\n",
" for o,c in zip(v['table_names'],v['table_names_original']):\n",
" table.append([o,c])\n",
" db_t_ori2pre[k][c] = o\n",
" db_reverse[k]['table'] = table\n",
"db_aug = {}\n",
"for k,v in db_dict.items():\n",
" col_temp = {}\n",
" for index, item in enumerate(v['column_names']):\n",
" if item[0] not in col_temp.keys():\n",
" if item[1] in v['primary_keys']:\n",
" col_temp[item[0]] = [(index,item[1],v['column_types'][index],'primary')]\n",
" else:\n",
" col_temp[item[0]] = [(index,item[1],v['column_types'][index],'None')]\n",
" else:\n",
" if item[1] in v['primary_keys']:\n",
" col_temp[item[0]].append((index,item[1],v['column_types'][index],'primary'))\n",
" else:\n",
" col_temp[item[0]].append((index,item[1],v['column_types'][index],'None'))\n",
" db_aug[k] = col_temp"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "67f791a5-4dae-4ca2-a8c0-ac6c510e39d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8314\n",
"10191\n",
"27207\n",
"34627\n",
"2719\n",
"3395\n",
"8120\n",
"8739\n",
"8740\n",
"9142\n",
"9143\n",
"9897\n",
"9898\n",
"10113\n",
"10114\n",
"10115\n",
"11053\n",
"11054\n",
"11055\n",
"11417\n",
"11418\n",
"12356\n",
"12357\n",
"14995\n",
"14996\n",
"18075\n",
"18076\n",
"18077\n",
"18108\n",
"18109\n",
"18110\n",
"18111\n",
"18164\n",
"18165\n",
"19200\n",
"19201\n",
"20287\n",
"20435\n",
"22637\n",
"26839\n",
"26840\n",
"27898\n",
"27899\n",
"27920\n",
"27921\n",
"28075\n",
"28076\n",
"28105\n",
"28106\n",
"28107\n",
"28108\n",
"29514\n",
"29530\n",
"30916\n",
"30955\n",
"31184\n",
"31218\n",
"31345\n",
"31512\n",
"32400\n",
"33131\n",
"34321\n"
]
}
],
"source": [
"final = []\n",
"for index,item in enumerate(data):\n",
" temp = {}\n",
" db_id = item['db_id']\n",
" if db_id in ['formula_1','store_1','scholar']:\n",
" continue\n",
" sql = item['query']\n",
" schema = schemas[db_id]\n",
" table = tables[db_id]\n",
" schema = Schema(schema, table)\n",
" try:\n",
" sql_label = get_sql(schema, sql)\n",
" temp['struct'] = sql_label\n",
" flag = True\n",
" for titem in temp['struct']['where']:\n",
" if titem == 'and' or titem == 'or':\n",
" continue\n",
" elif isinstance(titem[3],tuple):\n",
" flag = False\n",
"\n",
" temp['query'] = item['query']\n",
" temp['question'] = item['question']\n",
" for k, v in db_dict[db_id].items():\n",
" temp[k] = v\n",
" if flag:\n",
" final.append(temp)\n",
" except:\n",
" print(index)\n",
"data_final = []\n",
"sql_all = []\n",
"for index,item in enumerate(final):\n",
" try:\n",
" if len(item['struct']['from'][\"table_units\"]) > 1 and len(item['struct']['from'][\"conds\"]) == 0:\n",
" continue\n",
" a = reverse(item['struct'],item['column_names'])\n",
" logic = translate_logic(a,0,item['db_id'],db_reverse)\n",
" sql = translate_sql(a,0,item['db_id'],db_reverse)\n",
" temp = {}\n",
" temp['db_id'] = item['db_id']\n",
" temp['query_ori'] = item['query']\n",
" temp['struct'] = item['struct'] \n",
" temp['logic'] = logic\n",
" temp['query'] = sql\n",
" temp['middle'] = a\n",
" temp['question'] = item['question']\n",
" if temp['query'] not in sql_all:\n",
" data_final.append(temp)\n",
" sql_all.append(temp['query'])\n",
" except:\n",
" print(index)\n",
" pass\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "83e86a90-3871-4b2e-ac6f-c32f93b19731",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15281"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data_final)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "097df3ae-b9fc-4f38-95ca-3ffc22355cbc",
"metadata": {},
"outputs": [],
"source": [
"data_raw = []\n",
"for item in data_final:\n",
" temp = {}\n",
" temp['question'] = item['question']\n",
" temp['query'] = item['query']\n",
" data_raw.append(temp)\n",
"data_prep = {}\n",
"for item in data_final:\n",
" data_prep[item['query']] = item['logic']\n",
"with open('preprocessed/question_sql.json','w') as f:\n",
" json.dump(data_raw,f)\n",
"with open('preprocessed/logic.json','w') as f:\n",
" json.dump(data_prep,f)\n",
"with open ('preprocessed/alldata.json','w') as f:\n",
" json.dump(data_final,f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4af6facf-5c80-45e3-aba3-6ad377c7860f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}