add: r2sql repo
This commit is contained in:
parent
d3eba969df
commit
2659e4dd89
|
@ -0,0 +1,145 @@
|
|||
# R²SQL
|
||||
The PyTorch implementation of paper [Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing.](https://arxiv.org/pdf/2101.01686) (AAAI 2021)
|
||||
|
||||
|
||||
## Requirements
|
||||
The model is tested in python 3.6 with following requirements:
|
||||
```
|
||||
torch==1.0.0
|
||||
transformers==2.10.0
|
||||
sqlparse
|
||||
pymysql
|
||||
progressbar
|
||||
nltk
|
||||
numpy
|
||||
six
|
||||
spacy
|
||||
```
|
||||
All experiments on SParC and CoSQL datasets were run on NVIDIA V100 GPU with 32GB GPU memory.
|
||||
* Tips: The 16GB GPU memory may appear out-of-memory error.
|
||||
|
||||
## Setup
|
||||
|
||||
The SParC and CoSQL experiments in two different folders, you need to download different datasets from [[SParC](https://yale-lily.github.io/spider) | [CoSQL](https://yale-lily.github.io/cosql)] to the `{sparc|cosql}/data` folder separately.
|
||||
Another related data file could be download from [EditSQL](https://github.com/ryanzhumich/editsql/tree/master/data).
|
||||
Then, download the database sqlite files from [[here](https://drive.google.com/file/d/1a828mkHcgyQCBgVla0jGxKJ58aV8RsYK/view?usp=sharing)] as `data/database`.
|
||||
|
||||
Download Pretrained BERT model from [[here](https://drive.google.com/file/d/1f_LEWVgrtZLRuoiExJa5fNzTS8-WcAX9/view?usp=sharing)] as `model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin`.
|
||||
|
||||
Download Glove embeddings file (`glove.840B.300d.txt`) and change the `GLOVE_PATH` for your own path in all scripts.
|
||||
|
||||
Download Reranker models from [[SParC reranker](https://drive.google.com/file/d/1cA106xgSx6KeonOxD2sZ06Eolptxt_OG/view?usp=sharing) | [CoSQL reranker](https://drive.google.com/file/d/1UURYw15T6zORcYRTvP51MYkzaxNmvRIU/view?usp=sharing)] as `submit_models/reranker_roberta.pt`, besides the roberta-base model could download from [here](https://drive.google.com/file/d/1LkTe-Z0AFg2dAAWgUKuCLEhSmtW-CWXh/view?usp=sharing) for `./[sparc|cosql]/local_param/`.
|
||||
|
||||
## Usage
|
||||
|
||||
Train the model from scratch.
|
||||
```bash
|
||||
./sparc_train.sh
|
||||
```
|
||||
|
||||
Test the model for the concrete checkpoint:
|
||||
```bash
|
||||
./sparc_test.sh
|
||||
```
|
||||
then the dev prediction file will be appeared in `results` folder, named like `save_%d_predictions.json`.
|
||||
|
||||
Get the evaluation result from the prediction file:
|
||||
```bash
|
||||
./sparc_evaluate.sh
|
||||
```
|
||||
the final result will be appeared in `results` folder, named `*.eval`.
|
||||
|
||||
Similarly, the CoSQL experiments could be reproduced in same way.
|
||||
|
||||
---
|
||||
|
||||
You could download our trained checkpoint and results in here:
|
||||
|
||||
* SParC: [[log](https://drive.google.com/file/d/19ySQ_4x3R-T0cML2uJQBaYI2EyTlPr1G/view?usp=sharing) | [results](https://drive.google.com/file/d/12-kTEnNJKKblPDx5UIz5W0lVvf_sWpyS/view?usp=sharing)]
|
||||
* CoSQL: [[log](https://drive.google.com/file/d/1QaxM8AUu3cQUXIZvCgoqW115tZCcEppl/view?usp=sharing) | [results](https://drive.google.com/file/d/1fCTRagV46gvEKU5XPje0Um69rMkEAztU/view?usp=sharing)]
|
||||
|
||||
### Reranker
|
||||
If your want train your own reranker model, you could download the training file from here:
|
||||
|
||||
* SParC: [[reranker training data](https://drive.google.com/file/d/1XEiYUmDsVGouCO6NZS1yyMkUDxvWgCZ9/view?usp=sharing)]
|
||||
* CoSQL: [[reranker training data](https://drive.google.com/file/d/1mzjywnMiABOTHYC9BWOoUOn4HnokcX8i/view?usp=sharing)]
|
||||
|
||||
Then you could train, test and predict it:
|
||||
|
||||
train:
|
||||
```bash
|
||||
python -m reranker.main --train --batch_size 64 --epoches 50
|
||||
```
|
||||
|
||||
test:
|
||||
```bash
|
||||
python -m reranker.main --test --batch_size 64
|
||||
```
|
||||
|
||||
predict:
|
||||
```bash
|
||||
python -m reranker.predict
|
||||
```
|
||||
|
||||
|
||||
## Improvements
|
||||
We have improved the origin version (descripted in paper) and got more performance improvements :partying_face:!
|
||||
|
||||
Compare with the origin version, we have made the following improvements:
|
||||
|
||||
* add the self-ensemble strategy for prediction, which use different epoch checkpoint to get final result. In order to easily perform this strategy, we remove the task-related representation in Reranker module.
|
||||
* remove the decay function in DCRI, we find that DCRI is unstable with decay function, so we let DCRI degenerate into vanilla cross attention.
|
||||
* replace the BERT-based with RoBERTa-based model for Reranker module.
|
||||
|
||||
The final performance comparison on dev as follows:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th colspan="2">SParC</th>
|
||||
<th colspan="2">CoSQL</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>QM</td>
|
||||
<td>IM</td>
|
||||
<td>QM</td>
|
||||
<td>IM</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>EditSQL</td>
|
||||
<td>47.2</td>
|
||||
<td>29.5</td>
|
||||
<td>39.9</td>
|
||||
<td>12.3</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>R²SQL v1 (origin paper)</td>
|
||||
<td>54.1</td>
|
||||
<td>35.2</td>
|
||||
<td>45.7</td>
|
||||
<td>19.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>R²SQL v2 (this repo)</td>
|
||||
<td>54.0</td>
|
||||
<td>35.2</td>
|
||||
<td>46.3</td>
|
||||
<td>19.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>R²SQL v2 + ensemble </td>
|
||||
<td>55.1</td>
|
||||
<td>36.8</td>
|
||||
<td>47.3</td>
|
||||
<td>20.9</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Citation
|
||||
Please star this repo and cite paper if you want to use it in your work.
|
||||
|
||||
|
||||
## Acknowledgments
|
||||
This implementation is based on ["Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions"](https://github.com/ryanzhumich/editsql) EMNLP 2019.
|
||||
|
|
@ -0,0 +1 @@
|
|||
save_* filter=lfs diff=lfs merge=lfs -text
|
|
@ -0,0 +1,4 @@
|
|||
*pyc
|
||||
data/*
|
||||
logs_spider_editsql/*
|
||||
processed_data_spider_removefrom/*
|
|
@ -0,0 +1,4 @@
|
|||
#! /bin/bash
|
||||
|
||||
# 3. get evaluation result
|
||||
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file results/save_14_predictions.json --remove_from
|
|
@ -0,0 +1,4 @@
|
|||
#! /bin/bash
|
||||
|
||||
# 3. get evaluation result
|
||||
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file results/save_14_predictions.json,results/save_9_predictions.json,results/save_6_predictions.json,results/save_3_predictions.json --remove_from
|
|
@ -0,0 +1,36 @@
|
|||
#! /bin/bash
|
||||
|
||||
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
|
||||
LOGDIR="log/cosql"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--save_file="$LOGDIR/save_14"
|
||||
|
||||
# 3. get evaluation result
|
||||
#python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from
|
|
@ -0,0 +1,125 @@
|
|||
#! /bin/bash
|
||||
|
||||
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
|
||||
LOGDIR="log/cosql"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--save_file="$LOGDIR/save_3"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--save_file="$LOGDIR/save_6"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--save_file="$LOGDIR/save_9"
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--save_file="$LOGDIR/save_14"
|
||||
|
||||
|
||||
# 3. get evaluation result
|
||||
#python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from
|
|
@ -0,0 +1,40 @@
|
|||
#! /bin/bash
|
||||
|
||||
# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/
|
||||
python3 preprocess.py --dataset=cosql --remove_from
|
||||
|
||||
# 2. train and evaluate.
|
||||
# the result (models, logs, prediction outputs) are saved in $LOGDIR
|
||||
|
||||
LOGDIR='log/cosql'
|
||||
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
|
||||
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
|
||||
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
|
||||
--embedding_filename=$GLOVE_PATH \
|
||||
--data_directory="processed_data_cosql_removefrom" \
|
||||
--input_key="utterance" \
|
||||
--state_positional_embeddings=1 \
|
||||
--discourse_level_lstm=1 \
|
||||
--use_utterance_attention=1 \
|
||||
--use_previous_query=1 \
|
||||
--use_query_attention=1 \
|
||||
--use_copy_switch=1 \
|
||||
--use_schema_encoder=1 \
|
||||
--use_schema_attention=1 \
|
||||
--use_encoder_attention=1 \
|
||||
--use_bert=1 \
|
||||
--bert_type_abb=uS \
|
||||
--fine_tune_bert=1 \
|
||||
--use_schema_self_attention=1 \
|
||||
--use_schema_encoder_2=1 \
|
||||
--interaction_level=1 \
|
||||
--reweight_batch=1 \
|
||||
--freeze=1 \
|
||||
--train=1 \
|
||||
--logdir=$LOGDIR \
|
||||
--evaluate=1 \
|
||||
--evaluate_split="valid" \
|
||||
--use_predicted_queries=1 \
|
||||
--seed=484
|
|
@ -0,0 +1,262 @@
|
|||
"""Code for identifying and anonymizing entities in NL and SQL."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from . import util
|
||||
|
||||
ENTITY_NAME = "ENTITY"
|
||||
CONSTANT_NAME = "CONSTANT"
|
||||
TIME_NAME = "TIME"
|
||||
SEPARATOR = "#"
|
||||
|
||||
|
||||
def timeval(string):
|
||||
"""Returns the numeric version of a time.
|
||||
|
||||
Inputs:
|
||||
string (str): String representing a time.
|
||||
|
||||
Returns:
|
||||
String representing the absolute time.
|
||||
"""
|
||||
if string.endswith("am") or string.endswith(
|
||||
"pm") and string[:-2].isdigit():
|
||||
numval = int(string[:-2])
|
||||
if len(string) == 3 or len(string) == 4:
|
||||
numval *= 100
|
||||
if string.endswith("pm"):
|
||||
numval += 1200
|
||||
return str(numval)
|
||||
return ""
|
||||
|
||||
|
||||
def is_time(string):
|
||||
"""Returns whether a string represents a time.
|
||||
|
||||
Inputs:
|
||||
string (str): String to check.
|
||||
|
||||
Returns:
|
||||
Whether the string represents a time.
|
||||
"""
|
||||
if string.endswith("am") or string.endswith("pm"):
|
||||
if string[:-2].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def deanonymize(sequence, ent_dict, key):
|
||||
"""Deanonymizes a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): List of tokens to deanonymize.
|
||||
ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary.
|
||||
key (str): The key to use, in this case either natural language or SQL.
|
||||
|
||||
Returns:
|
||||
Deanonymized sequence of tokens.
|
||||
"""
|
||||
new_sequence = []
|
||||
for token in sequence:
|
||||
if token in ent_dict:
|
||||
new_sequence.extend(ent_dict[token][key])
|
||||
else:
|
||||
new_sequence.append(token)
|
||||
|
||||
return new_sequence
|
||||
|
||||
|
||||
class Anonymizer:
|
||||
"""Anonymization class for keeping track of entities in this domain and
|
||||
scripts for anonymizing/deanonymizing.
|
||||
|
||||
Members:
|
||||
anonymization_map (list of dict (str->str)): Containing entities from
|
||||
the anonymization file.
|
||||
entity_types (list of str): All entities in the anonymization file.
|
||||
keys (set of str): Possible keys (types of text handled); in this case it should be
|
||||
one for natural language and another for SQL.
|
||||
entity_set (set of str): entity_types as a set.
|
||||
"""
|
||||
def __init__(self, filename):
|
||||
self.anonymization_map = []
|
||||
self.entity_types = []
|
||||
self.keys = set()
|
||||
|
||||
pairs = [json.loads(line) for line in open(filename).readlines()]
|
||||
for pair in pairs:
|
||||
for key in pair:
|
||||
if key != "type":
|
||||
self.keys.add(key)
|
||||
self.anonymization_map.append(pair)
|
||||
if pair["type"] not in self.entity_types:
|
||||
self.entity_types.append(pair["type"])
|
||||
|
||||
self.entity_types.append(ENTITY_NAME)
|
||||
self.entity_types.append(CONSTANT_NAME)
|
||||
self.entity_types.append(TIME_NAME)
|
||||
|
||||
self.entity_set = set(self.entity_types)
|
||||
|
||||
def get_entity_type_from_token(self, token):
|
||||
"""Gets the type of an entity given an anonymized token.
|
||||
|
||||
Inputs:
|
||||
token (str): The entity token.
|
||||
|
||||
Returns:
|
||||
str, representing the type of the entity.
|
||||
"""
|
||||
# these are in the pattern NAME:#, so just strip the thing after the
|
||||
# colon
|
||||
colon_loc = token.index(SEPARATOR)
|
||||
entity_type = token[:colon_loc]
|
||||
assert entity_type in self.entity_set
|
||||
|
||||
return entity_type
|
||||
|
||||
def is_anon_tok(self, token):
|
||||
"""Returns whether a token is an anonymized token or not.
|
||||
|
||||
Input:
|
||||
token (str): The token to check.
|
||||
|
||||
Returns:
|
||||
bool, whether the token is an anonymized token.
|
||||
"""
|
||||
return token.split(SEPARATOR)[0] in self.entity_set
|
||||
|
||||
def get_anon_id(self, token):
|
||||
"""Gets the entity index (unique ID) for a token.
|
||||
|
||||
Input:
|
||||
token (str): The token to get the index from.
|
||||
|
||||
Returns:
|
||||
int, the token ID if it is an anonymized token; otherwise -1.
|
||||
"""
|
||||
if self.is_anon_tok(token):
|
||||
return self.entity_types.index(token.split(SEPARATOR)[0])
|
||||
else:
|
||||
return -1
|
||||
|
||||
def anonymize(self,
|
||||
sequence,
|
||||
tok_to_entity_dict,
|
||||
key,
|
||||
add_new_anon_toks=False):
|
||||
"""Anonymizes a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): Sequence to anonymize.
|
||||
tok_to_entity_dict (dict): Existing dictionary mapping from anonymized
|
||||
tokens to entities.
|
||||
key (str): Which kind of text this is (natural language or SQL)
|
||||
add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict.
|
||||
|
||||
Returns:
|
||||
list of str, the anonymized sequence.
|
||||
"""
|
||||
# Sort the token-tok-entity dict by the length of the modality.
|
||||
sorted_dict = sorted(tok_to_entity_dict.items(),
|
||||
key=lambda k: len(k[1][key]))[::-1]
|
||||
|
||||
anonymized_sequence = copy.deepcopy(sequence)
|
||||
|
||||
if add_new_anon_toks:
|
||||
type_counts = {}
|
||||
for entity_type in self.entity_types:
|
||||
type_counts[entity_type] = 0
|
||||
for token in tok_to_entity_dict:
|
||||
entity_type = self.get_entity_type_from_token(token)
|
||||
type_counts[entity_type] += 1
|
||||
|
||||
# First find occurrences of things in the anonymization dictionary.
|
||||
for token, modalities in sorted_dict:
|
||||
our_modality = modalities[key]
|
||||
|
||||
# Check if this key's version of the anonymized thing is in our
|
||||
# sequence.
|
||||
while util.subsequence(our_modality, anonymized_sequence):
|
||||
found = False
|
||||
for startidx in range(
|
||||
len(anonymized_sequence) - len(our_modality) + 1):
|
||||
if anonymized_sequence[startidx:startidx +
|
||||
len(our_modality)] == our_modality:
|
||||
anonymized_sequence = anonymized_sequence[:startidx] + [
|
||||
token] + anonymized_sequence[startidx + len(our_modality):]
|
||||
found = True
|
||||
break
|
||||
assert found, "Thought " \
|
||||
+ str(our_modality) + " was in [" \
|
||||
+ str(anonymized_sequence) + "] but could not find it"
|
||||
|
||||
# Now add new keys if they are present.
|
||||
if add_new_anon_toks:
|
||||
|
||||
# For every span in the sequence, check whether it is in the anon map
|
||||
# for this modality
|
||||
sorted_anon_map = sorted(self.anonymization_map,
|
||||
key=lambda k: len(k[key]))[::-1]
|
||||
|
||||
for pair in sorted_anon_map:
|
||||
our_modality = pair[key]
|
||||
|
||||
token_type = pair["type"]
|
||||
new_token = token_type + SEPARATOR + \
|
||||
str(type_counts[token_type])
|
||||
|
||||
while util.subsequence(our_modality, anonymized_sequence):
|
||||
found = False
|
||||
for startidx in range(
|
||||
len(anonymized_sequence) - len(our_modality) + 1):
|
||||
if anonymized_sequence[startidx:startidx + \
|
||||
len(our_modality)] == our_modality:
|
||||
if new_token not in tok_to_entity_dict:
|
||||
type_counts[token_type] += 1
|
||||
tok_to_entity_dict[new_token] = pair
|
||||
|
||||
anonymized_sequence = anonymized_sequence[:startidx] + [
|
||||
new_token] + anonymized_sequence[startidx + len(our_modality):]
|
||||
found = True
|
||||
break
|
||||
assert found, "Thought " \
|
||||
+ str(our_modality) + " was in [" \
|
||||
+ str(anonymized_sequence) + "] but could not find it"
|
||||
|
||||
# Also replace integers with constants
|
||||
for index, token in enumerate(anonymized_sequence):
|
||||
if token.isdigit() or is_time(token):
|
||||
if token.isdigit():
|
||||
entity_type = CONSTANT_NAME
|
||||
value = new_token
|
||||
if is_time(token):
|
||||
entity_type = TIME_NAME
|
||||
value = timeval(token)
|
||||
|
||||
# First try to find the constant in the entity dictionary already,
|
||||
# and get the name if it's found.
|
||||
new_token = ""
|
||||
new_dict = {}
|
||||
found = False
|
||||
for entity, value in tok_to_entity_dict.items():
|
||||
if value[key][0] == token:
|
||||
new_token = entity
|
||||
new_dict = value
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
new_token = entity_type + SEPARATOR + \
|
||||
str(type_counts[entity_type])
|
||||
new_dict = {}
|
||||
for tempkey in self.keys:
|
||||
new_dict[tempkey] = [token]
|
||||
|
||||
tok_to_entity_dict[new_token] = new_dict
|
||||
type_counts[entity_type] += 1
|
||||
|
||||
anonymized_sequence[index] = new_token
|
||||
|
||||
return anonymized_sequence
|
|
@ -0,0 +1,350 @@
|
|||
# TODO: review this entire file and make it much simpler.
|
||||
|
||||
import copy
|
||||
from . import snippets as snip
|
||||
from . import sql_util
|
||||
from . import vocabulary as vocab
|
||||
|
||||
|
||||
class UtteranceItem():
|
||||
def __init__(self, interaction, index):
|
||||
self.interaction = interaction
|
||||
self.utterance_index = index
|
||||
|
||||
def __str__(self):
|
||||
return str(self.interaction.utterances[self.utterance_index])
|
||||
|
||||
def histories(self, maximum):
|
||||
if maximum > 0:
|
||||
history_seqs = []
|
||||
for utterance in self.interaction.utterances[:self.utterance_index]:
|
||||
history_seqs.append(utterance.input_seq_to_use)
|
||||
|
||||
if len(history_seqs) > maximum:
|
||||
history_seqs = history_seqs[-maximum:]
|
||||
|
||||
return history_seqs
|
||||
return []
|
||||
|
||||
def input_sequence(self):
|
||||
return self.interaction.utterances[self.utterance_index].input_seq_to_use
|
||||
|
||||
def previous_query(self):
|
||||
if self.utterance_index == 0:
|
||||
return []
|
||||
return self.interaction.utterances[self.utterance_index -
|
||||
1].anonymized_gold_query
|
||||
|
||||
def anonymized_gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].anonymized_gold_query
|
||||
|
||||
def snippets(self):
|
||||
return self.interaction.utterances[self.utterance_index].available_snippets
|
||||
|
||||
def original_gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].original_gold_query
|
||||
|
||||
def contained_entities(self):
|
||||
return self.interaction.utterances[self.utterance_index].contained_entities
|
||||
|
||||
def original_gold_queries(self):
|
||||
return [
|
||||
q[0] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
|
||||
|
||||
def gold_tables(self):
|
||||
return [
|
||||
q[1] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
|
||||
|
||||
def gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_query_to_use + [
|
||||
vocab.EOS_TOK]
|
||||
|
||||
def gold_edit_sequence(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_edit_sequence
|
||||
|
||||
def gold_table(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_sql_results
|
||||
|
||||
def all_snippets(self):
|
||||
return self.interaction.snippets
|
||||
|
||||
def within_limits(self,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
return self.interaction.utterances[self.utterance_index].length_valid(
|
||||
max_input_length, max_output_length)
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
# Remove the EOS
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
# First remove the snippets
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
return no_snippets_sequence
|
||||
|
||||
def flatten_sequence(self, sequence):
|
||||
# Remove the EOS
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
# First remove the snippets
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
|
||||
# Deanonymize
|
||||
deanon_sequence = self.interaction.deanonymize(
|
||||
no_snippets_sequence, "sql")
|
||||
return deanon_sequence
|
||||
|
||||
|
||||
class UtteranceBatch():
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def start(self):
|
||||
self.index = 0
|
||||
|
||||
def next(self):
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
def done(self):
|
||||
return self.index >= len(self.items)
|
||||
|
||||
class PredUtteranceItem():
|
||||
def __init__(self,
|
||||
input_sequence,
|
||||
interaction_item,
|
||||
previous_query,
|
||||
index,
|
||||
available_snippets):
|
||||
self.input_seq_to_use = input_sequence
|
||||
self.interaction_item = interaction_item
|
||||
self.index = index
|
||||
self.available_snippets = available_snippets
|
||||
self.prev_pred_query = previous_query
|
||||
|
||||
def input_sequence(self):
|
||||
return self.input_seq_to_use
|
||||
|
||||
def histories(self, maximum):
|
||||
if maximum == 0:
|
||||
return histories
|
||||
histories = []
|
||||
for utterance in self.interaction_item.processed_utterances[:self.index]:
|
||||
histories.append(utterance.input_sequence())
|
||||
if len(histories) > maximum:
|
||||
histories = histories[-maximum:]
|
||||
return histories
|
||||
|
||||
def snippets(self):
|
||||
return self.available_snippets
|
||||
|
||||
def previous_query(self):
|
||||
return self.prev_pred_query
|
||||
|
||||
def flatten_sequence(self, sequence):
|
||||
return self.interaction_item.flatten_sequence(sequence)
|
||||
|
||||
def remove_snippets(self, sequence):
|
||||
return sql_util.fix_parentheses(
|
||||
self.interaction_item.expand_snippets(sequence))
|
||||
|
||||
def set_predicted_query(self, query):
|
||||
self.anonymized_pred_query = query
|
||||
|
||||
# Mocks an Interaction item, but allows for the parameters to be updated during
|
||||
# the process
|
||||
|
||||
|
||||
class InteractionItem():
|
||||
def __init__(self,
|
||||
interaction,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
nl_to_sql_dict={},
|
||||
maximum_length=float('inf')):
|
||||
if maximum_length != float('inf'):
|
||||
self.interaction = copy.deepcopy(interaction)
|
||||
self.interaction.utterances = self.interaction.utterances[:maximum_length]
|
||||
else:
|
||||
self.interaction = interaction
|
||||
self.processed_utterances = []
|
||||
self.snippet_bank = []
|
||||
self.identifier = self.interaction.identifier
|
||||
|
||||
self.max_input_length = max_input_length
|
||||
self.max_output_length = max_output_length
|
||||
|
||||
self.nl_to_sql_dict = nl_to_sql_dict
|
||||
|
||||
self.index = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.interaction)
|
||||
|
||||
def __str__(self):
|
||||
s = "Utterances, gold queries, and predictions:\n"
|
||||
for i, utterance in enumerate(self.interaction.utterances):
|
||||
s += " ".join(utterance.input_seq_to_use) + "\n"
|
||||
pred_utterance = self.processed_utterances[i]
|
||||
s += " ".join(pred_utterance.gold_query()) + "\n"
|
||||
s += " ".join(pred_utterance.anonymized_query()) + "\n"
|
||||
s += "\n"
|
||||
s += "Snippets:\n"
|
||||
for snippet in self.snippet_bank:
|
||||
s += str(snippet) + "\n"
|
||||
|
||||
return s
|
||||
|
||||
def start_interaction(self):
|
||||
assert len(self.snippet_bank) == 0
|
||||
assert len(self.processed_utterances) == 0
|
||||
assert self.index == 0
|
||||
|
||||
def next_utterance(self):
|
||||
utterance = self.interaction.utterances[self.index]
|
||||
self.index += 1
|
||||
|
||||
available_snippets = self.available_snippets(snippet_keep_age=1)
|
||||
|
||||
return PredUtteranceItem(utterance.input_seq_to_use,
|
||||
self,
|
||||
self.processed_utterances[-1].anonymized_pred_query if len(self.processed_utterances) > 0 else [],
|
||||
self.index - 1,
|
||||
available_snippets)
|
||||
|
||||
def done(self):
|
||||
return len(self.processed_utterances) == len(self.interaction)
|
||||
|
||||
def finish(self):
|
||||
self.snippet_bank = []
|
||||
self.processed_utterances = []
|
||||
self.index = 0
|
||||
|
||||
def utterance_within_limits(self, utterance_item):
|
||||
return utterance_item.within_limits(self.max_input_length,
|
||||
self.max_output_length)
|
||||
|
||||
def available_snippets(self, snippet_keep_age):
|
||||
return [
|
||||
snippet for snippet in self.snippet_bank if snippet.index <= snippet_keep_age]
|
||||
|
||||
def gold_utterances(self):
|
||||
utterances = []
|
||||
for i, utterance in enumerate(self.interaction.utterances):
|
||||
utterances.append(UtteranceItem(self.interaction, i))
|
||||
return utterances
|
||||
|
||||
def get_schema(self):
|
||||
return self.interaction.schema
|
||||
|
||||
def add_utterance(
|
||||
self,
|
||||
utterance,
|
||||
predicted_sequence,
|
||||
snippets=None,
|
||||
previous_snippets=[],
|
||||
simple=False):
|
||||
if not snippets:
|
||||
self.add_snippets(
|
||||
predicted_sequence,
|
||||
previous_snippets=previous_snippets, simple=simple)
|
||||
else:
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(self.snippet_bank))
|
||||
self.snippet_bank.append(snippet)
|
||||
|
||||
for snippet in self.snippet_bank:
|
||||
snippet.increase_age()
|
||||
self.processed_utterances.append(utterance)
|
||||
|
||||
def add_snippets(self, sequence, previous_snippets=[], simple=False):
|
||||
if sequence:
|
||||
if simple:
|
||||
snippets = sql_util.get_subtrees_simple(
|
||||
sequence, oldsnippets=previous_snippets)
|
||||
else:
|
||||
snippets = sql_util.get_subtrees(
|
||||
sequence, oldsnippets=previous_snippets)
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(self.snippet_bank))
|
||||
self.snippet_bank.append(snippet)
|
||||
|
||||
for snippet in self.snippet_bank:
|
||||
snippet.increase_age()
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
return sql_util.fix_parentheses(
|
||||
snip.expand_snippets(
|
||||
sequence, self.snippet_bank))
|
||||
|
||||
def remove_snippets(self, sequence):
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
no_snippets_sequence = self.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
return no_snippets_sequence
|
||||
|
||||
def flatten_sequence(self, sequence, gold_snippets=False):
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
if gold_snippets:
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
else:
|
||||
no_snippets_sequence = self.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
|
||||
deanon_sequence = self.interaction.deanonymize(
|
||||
no_snippets_sequence, "sql")
|
||||
return deanon_sequence
|
||||
|
||||
def gold_query(self, index):
|
||||
return self.interaction.utterances[index].gold_query_to_use + [
|
||||
vocab.EOS_TOK]
|
||||
|
||||
def original_gold_query(self, index):
|
||||
return self.interaction.utterances[index].original_gold_query
|
||||
|
||||
def gold_table(self, index):
|
||||
return self.interaction.utterances[index].gold_sql_results
|
||||
|
||||
|
||||
class InteractionBatch():
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def start(self):
|
||||
self.timestep = 0
|
||||
self.current_interactions = []
|
||||
|
||||
def get_next_utterance_batch(self, snippet_keep_age, use_gold=False):
|
||||
items = []
|
||||
self.current_interactions = []
|
||||
for interaction in self.items:
|
||||
if self.timestep < len(interaction):
|
||||
utterance_item = interaction.original_utterances(
|
||||
snippet_keep_age, use_gold)[self.timestep]
|
||||
self.current_interactions.append(interaction)
|
||||
items.append(utterance_item)
|
||||
|
||||
self.timestep += 1
|
||||
return UtteranceBatch(items)
|
||||
|
||||
def done(self):
|
||||
finished = True
|
||||
for interaction in self.items:
|
||||
if self.timestep < len(interaction):
|
||||
finished = False
|
||||
return finished
|
||||
return finished
|
|
@ -0,0 +1,376 @@
|
|||
""" Utility functions for loading and processing ATIS data."""
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
|
||||
from . import anonymization as anon
|
||||
from . import atis_batch
|
||||
from . import dataset_split as ds
|
||||
|
||||
from .interaction import load_function
|
||||
from .entities import NLtoSQLDict
|
||||
from .atis_vocab import ATISVocabulary
|
||||
|
||||
ENTITIES_FILENAME = 'data/entities.txt'
|
||||
ANONYMIZATION_FILENAME = 'data/anonymization.txt'
|
||||
|
||||
class ATISDataset():
|
||||
""" Contains the ATIS data. """
|
||||
def __init__(self, params):
|
||||
self.anonymizer = None
|
||||
if params.anonymize:
|
||||
self.anonymizer = anon.Anonymizer(ANONYMIZATION_FILENAME)
|
||||
|
||||
if not os.path.exists(params.data_directory):
|
||||
os.mkdir(params.data_directory)
|
||||
|
||||
self.entities_dictionary = NLtoSQLDict(ENTITIES_FILENAME)
|
||||
|
||||
database_schema = None
|
||||
if params.database_schema_filename:
|
||||
if 'removefrom' not in params.data_directory:
|
||||
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema_simple(params.database_schema_filename)
|
||||
else:
|
||||
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema(params.database_schema_filename)
|
||||
|
||||
int_load_function = load_function(params,
|
||||
self.entities_dictionary,
|
||||
self.anonymizer,
|
||||
database_schema=database_schema)
|
||||
|
||||
def collapse_list(the_list):
|
||||
""" Collapses a list of list into a single list."""
|
||||
return [s for i in the_list for s in i]
|
||||
|
||||
if 'atis' not in params.data_directory:
|
||||
self.train_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_train_filename),
|
||||
params.raw_train_filename,
|
||||
int_load_function)
|
||||
self.valid_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_validation_filename),
|
||||
params.raw_validation_filename,
|
||||
int_load_function)
|
||||
|
||||
train_input_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.input_seqs()))
|
||||
valid_input_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.input_seqs()))
|
||||
|
||||
all_input_seqs = train_input_seqs + valid_input_seqs
|
||||
|
||||
self.input_vocabulary = ATISVocabulary(
|
||||
all_input_seqs,
|
||||
os.path.join(params.data_directory, params.input_vocabulary_filename),
|
||||
params,
|
||||
is_input='input',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
self.output_vocabulary_schema = ATISVocabulary(
|
||||
column_names_embedder_input,
|
||||
os.path.join(params.data_directory, 'schema_'+params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='schema',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
train_output_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.output_seqs()))
|
||||
valid_output_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.output_seqs()))
|
||||
all_output_seqs = train_output_seqs + valid_output_seqs
|
||||
|
||||
sql_keywords = ['.', 't1', 't2', '=', 'select', 'as', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'group', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
|
||||
sql_keywords += ['count', 'from', 'value', 'order']
|
||||
sql_keywords += ['group_by', 'order_by', 'limit_value', '!=']
|
||||
|
||||
# skip column_names_surface_form but keep sql_keywords
|
||||
skip_tokens = list(set(column_names_surface_form) - set(sql_keywords))
|
||||
|
||||
if params.data_directory == 'processed_data_sparc_removefrom_test':
|
||||
all_output_seqs = []
|
||||
out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'avg', 'and', 'having', '<', 'in', 'max', 'sum', 'asc', 'like', 'not', 'or', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
|
||||
for i in range(len(out_vocab_ordered)):
|
||||
all_output_seqs.append(out_vocab_ordered[:i+1])
|
||||
|
||||
self.output_vocabulary = ATISVocabulary(
|
||||
all_output_seqs,
|
||||
os.path.join(params.data_directory, params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='output',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None,
|
||||
skip=skip_tokens)
|
||||
else:
|
||||
self.train_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_train_filename),
|
||||
params.raw_train_filename,
|
||||
int_load_function)
|
||||
if params.train:
|
||||
self.valid_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_validation_filename),
|
||||
params.raw_validation_filename,
|
||||
int_load_function)
|
||||
if params.evaluate or params.attention:
|
||||
self.dev_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_dev_filename),
|
||||
params.raw_dev_filename,
|
||||
int_load_function)
|
||||
if params.enable_testing:
|
||||
self.test_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_test_filename),
|
||||
params.raw_test_filename,
|
||||
int_load_function)
|
||||
|
||||
train_input_seqs = []
|
||||
train_input_seqs = collapse_list(
|
||||
self.train_data.get_ex_properties(
|
||||
lambda i: i.input_seqs()))
|
||||
|
||||
self.input_vocabulary = ATISVocabulary(
|
||||
train_input_seqs,
|
||||
os.path.join(params.data_directory, params.input_vocabulary_filename),
|
||||
params,
|
||||
is_input='input',
|
||||
min_occur=2,
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
train_output_seqs = collapse_list(
|
||||
self.train_data.get_ex_properties(
|
||||
lambda i: i.output_seqs()))
|
||||
|
||||
self.output_vocabulary = ATISVocabulary(
|
||||
train_output_seqs,
|
||||
os.path.join(params.data_directory, params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='output',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
self.output_vocabulary_schema = None
|
||||
|
||||
def read_database_schema_simple(self, database_schema_filename):
|
||||
with open(database_schema_filename, "r") as f:
|
||||
database_schema = json.load(f)
|
||||
|
||||
database_schema_dict = {}
|
||||
column_names_surface_form = []
|
||||
column_names_embedder_input = []
|
||||
for table_schema in database_schema:
|
||||
db_id = table_schema['db_id']
|
||||
database_schema_dict[db_id] = table_schema
|
||||
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append(table_name.lower())
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
column_name_embedder_input = column_name
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
for table_name in table_names:
|
||||
column_names_embedder_input.append(table_name.split())
|
||||
|
||||
database_schema = database_schema_dict
|
||||
|
||||
return database_schema, column_names_surface_form, column_names_embedder_input
|
||||
|
||||
def read_database_schema(self, database_schema_filename):
|
||||
with open(database_schema_filename, "r") as f:
|
||||
database_schema = json.load(f)
|
||||
|
||||
database_schema_dict = {}
|
||||
column_names_surface_form = []
|
||||
column_names_embedder_input = []
|
||||
for table_schema in database_schema:
|
||||
db_id = table_schema['db_id']
|
||||
database_schema_dict[db_id] = table_schema
|
||||
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
if table_id >= 0:
|
||||
table_name = table_names[table_id]
|
||||
column_name_embedder_input = table_name + ' . ' + column_name
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
for table_name in table_names:
|
||||
column_name_embedder_input = table_name + ' . *'
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
database_schema = database_schema_dict
|
||||
|
||||
return database_schema, column_names_surface_form, column_names_embedder_input
|
||||
|
||||
def get_all_utterances(self,
|
||||
dataset,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
""" Returns all utterances in a dataset."""
|
||||
items = []
|
||||
for interaction in dataset.examples:
|
||||
for i, utterance in enumerate(interaction.utterances):
|
||||
if utterance.length_valid(max_input_length, max_output_length):
|
||||
items.append(atis_batch.UtteranceItem(interaction, i))
|
||||
return items
|
||||
|
||||
def get_all_interactions(self,
|
||||
dataset,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
sorted_by_length=False):
|
||||
"""Gets all interactions in a dataset that fit the criteria.
|
||||
|
||||
Inputs:
|
||||
dataset (ATISDatasetSplit): The dataset to use.
|
||||
max_interaction_length (int): Maximum interaction length to keep.
|
||||
max_input_length (int): Maximum input sequence length to keep.
|
||||
max_output_length (int): Maximum output sequence length to keep.
|
||||
sorted_by_length (bool): Whether to sort the examples by interaction length.
|
||||
"""
|
||||
ints = [
|
||||
atis_batch.InteractionItem(
|
||||
interaction,
|
||||
max_input_length,
|
||||
max_output_length,
|
||||
self.entities_dictionary,
|
||||
max_interaction_length) for interaction in dataset.examples]
|
||||
if sorted_by_length:
|
||||
return sorted(ints, key=lambda x: len(x))[::-1]
|
||||
else:
|
||||
return ints
|
||||
|
||||
# This defines a standard way of training: each example is an utterance, and
|
||||
# the batch can contain unrelated utterances.
|
||||
def get_utterance_batches(self,
|
||||
batch_size,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
randomize=True):
|
||||
"""Gets batches of utterances in the data.
|
||||
|
||||
Inputs:
|
||||
batch_size (int): Batch size to use.
|
||||
max_input_length (int): Maximum length of input to keep.
|
||||
max_output_length (int): Maximum length of output to use.
|
||||
randomize (bool): Whether to randomize the ordering.
|
||||
"""
|
||||
# First, get all interactions and the positions of the utterances that are
|
||||
# possible in them.
|
||||
items = self.get_all_utterances(self.train_data,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
if randomize:
|
||||
random.shuffle(items)
|
||||
|
||||
batches = []
|
||||
|
||||
current_batch_items = []
|
||||
for item in items:
|
||||
if len(current_batch_items) >= batch_size:
|
||||
batches.append(atis_batch.UtteranceBatch(current_batch_items))
|
||||
current_batch_items = []
|
||||
current_batch_items.append(item)
|
||||
batches.append(atis_batch.UtteranceBatch(current_batch_items))
|
||||
|
||||
assert sum([len(batch) for batch in batches]) == len(items)
|
||||
|
||||
return batches
|
||||
|
||||
def get_interaction_batches(self,
|
||||
batch_size,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
randomize=True):
|
||||
"""Gets batches of interactions in the data.
|
||||
|
||||
Inputs:
|
||||
batch_size (int): Batch size to use.
|
||||
max_interaction_length (int): Maximum length of interaction to keep
|
||||
max_input_length (int): Maximum length of input to keep.
|
||||
max_output_length (int): Maximum length of output to keep.
|
||||
randomize (bool): Whether to randomize the ordering.
|
||||
"""
|
||||
items = self.get_all_interactions(self.train_data,
|
||||
max_interaction_length,
|
||||
max_input_length,
|
||||
max_output_length,
|
||||
sorted_by_length=not randomize)
|
||||
if randomize:
|
||||
random.shuffle(items)
|
||||
|
||||
batches = []
|
||||
current_batch_items = []
|
||||
for item in items:
|
||||
if len(current_batch_items) >= batch_size:
|
||||
batches.append(
|
||||
atis_batch.InteractionBatch(current_batch_items))
|
||||
current_batch_items = []
|
||||
current_batch_items.append(item)
|
||||
batches.append(atis_batch.InteractionBatch(current_batch_items))
|
||||
|
||||
assert sum([len(batch) for batch in batches]) == len(items)
|
||||
|
||||
return batches
|
||||
|
||||
def get_random_utterances(self,
|
||||
num_samples,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
"""Gets a random selection of utterances in the data.
|
||||
|
||||
Inputs:
|
||||
num_samples (bool): Number of random utterances to get.
|
||||
max_input_length (int): Limit of input length.
|
||||
max_output_length (int): Limit on output length.
|
||||
"""
|
||||
items = self.get_all_utterances(self.train_data,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
random.shuffle(items)
|
||||
return items[:num_samples]
|
||||
|
||||
def get_random_interactions(self,
|
||||
num_samples,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
|
||||
"""Gets a random selection of interactions in the data.
|
||||
|
||||
Inputs:
|
||||
num_samples (bool): Number of random interactions to get.
|
||||
max_input_length (int): Limit of input length.
|
||||
max_output_length (int): Limit on output length.
|
||||
"""
|
||||
items = self.get_all_interactions(self.train_data,
|
||||
max_interaction_length,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
random.shuffle(items)
|
||||
return items[:num_samples]
|
||||
|
||||
|
||||
def num_utterances(dataset):
|
||||
"""Returns the total number of utterances in the dataset."""
|
||||
return sum([len(interaction) for interaction in dataset.examples])
|
|
@ -0,0 +1,74 @@
|
|||
"""Gets and stores vocabulary for the ATIS data."""
|
||||
|
||||
from . import snippets
|
||||
from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK
|
||||
|
||||
INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK]
|
||||
OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK]
|
||||
|
||||
MIN_INPUT_OCCUR = 1
|
||||
MIN_OUTPUT_OCCUR = 1
|
||||
|
||||
class ATISVocabulary():
|
||||
""" Stores the vocabulary for the ATIS data.
|
||||
|
||||
Attributes:
|
||||
raw_vocab (Vocabulary): Vocabulary object.
|
||||
tokens (set of str): Set of all of the strings in the vocabulary.
|
||||
inorder_tokens (list of str): List of all tokens, with a strict and
|
||||
unchanging order.
|
||||
"""
|
||||
def __init__(self,
|
||||
token_sequences,
|
||||
filename,
|
||||
params,
|
||||
is_input='input',
|
||||
min_occur=1,
|
||||
anonymizer=None,
|
||||
skip=None):
|
||||
|
||||
if is_input=='input':
|
||||
functional_types = INPUT_FN_TYPES
|
||||
elif is_input=='output':
|
||||
functional_types = OUTPUT_FN_TYPES
|
||||
elif is_input=='schema':
|
||||
functional_types = [UNK_TOK]
|
||||
else:
|
||||
functional_types = []
|
||||
|
||||
self.raw_vocab = Vocabulary(
|
||||
token_sequences,
|
||||
filename,
|
||||
functional_types=functional_types,
|
||||
min_occur=min_occur,
|
||||
ignore_fn=lambda x: snippets.is_snippet(x) or (
|
||||
anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) )
|
||||
self.tokens = set(self.raw_vocab.token_to_id.keys())
|
||||
self.inorder_tokens = self.raw_vocab.id_to_token
|
||||
|
||||
assert len(self.inorder_tokens) == len(self.raw_vocab)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_vocab)
|
||||
|
||||
def token_to_id(self, token):
|
||||
""" Maps from a token to a unique ID.
|
||||
|
||||
Inputs:
|
||||
token (str): The token to look up.
|
||||
|
||||
Returns:
|
||||
int, uniquely identifying the token.
|
||||
"""
|
||||
return self.raw_vocab.token_to_id[token]
|
||||
|
||||
def id_to_token(self, identifier):
|
||||
""" Maps from a unique integer to an identifier.
|
||||
|
||||
Inputs:
|
||||
identifier (int): The unique ID.
|
||||
|
||||
Returns:
|
||||
string, representing the token.
|
||||
"""
|
||||
return self.raw_vocab.id_to_token[identifier]
|
|
@ -0,0 +1,56 @@
|
|||
""" Utility functions for loading and processing ATIS data.
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
|
||||
class DatasetSplit:
|
||||
"""Stores a split of the ATIS dataset.
|
||||
|
||||
Attributes:
|
||||
examples (list of Interaction): Stores the examples in the split.
|
||||
"""
|
||||
def __init__(self, processed_filename, raw_filename, load_function):
|
||||
if os.path.exists(processed_filename):
|
||||
print("Loading preprocessed data from " + processed_filename)
|
||||
with open(processed_filename, 'rb') as infile:
|
||||
self.examples = pickle.load(infile)
|
||||
else:
|
||||
print(
|
||||
"Loading raw data from " +
|
||||
raw_filename +
|
||||
" and writing to " +
|
||||
processed_filename)
|
||||
|
||||
infile = open(raw_filename, 'rb')
|
||||
examples_from_file = pickle.load(infile)
|
||||
assert isinstance(examples_from_file, list), raw_filename + \
|
||||
" does not contain a list of examples"
|
||||
infile.close()
|
||||
|
||||
self.examples = []
|
||||
for example in examples_from_file:
|
||||
obj, keep = load_function(example)
|
||||
|
||||
if keep:
|
||||
self.examples.append(obj)
|
||||
|
||||
|
||||
print("Loaded " + str(len(self.examples)) + " examples")
|
||||
outfile = open(processed_filename, 'wb')
|
||||
pickle.dump(self.examples, outfile)
|
||||
outfile.close()
|
||||
|
||||
def get_ex_properties(self, function):
|
||||
""" Applies some function to the examples in the dataset.
|
||||
|
||||
Inputs:
|
||||
function: (lambda Interaction -> T): Function to apply to all
|
||||
examples.
|
||||
|
||||
Returns
|
||||
list of the return value of the function
|
||||
"""
|
||||
elems = []
|
||||
for example in self.examples:
|
||||
elems.append(function(example))
|
||||
return elems
|
|
@ -0,0 +1,63 @@
|
|||
""" Classes for keeping track of the entities in a natural language string. """
|
||||
import json
|
||||
|
||||
|
||||
class NLtoSQLDict:
|
||||
"""
|
||||
Entity dict file should contain, on each line, a JSON dictionary with
|
||||
"input" and "output" keys specifying the string for the input and output
|
||||
pairs. The idea is that the existence of the key in an input sequence
|
||||
likely corresponds to the existence of the value in the output sequence.
|
||||
|
||||
The entity_dict should map keys (input strings) to a list of values (output
|
||||
strings) where this property holds. This allows keys to map to multiple
|
||||
output strings (e.g. for times).
|
||||
"""
|
||||
def __init__(self, entity_dict_filename):
|
||||
self.entity_dict = {}
|
||||
|
||||
pairs = [json.loads(line)
|
||||
for line in open(entity_dict_filename).readlines()]
|
||||
for pair in pairs:
|
||||
input_seq = pair["input"]
|
||||
output_seq = pair["output"]
|
||||
if input_seq not in self.entity_dict:
|
||||
self.entity_dict[input_seq] = []
|
||||
self.entity_dict[input_seq].append(output_seq)
|
||||
|
||||
def get_sql_entities(self, tokenized_nl_string):
|
||||
"""
|
||||
Gets the output-side entities which correspond to the input entities in
|
||||
the input sequence.
|
||||
Inputs:
|
||||
tokenized_input_string: list of tokens in the input string.
|
||||
Outputs:
|
||||
set of output strings.
|
||||
"""
|
||||
assert len(tokenized_nl_string) > 0
|
||||
flat_input_string = " ".join(tokenized_nl_string)
|
||||
entities = []
|
||||
|
||||
# See if any input strings are in our input sequence, and add the
|
||||
# corresponding output strings if so.
|
||||
for entry, values in self.entity_dict.items():
|
||||
in_middle = " " + entry + " " in flat_input_string
|
||||
|
||||
leftspace = " " + entry
|
||||
at_end = leftspace in flat_input_string and flat_input_string.endswith(
|
||||
leftspace)
|
||||
|
||||
rightspace = entry + " "
|
||||
at_beginning = rightspace in flat_input_string and flat_input_string.startswith(
|
||||
rightspace)
|
||||
if in_middle or at_end or at_beginning:
|
||||
for out_string in values:
|
||||
entities.append(out_string)
|
||||
|
||||
# Also add any integers in the input string (these aren't in the entity)
|
||||
# dict.
|
||||
for token in tokenized_nl_string:
|
||||
if token.isnumeric():
|
||||
entities.append(token)
|
||||
|
||||
return entities
|
|
@ -0,0 +1,312 @@
|
|||
""" Contains the class for an interaction in ATIS. """
|
||||
|
||||
from . import anonymization as anon
|
||||
from . import sql_util
|
||||
from .snippets import expand_snippets
|
||||
from .utterance import Utterance, OUTPUT_KEY, ANON_INPUT_KEY
|
||||
|
||||
import torch
|
||||
|
||||
class Schema:
|
||||
def __init__(self, table_schema, simple=False):
|
||||
if simple:
|
||||
self.helper1(table_schema)
|
||||
else:
|
||||
self.helper2(table_schema)
|
||||
|
||||
def helper1(self, table_schema):
|
||||
self.table_schema = table_schema
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
|
||||
|
||||
column_keep_index = []
|
||||
|
||||
self.column_names_surface_form = []
|
||||
self.column_names_surface_form_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
column_name_surface_form = column_name
|
||||
column_name_surface_form = column_name_surface_form.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index.append(i)
|
||||
|
||||
column_keep_index_2 = []
|
||||
for i, table_name in enumerate(table_names_original):
|
||||
column_name_surface_form = table_name.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index_2.append(i)
|
||||
|
||||
self.column_names_embedder_input = []
|
||||
self.column_names_embedder_input_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
column_name_embedder_input = column_name
|
||||
if i in column_keep_index:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
for i, table_name in enumerate(table_names):
|
||||
column_name_embedder_input = table_name
|
||||
if i in column_keep_index_2:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
|
||||
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
|
||||
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
|
||||
|
||||
self.num_col = len(self.column_names_surface_form)
|
||||
|
||||
def helper2(self, table_schema):
|
||||
self.table_schema = table_schema
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
|
||||
|
||||
column_keep_index = []
|
||||
|
||||
self.column_names_surface_form = []
|
||||
self.column_names_surface_form_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
column_name_surface_form = column_name
|
||||
column_name_surface_form = column_name_surface_form.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index.append(i)
|
||||
|
||||
start_i = len(self.column_names_surface_form_to_id)
|
||||
for i, table_name in enumerate(table_names_original):
|
||||
column_name_surface_form = '{}.*'.format(table_name.lower())
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = i + start_i
|
||||
|
||||
self.column_names_embedder_input = []
|
||||
self.column_names_embedder_input_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
if table_id >= 0:
|
||||
table_name = table_names[table_id]
|
||||
column_name_embedder_input = table_name + ' . ' + column_name
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
if i in column_keep_index:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
start_i = len(self.column_names_embedder_input_to_id)
|
||||
for i, table_name in enumerate(table_names):
|
||||
column_name_embedder_input = table_name + ' . *'
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = i + start_i
|
||||
|
||||
assert len(self.column_names_surface_form) == len(self.column_names_surface_form_to_id) == len(self.column_names_embedder_input) == len(self.column_names_embedder_input_to_id)
|
||||
|
||||
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
|
||||
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
|
||||
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
|
||||
|
||||
self.num_col = len(self.column_names_surface_form)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_col
|
||||
|
||||
def in_vocabulary(self, column_name, surface_form=False):
|
||||
if surface_form:
|
||||
return column_name in self.column_names_surface_form_to_id
|
||||
else:
|
||||
return column_name in self.column_names_embedder_input_to_id
|
||||
|
||||
def column_name_embedder_bow(self, column_name, surface_form=False, column_name_token_embedder=None):
|
||||
assert self.in_vocabulary(column_name, surface_form)
|
||||
if surface_form:
|
||||
column_name_id = self.column_names_surface_form_to_id[column_name]
|
||||
column_name_embedder_input = self.column_names_embedder_input[column_name_id]
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
|
||||
column_name_embeddings = [column_name_token_embedder(token) for token in column_name_embedder_input.split()]
|
||||
column_name_embeddings = torch.stack(column_name_embeddings, dim=0)
|
||||
return torch.mean(column_name_embeddings, dim=0)
|
||||
|
||||
def set_column_name_embeddings(self, column_name_embeddings):
|
||||
self.column_name_embeddings = column_name_embeddings
|
||||
assert len(self.column_name_embeddings) == self.num_col
|
||||
|
||||
def column_name_embedder(self, column_name, surface_form=False):
|
||||
assert self.in_vocabulary(column_name, surface_form)
|
||||
if surface_form:
|
||||
column_name_id = self.column_names_surface_form_to_id[column_name]
|
||||
else:
|
||||
column_name_id = self.column_names_embedder_input_to_id[column_name]
|
||||
|
||||
return self.column_name_embeddings[column_name_id]
|
||||
|
||||
class Interaction:
|
||||
""" ATIS interaction class.
|
||||
|
||||
Attributes:
|
||||
utterances (list of Utterance): The utterances in the interaction.
|
||||
snippets (list of Snippet): The snippets that appear through the interaction.
|
||||
anon_tok_to_ent:
|
||||
identifier (str): Unique identifier for the interaction in the dataset.
|
||||
"""
|
||||
def __init__(self,
|
||||
utterances,
|
||||
schema,
|
||||
snippets,
|
||||
anon_tok_to_ent,
|
||||
identifier,
|
||||
params):
|
||||
self.utterances = utterances
|
||||
self.schema = schema
|
||||
self.snippets = snippets
|
||||
self.anon_tok_to_ent = anon_tok_to_ent
|
||||
self.identifier = identifier
|
||||
|
||||
# Ensure that each utterance's input and output sequences, when remapped
|
||||
# without anonymization or snippets, are the same as the original
|
||||
# version.
|
||||
for i, utterance in enumerate(self.utterances):
|
||||
deanon_input = self.deanonymize(utterance.input_seq_to_use,
|
||||
ANON_INPUT_KEY)
|
||||
assert deanon_input == utterance.original_input_seq, "Anonymized sequence [" \
|
||||
+ " ".join(utterance.input_seq_to_use) + "] is not the same as [" \
|
||||
+ " ".join(utterance.original_input_seq) + "] when deanonymized (is [" \
|
||||
+ " ".join(deanon_input) + "] instead)"
|
||||
desnippet_gold = self.expand_snippets(utterance.gold_query_to_use)
|
||||
deanon_gold = self.deanonymize(desnippet_gold, OUTPUT_KEY)
|
||||
assert deanon_gold == utterance.original_gold_query, \
|
||||
"Anonymized and/or snippet'd query " \
|
||||
+ " ".join(utterance.gold_query_to_use) + " is not the same as " \
|
||||
+ " ".join(utterance.original_gold_query)
|
||||
|
||||
def __str__(self):
|
||||
string = "Utterances:\n"
|
||||
for utterance in self.utterances:
|
||||
string += str(utterance) + "\n"
|
||||
string += "Anonymization dictionary:\n"
|
||||
for ent_tok, deanon in self.anon_tok_to_ent.items():
|
||||
string += ent_tok + "\t" + str(deanon) + "\n"
|
||||
|
||||
return string
|
||||
|
||||
def __len__(self):
|
||||
return len(self.utterances)
|
||||
|
||||
def deanonymize(self, sequence, key):
|
||||
""" Deanonymizes a predicted query or an input utterance.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): The sequence to deanonymize.
|
||||
key (str): The key in the anonymization table, e.g. NL or SQL.
|
||||
"""
|
||||
return anon.deanonymize(sequence, self.anon_tok_to_ent, key)
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
""" Expands snippets for a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): A SQL query.
|
||||
|
||||
"""
|
||||
return expand_snippets(sequence, self.snippets)
|
||||
|
||||
def input_seqs(self):
|
||||
in_seqs = []
|
||||
for utterance in self.utterances:
|
||||
in_seqs.append(utterance.input_seq_to_use)
|
||||
return in_seqs
|
||||
|
||||
def output_seqs(self):
|
||||
out_seqs = []
|
||||
for utterance in self.utterances:
|
||||
out_seqs.append(utterance.gold_query_to_use)
|
||||
return out_seqs
|
||||
|
||||
def load_function(parameters,
|
||||
nl_to_sql_dict,
|
||||
anonymizer,
|
||||
database_schema=None):
|
||||
def fn(interaction_example):
|
||||
keep = False
|
||||
|
||||
raw_utterances = interaction_example["interaction"]
|
||||
|
||||
if "database_id" in interaction_example:
|
||||
database_id = interaction_example["database_id"]
|
||||
interaction_id = interaction_example["interaction_id"]
|
||||
identifier = database_id + '/' + str(interaction_id)
|
||||
else:
|
||||
identifier = interaction_example["id"]
|
||||
|
||||
schema = None
|
||||
if database_schema:
|
||||
if 'removefrom' not in parameters.data_directory:
|
||||
schema = Schema(database_schema[database_id], simple=True)
|
||||
else:
|
||||
schema = Schema(database_schema[database_id])
|
||||
|
||||
snippet_bank = []
|
||||
|
||||
utterance_examples = []
|
||||
|
||||
anon_tok_to_ent = {}
|
||||
|
||||
for utterance in raw_utterances:
|
||||
available_snippets = [
|
||||
snippet for snippet in snippet_bank if snippet.index <= 1]
|
||||
|
||||
proc_utterance = Utterance(utterance,
|
||||
available_snippets,
|
||||
nl_to_sql_dict,
|
||||
parameters,
|
||||
anon_tok_to_ent,
|
||||
anonymizer)
|
||||
keep_utterance = proc_utterance.keep
|
||||
|
||||
if schema:
|
||||
assert keep_utterance
|
||||
|
||||
if keep_utterance:
|
||||
keep = True
|
||||
utterance_examples.append(proc_utterance)
|
||||
|
||||
# Update the snippet bank, and age each snippet in it.
|
||||
if parameters.use_snippets:
|
||||
if 'atis' in parameters.data_directory:
|
||||
snippets = sql_util.get_subtrees(
|
||||
proc_utterance.anonymized_gold_query,
|
||||
proc_utterance.available_snippets)
|
||||
else:
|
||||
snippets = sql_util.get_subtrees_simple(
|
||||
proc_utterance.anonymized_gold_query,
|
||||
proc_utterance.available_snippets)
|
||||
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(snippet_bank))
|
||||
snippet_bank.append(snippet)
|
||||
|
||||
for snippet in snippet_bank:
|
||||
snippet.increase_age()
|
||||
|
||||
interaction = Interaction(utterance_examples,
|
||||
schema,
|
||||
snippet_bank,
|
||||
anon_tok_to_ent,
|
||||
identifier,
|
||||
parameters)
|
||||
|
||||
return interaction, keep
|
||||
|
||||
return fn
|
|
@ -0,0 +1,105 @@
|
|||
""" Contains the Snippet class and methods for handling snippets.
|
||||
|
||||
Attributes:
|
||||
SNIPPET_PREFIX: string prefix for snippets.
|
||||
"""
|
||||
|
||||
SNIPPET_PREFIX = "SNIPPET_"
|
||||
|
||||
|
||||
def is_snippet(token):
|
||||
""" Determines whether a token is a snippet or not.
|
||||
|
||||
Inputs:
|
||||
token (str): The token to check.
|
||||
|
||||
Returns:
|
||||
bool, indicating whether it's a snippet.
|
||||
"""
|
||||
return token.startswith(SNIPPET_PREFIX)
|
||||
|
||||
def expand_snippets(sequence, snippets):
|
||||
""" Given a sequence and a list of snippets, expand the snippets in the sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): Query containing snippet references.
|
||||
snippets (list of Snippet): List of available snippets.
|
||||
|
||||
return list of str representing the expanded sequence
|
||||
"""
|
||||
snippet_id_to_snippet = {}
|
||||
for snippet in snippets:
|
||||
assert snippet.name not in snippet_id_to_snippet
|
||||
snippet_id_to_snippet[snippet.name] = snippet
|
||||
expanded_seq = []
|
||||
for token in sequence:
|
||||
if token in snippet_id_to_snippet:
|
||||
expanded_seq.extend(snippet_id_to_snippet[token].sequence)
|
||||
else:
|
||||
assert not is_snippet(token)
|
||||
expanded_seq.append(token)
|
||||
|
||||
return expanded_seq
|
||||
|
||||
def snippet_index(token):
|
||||
""" Returns the index of a snippet.
|
||||
|
||||
Inputs:
|
||||
token (str): The snippet to check.
|
||||
|
||||
Returns:
|
||||
integer, the index of the snippet.
|
||||
"""
|
||||
assert is_snippet(token)
|
||||
return int(token.split("_")[-1])
|
||||
|
||||
|
||||
class Snippet():
|
||||
""" Contains a snippet. """
|
||||
def __init__(self,
|
||||
sequence,
|
||||
startpos,
|
||||
sql,
|
||||
age=0):
|
||||
self.sequence = sequence
|
||||
self.startpos = startpos
|
||||
self.sql = sql
|
||||
|
||||
# TODO: age vs. index?
|
||||
self.age = age
|
||||
self.index = 0
|
||||
|
||||
self.name = ""
|
||||
self.embedding = None
|
||||
|
||||
self.endpos = self.startpos + len(self.sequence)
|
||||
assert self.endpos < len(self.sql), "End position of snippet is " + str(
|
||||
self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")"
|
||||
assert self.sequence == self.sql[self.startpos:self.endpos], \
|
||||
"Value of snippet (" + " ".join(self.sequence) + ") " \
|
||||
"is not the same as SQL at the same positions (" \
|
||||
+ " ".join(self.sql[self.startpos:self.endpos]) + ")"
|
||||
|
||||
def __str__(self):
|
||||
return self.name + "\t" + \
|
||||
str(self.age) + "\t" + " ".join(self.sequence)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence)
|
||||
|
||||
def increase_age(self):
|
||||
""" Ages a snippet by one. """
|
||||
self.index += 1
|
||||
|
||||
def assign_id(self, number):
|
||||
""" Assigns the name of the snippet to be the prefix + the number. """
|
||||
self.name = SNIPPET_PREFIX + str(number)
|
||||
|
||||
def set_embedding(self, embedding):
|
||||
""" Sets the embedding of the snippet.
|
||||
|
||||
Inputs:
|
||||
embedding (dy.Expression)
|
||||
|
||||
"""
|
||||
self.embedding = embedding
|
|
@ -0,0 +1,445 @@
|
|||
import copy
|
||||
import pymysql
|
||||
import random
|
||||
import signal
|
||||
import sqlparse
|
||||
from . import util
|
||||
|
||||
from .snippets import Snippet
|
||||
from sqlparse import tokens as token_types
|
||||
from sqlparse import sql as sql_types
|
||||
|
||||
interesting_selects = ["DISTINCT", "MAX", "MIN", "count"]
|
||||
ignored_subtrees = [["1", "=", "1"]]
|
||||
|
||||
# strip_whitespace_front
|
||||
# Strips whitespace and punctuation from the front of a SQL token list.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the token list.
|
||||
#
|
||||
# Outputs:
|
||||
# new token list.
|
||||
|
||||
|
||||
def strip_whitespace_front(token_list):
|
||||
new_token_list = []
|
||||
found_valid = False
|
||||
|
||||
for token in token_list:
|
||||
if not (token.is_whitespace or token.ttype ==
|
||||
token_types.Punctuation) or found_valid:
|
||||
found_valid = True
|
||||
new_token_list.append(token)
|
||||
|
||||
return new_token_list
|
||||
|
||||
# strip_whitespace
|
||||
# Strips whitespace from a token list.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the token list.
|
||||
#
|
||||
# Outputs:
|
||||
# new token list with no whitespace/punctuation surrounding.
|
||||
|
||||
|
||||
def strip_whitespace(token_list):
|
||||
subtokens = strip_whitespace_front(token_list)
|
||||
subtokens = strip_whitespace_front(subtokens[::-1])[::-1]
|
||||
return subtokens
|
||||
|
||||
# token_list_to_seq
|
||||
# Converts a Token list to a sequence of strings, stripping out surrounding
|
||||
# punctuation and all whitespace.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the list of tokens.
|
||||
#
|
||||
# Outputs:
|
||||
# list of strings
|
||||
|
||||
|
||||
def token_list_to_seq(token_list):
|
||||
subtokens = strip_whitespace(token_list)
|
||||
|
||||
seq = []
|
||||
flat = sqlparse.sql.TokenList(subtokens).flatten()
|
||||
for i, token in enumerate(flat):
|
||||
strip_token = str(token).strip()
|
||||
if len(strip_token) > 0:
|
||||
seq.append(strip_token)
|
||||
if len(seq) > 0:
|
||||
if seq[0] == "(" and seq[-1] == ")":
|
||||
seq = seq[1:-1]
|
||||
|
||||
return seq
|
||||
|
||||
# TODO: clean this up
|
||||
# find_subtrees
|
||||
# Finds subtrees for a subsequence of SQL.
|
||||
#
|
||||
# Inputs:
|
||||
# sequence: sequence of SQL tokens.
|
||||
# current_subtrees: current list of subtrees.
|
||||
#
|
||||
# Optional inputs:
|
||||
# where_parent: whether the parent of the current sequence was a where clause
|
||||
# keep_conj_subtrees: whether to look for a conjunction in this sequence and
|
||||
# keep its arguments
|
||||
|
||||
|
||||
def find_subtrees(sequence,
|
||||
current_subtrees,
|
||||
where_parent=False,
|
||||
keep_conj_subtrees=False):
|
||||
# If the parent of the subsequence was a WHERE clause, keep everything in the
|
||||
# sequence except for the beginning WHERE and any surrounding parentheses.
|
||||
if where_parent:
|
||||
# Strip out the beginning WHERE, and any punctuation or whitespace at the
|
||||
# beginning or end of the token list.
|
||||
seq = token_list_to_seq(sequence.tokens[1:])
|
||||
if len(seq) > 0 and seq not in current_subtrees:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
# If the current sequence has subtokens, i.e. if it's a node that can be
|
||||
# expanded, check for a conjunction in its subtrees, and expand its subtrees.
|
||||
# Also check for any SELECT statements and keep track of what follows.
|
||||
if sequence.is_group:
|
||||
if keep_conj_subtrees:
|
||||
subtokens = strip_whitespace(sequence.tokens)
|
||||
|
||||
# Check if there is a conjunction in the subsequence. If so, keep the
|
||||
# children. Also make sure you don't split where AND is used within a
|
||||
# child -- the subtokens sequence won't treat those ANDs differently (a
|
||||
# bit hacky but it works)
|
||||
has_and = False
|
||||
for i, token in enumerate(subtokens):
|
||||
if token.value == "OR" or token.value == "AND":
|
||||
has_and = True
|
||||
break
|
||||
|
||||
if has_and:
|
||||
and_subtrees = []
|
||||
current_subtree = []
|
||||
for i, token in enumerate(subtokens):
|
||||
if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len(
|
||||
subtokens) and subtokens[i - 4].value != "BETWEEN"):
|
||||
and_subtrees.append(current_subtree)
|
||||
current_subtree = []
|
||||
else:
|
||||
current_subtree.append(token)
|
||||
and_subtrees.append(current_subtree)
|
||||
|
||||
for subtree in and_subtrees:
|
||||
seq = token_list_to_seq(subtree)
|
||||
if len(seq) > 0 and seq[0] == "WHERE":
|
||||
seq = seq[1:]
|
||||
if seq not in current_subtrees:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
in_select = False
|
||||
select_toks = []
|
||||
for i, token in enumerate(sequence.tokens):
|
||||
# Mark whether this current token is a WHERE.
|
||||
is_where = (isinstance(token, sql_types.Where))
|
||||
|
||||
# If you are in a SELECT, start recording what follows until you hit a
|
||||
# FROM
|
||||
if token.value == "SELECT":
|
||||
in_select = True
|
||||
elif in_select:
|
||||
select_toks.append(token)
|
||||
if token.value == "FROM":
|
||||
in_select = False
|
||||
|
||||
seq = []
|
||||
if len(sequence.tokens) > i + 2:
|
||||
seq = token_list_to_seq(
|
||||
select_toks + [sequence.tokens[i + 2]])
|
||||
|
||||
if seq not in current_subtrees and len(
|
||||
seq) > 0 and seq[0] in interesting_selects:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
select_toks = []
|
||||
|
||||
# Recursively find subtrees in the children of the node.
|
||||
find_subtrees(token,
|
||||
current_subtrees,
|
||||
is_where,
|
||||
where_parent or keep_conj_subtrees)
|
||||
|
||||
# get_subtrees
|
||||
|
||||
|
||||
def get_subtrees(sql, oldsnippets=[]):
|
||||
parsed = sqlparse.parse(" ".join(sql))[0]
|
||||
|
||||
subtrees = []
|
||||
find_subtrees(parsed, subtrees)
|
||||
|
||||
final_subtrees = []
|
||||
for subtree in subtrees:
|
||||
if subtree not in ignored_subtrees:
|
||||
final_version = []
|
||||
keep = True
|
||||
|
||||
parens_counts = 0
|
||||
for i, token in enumerate(subtree):
|
||||
if token == ".":
|
||||
newtoken = final_version[-1] + "." + subtree[i + 1]
|
||||
final_version = final_version[:-1] + [newtoken]
|
||||
keep = False
|
||||
elif keep:
|
||||
final_version.append(token)
|
||||
else:
|
||||
keep = True
|
||||
|
||||
if token == "(":
|
||||
parens_counts -= 1
|
||||
elif token == ")":
|
||||
parens_counts += 1
|
||||
|
||||
if parens_counts == 0:
|
||||
final_subtrees.append(final_version)
|
||||
|
||||
snippets = []
|
||||
sql = [str(tok) for tok in sql]
|
||||
for subtree in final_subtrees:
|
||||
startpos = -1
|
||||
for i in range(len(sql) - len(subtree) + 1):
|
||||
if sql[i:i + len(subtree)] == subtree:
|
||||
startpos = i
|
||||
if startpos >= 0 and startpos + len(subtree) < len(sql):
|
||||
age = 0
|
||||
for prevsnippet in oldsnippets:
|
||||
if prevsnippet.sequence == subtree:
|
||||
age = prevsnippet.age + 1
|
||||
snippet = Snippet(subtree, startpos, sql, age=age)
|
||||
snippets.append(snippet)
|
||||
|
||||
return snippets
|
||||
|
||||
|
||||
def get_subtrees_simple(sql, oldsnippets=[]):
|
||||
sql_string = " ".join(sql)
|
||||
format_sql = sqlparse.format(sql_string, reindent=True)
|
||||
|
||||
# get subtrees
|
||||
subtrees = []
|
||||
for sub_sql in format_sql.split('\n'):
|
||||
sub_sql = sub_sql.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ')
|
||||
|
||||
subtree = sub_sql.strip().split()
|
||||
if len(subtree) > 1:
|
||||
subtrees.append(subtree)
|
||||
|
||||
final_subtrees = subtrees
|
||||
|
||||
snippets = []
|
||||
sql = [str(tok) for tok in sql]
|
||||
for subtree in final_subtrees:
|
||||
startpos = -1
|
||||
for i in range(len(sql) - len(subtree) + 1):
|
||||
if sql[i:i + len(subtree)] == subtree:
|
||||
startpos = i
|
||||
|
||||
if startpos >= 0 and startpos + len(subtree) <= len(sql):
|
||||
age = 0
|
||||
for prevsnippet in oldsnippets:
|
||||
if prevsnippet.sequence == subtree:
|
||||
age = prevsnippet.age + 1
|
||||
new_sql = sql + [';']
|
||||
snippet = Snippet(subtree, startpos, new_sql, age=age)
|
||||
snippets.append(snippet)
|
||||
|
||||
return snippets
|
||||
|
||||
|
||||
conjunctions = {"AND", "OR", "WHERE"}
|
||||
|
||||
|
||||
def get_all_in_parens(sequence):
|
||||
if sequence[-1] == ";":
|
||||
sequence = sequence[:-1]
|
||||
|
||||
if not "(" in sequence:
|
||||
return []
|
||||
|
||||
if sequence[0] == "(" and sequence[-1] == ")":
|
||||
in_parens = sequence[1:-1]
|
||||
return [in_parens] + get_all_in_parens(in_parens)
|
||||
else:
|
||||
paren_subseqs = []
|
||||
current_seq = []
|
||||
num_parens = 0
|
||||
in_parens = False
|
||||
for token in sequence:
|
||||
if in_parens:
|
||||
current_seq.append(token)
|
||||
if token == ")":
|
||||
num_parens -= 1
|
||||
if num_parens == 0:
|
||||
in_parens = False
|
||||
paren_subseqs.append(current_seq)
|
||||
current_seq = []
|
||||
elif token == "(":
|
||||
in_parens = True
|
||||
current_seq.append(token)
|
||||
if token == "(":
|
||||
num_parens += 1
|
||||
|
||||
all_subseqs = []
|
||||
for subseq in paren_subseqs:
|
||||
all_subseqs.extend(get_all_in_parens(subseq))
|
||||
return all_subseqs
|
||||
|
||||
|
||||
def split_by_conj(sequence):
|
||||
num_parens = 0
|
||||
current_seq = []
|
||||
subsequences = []
|
||||
|
||||
for token in sequence:
|
||||
if num_parens == 0:
|
||||
if token in conjunctions:
|
||||
subsequences.append(current_seq)
|
||||
current_seq = []
|
||||
break
|
||||
current_seq.append(token)
|
||||
if token == "(":
|
||||
num_parens += 1
|
||||
elif token == ")":
|
||||
num_parens -= 1
|
||||
|
||||
assert num_parens >= 0
|
||||
|
||||
return subsequences
|
||||
|
||||
|
||||
def get_sql_snippets(sequence):
|
||||
# First, get all subsequences of the sequence that are surrounded by
|
||||
# parentheses.
|
||||
all_in_parens = get_all_in_parens(sequence)
|
||||
all_subseq = []
|
||||
|
||||
# Then for each one, split the sequence on conjunctions (AND/OR).
|
||||
for seq in all_in_parens:
|
||||
subsequences = split_by_conj(seq)
|
||||
all_subseq.append(seq)
|
||||
all_subseq.extend(subsequences)
|
||||
|
||||
# Finally, also get "interesting" selects
|
||||
|
||||
for i, seq in enumerate(all_subseq):
|
||||
print(str(i) + "\t" + " ".join(seq))
|
||||
exit()
|
||||
|
||||
# add_snippets_to_query
|
||||
|
||||
|
||||
def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.):
|
||||
query_copy = copy.copy(query)
|
||||
|
||||
# Replace the longest snippets first, so sort by length descending.
|
||||
sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1]
|
||||
|
||||
for snippet in sorted_snippets:
|
||||
ignore = False
|
||||
snippet_seq = snippet.sequence
|
||||
|
||||
# TODO: continue here
|
||||
# If it contains an ignored entity, then don't use it.
|
||||
for entity in ignored_entities:
|
||||
ignore = ignore or util.subsequence(entity, snippet_seq)
|
||||
|
||||
# No NL entities found in snippet, then see if snippet is a substring of
|
||||
# the gold sequence
|
||||
if not ignore:
|
||||
snippet_length = len(snippet_seq)
|
||||
|
||||
# Iterate through gold sequence to see if it's a subsequence.
|
||||
for start_idx in range(len(query_copy) - snippet_length + 1):
|
||||
if query_copy[start_idx:start_idx +
|
||||
snippet_length] == snippet_seq:
|
||||
align = random.random() < prob_align
|
||||
|
||||
if align:
|
||||
prev_length = len(query_copy)
|
||||
|
||||
# At the start position of the snippet, replace with an
|
||||
# identifier.
|
||||
query_copy[start_idx] = snippet.name
|
||||
|
||||
# Then cut out the indices which were collapsed into
|
||||
# the snippet.
|
||||
query_copy = query_copy[:start_idx + 1] + \
|
||||
query_copy[start_idx + snippet_length:]
|
||||
|
||||
# Make sure the length is as expected
|
||||
assert len(query_copy) == prev_length - \
|
||||
(snippet_length - 1)
|
||||
|
||||
return query_copy
|
||||
|
||||
|
||||
def execution_results(query, username, password, timeout=3):
|
||||
connection = pymysql.connect(user=username, password=password)
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutException
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
|
||||
syntactic = True
|
||||
semantic = True
|
||||
|
||||
table = []
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
cursor.execute("SET sql_mode='IGNORE_SPACE';")
|
||||
cursor.execute("use atis3;")
|
||||
cursor.execute(query)
|
||||
table = cursor.fetchall()
|
||||
cursor.close()
|
||||
except TimeoutException:
|
||||
signal.alarm(0)
|
||||
cursor.close()
|
||||
except pymysql.err.ProgrammingError:
|
||||
syntactic = False
|
||||
semantic = False
|
||||
cursor.close()
|
||||
except pymysql.err.InternalError:
|
||||
semantic = False
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
signal.alarm(0)
|
||||
cursor.close()
|
||||
signal.alarm(0)
|
||||
|
||||
connection.close()
|
||||
|
||||
return (syntactic, semantic, sorted(table))
|
||||
|
||||
|
||||
def executable(query, username, password, timeout=2):
|
||||
return execution_results(query, username, password, timeout)[1]
|
||||
|
||||
|
||||
def fix_parentheses(sequence):
|
||||
num_left = sequence.count("(")
|
||||
num_right = sequence.count(")")
|
||||
|
||||
if num_right < num_left:
|
||||
fixed_sequence = sequence[:-1] + \
|
||||
[")" for _ in range(num_left - num_right)] + [sequence[-1]]
|
||||
return fixed_sequence
|
||||
|
||||
return sequence
|
|
@ -0,0 +1,83 @@
|
|||
"""Tokenizers for natural language SQL queries, and lambda calculus."""
|
||||
import nltk
|
||||
nltk.data.path.append('./nltk_data/')
|
||||
import sqlparse
|
||||
|
||||
def nl_tokenize(string):
|
||||
"""Tokenizes a natural language string into tokens.
|
||||
|
||||
Inputs:
|
||||
string: the string to tokenize.
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
|
||||
Assumes data is space-separated (this is true of ZC07 data in ATIS2/3).
|
||||
"""
|
||||
return nltk.word_tokenize(string)
|
||||
|
||||
def sql_tokenize(string):
|
||||
""" Tokenizes a SQL statement into tokens.
|
||||
|
||||
Inputs:
|
||||
string: string to tokenize.
|
||||
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
"""
|
||||
tokens = []
|
||||
statements = sqlparse.parse(string)
|
||||
|
||||
# SQLparse gives you a list of statements.
|
||||
for statement in statements:
|
||||
# Flatten the tokens in each statement and add to the tokens list.
|
||||
flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten()
|
||||
for token in flat_tokens:
|
||||
strip_token = str(token).strip()
|
||||
if len(strip_token) > 0:
|
||||
tokens.append(strip_token)
|
||||
|
||||
newtokens = []
|
||||
keep = True
|
||||
for i, token in enumerate(tokens):
|
||||
if token == ".":
|
||||
newtoken = newtokens[-1] + "." + tokens[i + 1]
|
||||
newtokens = newtokens[:-1] + [newtoken]
|
||||
keep = False
|
||||
elif keep:
|
||||
newtokens.append(token)
|
||||
else:
|
||||
keep = True
|
||||
|
||||
return newtokens
|
||||
|
||||
def lambda_tokenize(string):
|
||||
""" Tokenizes a lambda-calculus statement into tokens.
|
||||
|
||||
Inputs:
|
||||
string: a lambda-calculus string
|
||||
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
"""
|
||||
|
||||
space_separated = string.split(" ")
|
||||
|
||||
new_tokens = []
|
||||
|
||||
# Separate the string by spaces, then separate based on existence of ( or
|
||||
# ).
|
||||
for token in space_separated:
|
||||
tokens = []
|
||||
|
||||
current_token = ""
|
||||
for char in token:
|
||||
if char == ")" or char == "(":
|
||||
tokens.append(current_token)
|
||||
tokens.append(char)
|
||||
current_token = ""
|
||||
else:
|
||||
current_token += char
|
||||
tokens.append(current_token)
|
||||
new_tokens.extend([tok for tok in tokens if tok])
|
||||
|
||||
return new_tokens
|
|
@ -0,0 +1,16 @@
|
|||
"""Contains various utility functions."""
|
||||
def subsequence(first_sequence, second_sequence):
|
||||
"""
|
||||
Returns whether the first sequence is a subsequence of the second sequence.
|
||||
|
||||
Inputs:
|
||||
first_sequence (list): A sequence.
|
||||
second_sequence (list): Another sequence.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether first_sequence is a subsequence of second_sequence.
|
||||
"""
|
||||
for startidx in range(len(second_sequence) - len(first_sequence) + 1):
|
||||
if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence:
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,115 @@
|
|||
""" Contains the Utterance class. """
|
||||
|
||||
from . import sql_util
|
||||
from . import tokenizers
|
||||
|
||||
ANON_INPUT_KEY = "cleaned_nl"
|
||||
OUTPUT_KEY = "sql"
|
||||
|
||||
class Utterance:
|
||||
""" Utterance class. """
|
||||
def process_input_seq(self,
|
||||
anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent):
|
||||
assert not anon_tok_to_ent or anonymize
|
||||
assert not anonymize or anonymizer
|
||||
|
||||
if anonymize:
|
||||
assert anonymizer
|
||||
|
||||
self.input_seq_to_use = anonymizer.anonymize(
|
||||
self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True)
|
||||
else:
|
||||
self.input_seq_to_use = self.original_input_seq
|
||||
|
||||
def process_gold_seq(self,
|
||||
output_sequences,
|
||||
nl_to_sql_dict,
|
||||
available_snippets,
|
||||
anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent):
|
||||
# Get entities in the input sequence:
|
||||
# anonymized entity types
|
||||
# othe recognized entities (this includes "flight")
|
||||
entities_in_input = [
|
||||
[tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent]
|
||||
entities_in_input.extend(
|
||||
nl_to_sql_dict.get_sql_entities(
|
||||
self.input_seq_to_use))
|
||||
|
||||
# Get the shortest gold query (this is what we use to train)
|
||||
shortest_gold_and_results = min(output_sequences,
|
||||
key=lambda x: len(x[0]))
|
||||
|
||||
# Tokenize and anonymize it if necessary.
|
||||
self.original_gold_query = shortest_gold_and_results[0]
|
||||
self.gold_sql_results = shortest_gold_and_results[1]
|
||||
|
||||
self.contained_entities = entities_in_input
|
||||
|
||||
# Keep track of all gold queries and the resulting tables so that we can
|
||||
# give credit if it predicts a different correct sequence.
|
||||
self.all_gold_queries = output_sequences
|
||||
|
||||
self.anonymized_gold_query = self.original_gold_query
|
||||
if anonymize:
|
||||
self.anonymized_gold_query = anonymizer.anonymize(
|
||||
self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False)
|
||||
|
||||
# Add snippets to it.
|
||||
self.gold_query_to_use = sql_util.add_snippets_to_query(
|
||||
available_snippets, entities_in_input, self.anonymized_gold_query)
|
||||
|
||||
def __init__(self,
|
||||
example,
|
||||
available_snippets,
|
||||
nl_to_sql_dict,
|
||||
params,
|
||||
anon_tok_to_ent={},
|
||||
anonymizer=None):
|
||||
# Get output and input sequences from the dictionary representation.
|
||||
output_sequences = example[OUTPUT_KEY]
|
||||
self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key])
|
||||
self.available_snippets = available_snippets
|
||||
self.keep = False
|
||||
|
||||
# pruned_output_sequences = []
|
||||
# for sequence in output_sequences:
|
||||
# if len(sequence[0]) > 3:
|
||||
# pruned_output_sequences.append(sequence)
|
||||
|
||||
# output_sequences = pruned_output_sequences
|
||||
if len(output_sequences) > 0 and len(self.original_input_seq) > 0:
|
||||
# Only keep this example if there is at least one output sequence.
|
||||
self.keep = True
|
||||
if len(output_sequences) == 0 or len(self.original_input_seq) == 0:
|
||||
return
|
||||
|
||||
# Process the input sequence
|
||||
self.process_input_seq(params.anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent)
|
||||
|
||||
# Process the gold sequence
|
||||
self.process_gold_seq(output_sequences,
|
||||
nl_to_sql_dict,
|
||||
self.available_snippets,
|
||||
params.anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent)
|
||||
|
||||
def __str__(self):
|
||||
string = "Original input: " + " ".join(self.original_input_seq) + "\n"
|
||||
string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n"
|
||||
string += "Original output: " + " ".join(self.original_gold_query) + "\n"
|
||||
string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n"
|
||||
string += "Snippets:\n"
|
||||
for snippet in self.available_snippets:
|
||||
string += str(snippet) + "\n"
|
||||
return string
|
||||
|
||||
def length_valid(self, input_limit, output_limit):
|
||||
return (len(self.input_seq_to_use) < input_limit \
|
||||
and len(self.gold_query_to_use) < output_limit)
|
|
@ -0,0 +1,98 @@
|
|||
"""Contains class and methods for storing and computing a vocabulary from text."""
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
|
||||
# Special sequencing tokens.
|
||||
UNK_TOK = "_UNK" # Replaces out-of-vocabulary words.
|
||||
EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end.
|
||||
DEL_TOK = ";"
|
||||
|
||||
|
||||
class Vocabulary:
|
||||
"""Vocabulary class: stores information about words in a corpus.
|
||||
|
||||
Members:
|
||||
functional_types (list of str): Functional vocabulary words, such as EOS.
|
||||
max_size (int): The maximum size of vocabulary to keep.
|
||||
min_occur (int): The minimum number of times a word should occur to keep it.
|
||||
id_to_token (list of str): Ordered list of word types.
|
||||
token_to_id (dict str->int): Maps from each unique word type to its index.
|
||||
"""
|
||||
def get_vocab(self, sequences, ignore_fn):
|
||||
"""Gets vocabulary from a list of sequences.
|
||||
|
||||
Inputs:
|
||||
sequences (list of list of str): Sequences from which to compute the vocabulary.
|
||||
ignore_fn (lambda str: bool): Function used to tell whether to ignore a
|
||||
token during computation of the vocabulary.
|
||||
|
||||
Returns:
|
||||
list of str, representing the unique word types in the vocabulary.
|
||||
"""
|
||||
type_counts = {}
|
||||
|
||||
for sequence in sequences:
|
||||
for token in sequence:
|
||||
if not ignore_fn(token):
|
||||
if token not in type_counts:
|
||||
type_counts[token] = 0
|
||||
type_counts[token] += 1
|
||||
|
||||
# Create sorted list of tokens, by their counts. Reverse so it is in order of
|
||||
# most frequent to least frequent.
|
||||
sorted_type_counts = sorted(sorted(type_counts.items()),
|
||||
key=operator.itemgetter(1))[::-1]
|
||||
|
||||
sorted_types = [typecount[0]
|
||||
for typecount in sorted_type_counts if typecount[1] >= self.min_occur]
|
||||
|
||||
# Append the necessary functional tokens.
|
||||
sorted_types = self.functional_types + sorted_types
|
||||
|
||||
# Cut off if vocab_size is set (nonnegative)
|
||||
if self.max_size >= 0:
|
||||
vocab = sorted_types[:max(self.max_size, len(sorted_types))]
|
||||
else:
|
||||
vocab = sorted_types
|
||||
|
||||
return vocab
|
||||
|
||||
def __init__(self,
|
||||
sequences,
|
||||
filename,
|
||||
functional_types=None,
|
||||
max_size=-1,
|
||||
min_occur=0,
|
||||
ignore_fn=lambda x: False):
|
||||
self.functional_types = functional_types
|
||||
self.max_size = max_size
|
||||
self.min_occur = min_occur
|
||||
|
||||
vocab = self.get_vocab(sequences, ignore_fn)
|
||||
|
||||
self.id_to_token = []
|
||||
self.token_to_id = {}
|
||||
|
||||
for i, word_type in enumerate(vocab):
|
||||
self.id_to_token.append(word_type)
|
||||
self.token_to_id[word_type] = i
|
||||
|
||||
# Load the previous vocab, if it exists.
|
||||
if os.path.exists(filename):
|
||||
infile = open(filename, 'rb')
|
||||
loaded_vocab = pickle.load(infile)
|
||||
infile.close()
|
||||
|
||||
print("Loaded vocabulary from " + str(filename))
|
||||
if loaded_vocab.id_to_token != self.id_to_token \
|
||||
or loaded_vocab.token_to_id != self.token_to_id:
|
||||
print("Loaded vocabulary is different than generated vocabulary.")
|
||||
else:
|
||||
print("Writing vocabulary to " + str(filename))
|
||||
outfile = open(filename, 'wb')
|
||||
pickle.dump(self, outfile)
|
||||
outfile.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.id_to_token)
|
|
@ -0,0 +1,866 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
|
||||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]]
|
||||
label_wo_agg = [unit[2] for unit in label_conds]
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
|
||||
for _, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
if len(label['from']['table_units']) > 0:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
return label_tables == pred_tables
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold) as f:
|
||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict) as f:
|
||||
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
|
||||
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
|
||||
evaluator = Evaluator()
|
||||
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
p_str = p[0]
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table) as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
|
@ -0,0 +1,938 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
|
||||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]]
|
||||
label_wo_agg = [unit[2] for unit in label_conds]
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
|
||||
for key, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
|
||||
if len(label['from']['table_units']) > 0:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
return label_tables == pred_tables
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all', "joint_all"]
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
print("\n\n{:20} {:20} {:20} {:20} {:20} {:20}".format("", *turns))
|
||||
counts = [scores[turn]['count'] for turn in turns]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== TRUN XECUTION ACCURACY =====================')
|
||||
this_scores = [scores[turn]['exec'] for turn in turns]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== TRUN EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[turn]['exact'] for turn in turns]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold) as f:
|
||||
glist = []
|
||||
gseq_one = []
|
||||
for l in f.readlines():
|
||||
if len(l.strip()) == 0:
|
||||
glist.append(gseq_one)
|
||||
gseq_one = []
|
||||
else:
|
||||
lstrip = l.strip().split('\t')
|
||||
gseq_one.append(lstrip)
|
||||
#glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict) as f:
|
||||
plist = []
|
||||
pseq_one = []
|
||||
for l in f.readlines():
|
||||
if len(l.strip()) == 0:
|
||||
plist.append(pseq_one)
|
||||
pseq_one = []
|
||||
else:
|
||||
pseq_one.append(l.strip().split('\t'))
|
||||
#plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
|
||||
# glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
|
||||
evaluator = Evaluator()
|
||||
|
||||
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for turn in turns:
|
||||
scores[turn] = {'count': 0, 'exact': 0.}
|
||||
scores[turn]['exec'] = 0
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
scores['joint_all']['count'] += 1
|
||||
turn_scores = {"exec": [], "exact": []}
|
||||
for idx, pg in enumerate(zip(p, g)):
|
||||
p, g = pg
|
||||
p_str = p[0]
|
||||
p_str = p_str.replace("value", "1")
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
if idx > 3:
|
||||
idx = ">4"
|
||||
else:
|
||||
idx += 1
|
||||
turn_id = "turn " + str(idx)
|
||||
scores[turn_id]['count'] += 1
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
scores[turn_id]['exec'] += 1
|
||||
turn_scores['exec'].append(1)
|
||||
else:
|
||||
turn_scores['exec'].append(0)
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
turn_scores['exact'].append(0)
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
else:
|
||||
turn_scores['exact'].append(1)
|
||||
scores[turn_id]['exact'] += exact_score
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
if all(v == 1 for v in turn_scores["exec"]):
|
||||
scores['joint_all']['exec'] += 1
|
||||
|
||||
if all(v == 1 for v in turn_scores["exact"]):
|
||||
scores['joint_all']['exact'] += 1
|
||||
|
||||
for turn in turns:
|
||||
if scores[turn]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[turn]['exec'] /= scores[turn]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[turn]['exact'] /= scores[turn]['count']
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table) as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,53 @@
|
|||
import json
|
||||
import sys
|
||||
|
||||
predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line]
|
||||
|
||||
string_count = 0.
|
||||
sem_count = 0.
|
||||
syn_count = 0.
|
||||
table_count = 0.
|
||||
strict_table_count = 0.
|
||||
|
||||
precision_denom = 0.
|
||||
precision = 0.
|
||||
recall_denom = 0.
|
||||
recall = 0.
|
||||
f1_score = 0.
|
||||
f1_denom = 0.
|
||||
|
||||
time = 0.
|
||||
|
||||
for prediction in predictions:
|
||||
if prediction["correct_string"]:
|
||||
string_count += 1.
|
||||
if prediction["semantic"]:
|
||||
sem_count += 1.
|
||||
if prediction["syntactic"]:
|
||||
syn_count += 1.
|
||||
if prediction["correct_table"]:
|
||||
table_count += 1.
|
||||
if prediction["strict_correct_table"]:
|
||||
strict_table_count += 1.
|
||||
if prediction["gold_tables"] !="[[]]":
|
||||
precision += prediction["table_prec"]
|
||||
precision_denom += 1
|
||||
if prediction["pred_table"] != "[]":
|
||||
recall += prediction["table_rec"]
|
||||
recall_denom += 1
|
||||
|
||||
if prediction["gold_tables"] != "[[]]":
|
||||
f1_score += prediction["table_f1"]
|
||||
f1_denom += 1
|
||||
|
||||
num_p = len(predictions)
|
||||
print("string precision: " + str(string_count / num_p))
|
||||
print("% semantic: " + str(sem_count / num_p))
|
||||
print("% syntactic: " + str(syn_count / num_p))
|
||||
print("table prec: " + str(table_count / num_p))
|
||||
print("strict table prec: " + str(strict_table_count / num_p))
|
||||
print("table row prec: " + str(precision / precision_denom))
|
||||
print("table row recall: " + str(recall / recall_denom))
|
||||
print("table row f1: " + str(f1_score / f1_denom))
|
||||
print("inference time: " + str(time / num_p))
|
||||
|
|
@ -0,0 +1,569 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from nltk import word_tokenize
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema):
|
||||
self._schema = schema
|
||||
self._idMap = self._map(self._schema)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema):
|
||||
idMap = {'*': "__all__"}
|
||||
id = 1
|
||||
for key, vals in schema.items():
|
||||
for val in vals:
|
||||
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
|
||||
id += 1
|
||||
|
||||
for key in schema:
|
||||
idMap[key.lower()] = "__" + key.lower() + "__"
|
||||
id += 1
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schema(db):
|
||||
"""
|
||||
Get database's schema, which is a dict with table name as key
|
||||
and list of column names as value
|
||||
:param db: database path
|
||||
:return: schema dict
|
||||
"""
|
||||
|
||||
schema = {}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# fetch table names
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [str(table[0].lower()) for table in cursor.fetchall()]
|
||||
|
||||
# fetch table info
|
||||
for table in tables:
|
||||
cursor.execute("PRAGMA table_info({})".format(table))
|
||||
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_schema_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
|
||||
schema = {}
|
||||
for entry in data:
|
||||
table = str(entry['table'].lower())
|
||||
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
|
||||
schema[table] = cols
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def tokenize(string):
|
||||
string = str(string)
|
||||
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
|
||||
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
|
||||
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
|
||||
|
||||
# keep string value as token
|
||||
vals = {}
|
||||
for i in range(len(quote_idxs)-1, -1, -2):
|
||||
qidx1 = quote_idxs[i-1]
|
||||
qidx2 = quote_idxs[i]
|
||||
val = string[qidx1: qidx2+1]
|
||||
key = "__val_{}_{}__".format(qidx1, qidx2)
|
||||
string = string[:qidx1] + key + string[qidx2+1:]
|
||||
vals[key] = val
|
||||
|
||||
toks = [word.lower() for word in word_tokenize(string)]
|
||||
# replace with string value token
|
||||
for i in range(len(toks)):
|
||||
if toks[i] in vals:
|
||||
toks[i] = vals[toks[i]]
|
||||
|
||||
# find if there exists !=, >=, <=
|
||||
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
|
||||
eq_idxs.reverse()
|
||||
prefix = ('!', '>', '<')
|
||||
for eq_idx in eq_idxs:
|
||||
pre_tok = toks[eq_idx-1]
|
||||
if pre_tok in prefix:
|
||||
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def scan_alias(toks):
|
||||
"""Scan the index of 'as' and build the map for all alias"""
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
alias[toks[idx+1]] = toks[idx-1]
|
||||
return alias
|
||||
|
||||
|
||||
def get_tables_with_alias(schema, toks):
|
||||
tables = scan_alias(toks)
|
||||
for key in schema:
|
||||
assert key not in tables, "Alias {} has the same name in table".format(key)
|
||||
tables[key] = key
|
||||
return tables
|
||||
|
||||
|
||||
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, column id
|
||||
"""
|
||||
tok = toks[start_idx]
|
||||
if tok == "*":
|
||||
return start_idx + 1, schema.idMap[tok]
|
||||
|
||||
if '.' in tok: # if token is a composite
|
||||
alias, col = tok.split('.')
|
||||
key = tables_with_alias[alias] + "." + col
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
|
||||
|
||||
for alias in default_tables:
|
||||
table = tables_with_alias[alias]
|
||||
if tok in schema.schema[table]:
|
||||
key = table + "." + tok
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert False, "Error col: {}".format(tok)
|
||||
|
||||
|
||||
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, (agg_op id, col_id)
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
isDistinct = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
assert idx < len_ and toks[idx] == '('
|
||||
idx += 1
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert idx < len_ and toks[idx] == ')'
|
||||
idx += 1
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
agg_id = AGG_OPS.index("none")
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
|
||||
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
col_unit1 = None
|
||||
col_unit2 = None
|
||||
unit_op = UNIT_OPS.index('none')
|
||||
|
||||
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if idx < len_ and toks[idx] in UNIT_OPS:
|
||||
unit_op = UNIT_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (unit_op, col_unit1, col_unit2)
|
||||
|
||||
|
||||
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
:returns next idx, table id, table name
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
key = tables_with_alias[toks[idx]]
|
||||
|
||||
if idx + 1 < len_ and toks[idx+1] == "as":
|
||||
idx += 3
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
return idx, schema.idMap[key], key
|
||||
|
||||
|
||||
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
elif "\"" in toks[idx]: # token is a string value
|
||||
val = toks[idx]
|
||||
idx += 1
|
||||
else:
|
||||
try:
|
||||
val = float(toks[idx])
|
||||
idx += 1
|
||||
except:
|
||||
end_idx = idx
|
||||
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
|
||||
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
|
||||
end_idx += 1
|
||||
|
||||
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
|
||||
idx = end_idx
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
|
||||
return idx, val
|
||||
|
||||
|
||||
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
not_op = False
|
||||
if toks[idx] == 'not':
|
||||
not_op = True
|
||||
idx += 1
|
||||
|
||||
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
|
||||
op_id = WHERE_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
val1 = val2 = None
|
||||
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert toks[idx] == 'and'
|
||||
idx += 1
|
||||
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
else: # normal case: single value
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val2 = None
|
||||
|
||||
conds.append((not_op, op_id, val_unit, val1, val2))
|
||||
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
|
||||
break
|
||||
|
||||
if idx < len_ and toks[idx] in COND_OPS:
|
||||
conds.append(toks[idx])
|
||||
idx += 1 # skip and/or
|
||||
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
assert toks[idx] == 'select', "'select' not found"
|
||||
idx += 1
|
||||
isDistinct = False
|
||||
if idx < len_ and toks[idx] == 'distinct':
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
val_units = []
|
||||
|
||||
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
|
||||
agg_id = AGG_OPS.index("none")
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append((agg_id, val_unit))
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
|
||||
return idx, (isDistinct, val_units)
|
||||
|
||||
|
||||
def parse_from(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
Assume in the from clause, all table units are combined with join
|
||||
"""
|
||||
assert 'from' in toks[start_idx:], "'from' not found"
|
||||
|
||||
len_ = len(toks)
|
||||
idx = toks.index('from', start_idx) + 1
|
||||
default_tables = []
|
||||
table_units = []
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['sql'], sql))
|
||||
else:
|
||||
if idx < len_ and toks[idx] == 'join':
|
||||
idx += 1 # skip join
|
||||
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['table_unit'],table_unit))
|
||||
default_tables.append(table_name)
|
||||
if idx < len_ and toks[idx] == "on":
|
||||
idx += 1 # skip on
|
||||
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if len(conds) > 0:
|
||||
conds.append('and')
|
||||
conds.extend(this_conds)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
break
|
||||
|
||||
return idx, table_units, conds, default_tables
|
||||
|
||||
|
||||
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'where':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
col_units = []
|
||||
|
||||
if idx >= len_ or toks[idx] != 'group':
|
||||
return idx, col_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
col_units.append(col_unit)
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, col_units
|
||||
|
||||
|
||||
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
val_units = []
|
||||
order_type = 'asc' # default type is 'asc'
|
||||
|
||||
if idx >= len_ or toks[idx] != 'order':
|
||||
return idx, val_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append(val_unit)
|
||||
if idx < len_ and toks[idx] in ORDER_OPS:
|
||||
order_type = toks[idx]
|
||||
idx += 1
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, (order_type, val_units)
|
||||
|
||||
|
||||
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'having':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_limit(toks, start_idx):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx < len_ and toks[idx] == 'limit':
|
||||
idx += 2
|
||||
# make limit value can work, cannot assume put 1 as a fake limit number
|
||||
if type(toks[idx-1]) != int:
|
||||
return idx, 1
|
||||
|
||||
return idx, int(toks[idx-1])
|
||||
|
||||
return idx, None
|
||||
|
||||
|
||||
def parse_sql(toks, start_idx, tables_with_alias, schema):
|
||||
isBlock = False # indicate whether this is a block of sql/sub-sql
|
||||
len_ = len(toks)
|
||||
idx = start_idx
|
||||
|
||||
sql = {}
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
# parse from clause in order to get default tables
|
||||
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
|
||||
sql['from'] = {'table_units': table_units, 'conds': conds}
|
||||
# select clause
|
||||
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
|
||||
idx = from_end_idx
|
||||
sql['select'] = select_col_units
|
||||
# where clause
|
||||
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['where'] = where_conds
|
||||
# group by clause
|
||||
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['groupBy'] = group_col_units
|
||||
# having clause
|
||||
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['having'] = having_conds
|
||||
# order by clause
|
||||
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['orderBy'] = order_col_units
|
||||
# limit clause
|
||||
idx, limit_val = parse_limit(toks, idx)
|
||||
sql['limit'] = limit_val
|
||||
|
||||
idx = skip_semicolon(toks, idx)
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
idx = skip_semicolon(toks, idx)
|
||||
|
||||
# intersect/union/except clause
|
||||
for op in SQL_OPS: # initialize IUE
|
||||
sql[op] = None
|
||||
if idx < len_ and toks[idx] in SQL_OPS:
|
||||
sql_op = toks[idx]
|
||||
idx += 1
|
||||
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
sql[sql_op] = IUE_sql
|
||||
return idx, sql
|
||||
|
||||
|
||||
def load_data(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def skip_semicolon(toks, start_idx):
|
||||
idx = start_idx
|
||||
while idx < len(toks) and toks[idx] == ";":
|
||||
idx += 1
|
||||
return idx
|
|
@ -0,0 +1,58 @@
|
|||
"""Contains the logging class."""
|
||||
|
||||
class Logger():
|
||||
"""Attributes:
|
||||
|
||||
fileptr (file): File pointer for input/output.
|
||||
lines (list of str): The lines read from the log.
|
||||
"""
|
||||
def __init__(self, filename, option):
|
||||
self.fileptr = open(filename, option)
|
||||
if option == "r":
|
||||
self.lines = self.fileptr.readlines()
|
||||
else:
|
||||
self.lines = []
|
||||
|
||||
def put(self, string):
|
||||
"""Writes to the file."""
|
||||
self.fileptr.write(string + "\n")
|
||||
self.fileptr.flush()
|
||||
|
||||
def close(self):
|
||||
"""Closes the logger."""
|
||||
self.fileptr.close()
|
||||
|
||||
def findlast(self, identifier, default=0.):
|
||||
"""Finds the last line in the log with a certain value."""
|
||||
for line in self.lines[::-1]:
|
||||
if line.lower().startswith(identifier):
|
||||
string = line.strip().split("\t")[1]
|
||||
if string.replace(".", "").isdigit():
|
||||
return float(string)
|
||||
elif string.lower() == "true":
|
||||
return True
|
||||
elif string.lower() == "false":
|
||||
return False
|
||||
else:
|
||||
return string
|
||||
return default
|
||||
|
||||
def contains(self, string):
|
||||
"""Dtermines whether the string is present in the log."""
|
||||
for line in self.lines[::-1]:
|
||||
if string.lower() in line.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
def findlast_log_before(self, before_str):
|
||||
"""Finds the last entry in the log before another entry."""
|
||||
loglines = []
|
||||
in_line = False
|
||||
for line in self.lines[::-1]:
|
||||
if line.startswith(before_str):
|
||||
in_line = True
|
||||
elif in_line:
|
||||
loglines.append(line)
|
||||
if line.strip() == "" and in_line:
|
||||
return "".join(loglines[::-1])
|
||||
return "".join(loglines[::-1])
|
|
@ -0,0 +1,80 @@
|
|||
"""Contains classes for computing and keeping track of attention distributions.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
class AttentionResult(namedtuple('AttentionResult',
|
||||
('scores',
|
||||
'distribution',
|
||||
'vector'))):
|
||||
"""Stores the result of an attention calculation."""
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""Attention mechanism class. Stores parameters for and computes attention.
|
||||
|
||||
Attributes:
|
||||
transform_query (bool): Whether or not to transform the query being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
transform_key (bool): Whether or not to transform the key being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
transform_value (bool): Whether or not to transform the value being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
key_size (int): The size of the key vectors.
|
||||
value_size (int): The size of the value vectors.
|
||||
the query or key.
|
||||
query_weights (dy.Parameters): Weights for transforming the query.
|
||||
key_weights (dy.Parameters): Weights for transforming the key.
|
||||
value_weights (dy.Parameters): Weights for transforming the value.
|
||||
"""
|
||||
def __init__(self, query_size, key_size, value_size):
|
||||
super().__init__()
|
||||
self.key_size = key_size
|
||||
self.value_size = value_size
|
||||
|
||||
self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q")
|
||||
|
||||
def transform_arguments(self, query, keys, values):
|
||||
""" Transforms the query/key/value inputs before attention calculations.
|
||||
|
||||
Arguments:
|
||||
query (dy.Expression): Vector representing the query (e.g., hidden state.)
|
||||
keys (list of dy.Expression): List of vectors representing the key
|
||||
values.
|
||||
values (list of dy.Expression): List of vectors representing the values.
|
||||
|
||||
Returns:
|
||||
triple of dy.Expression, where the first represents the (transformed)
|
||||
query, the second represents the (transformed and concatenated)
|
||||
keys, and the third represents the (transformed and concatenated)
|
||||
values.
|
||||
"""
|
||||
assert len(keys) == len(values)
|
||||
|
||||
all_keys = torch.stack(keys, dim=1)
|
||||
all_values = torch.stack(values, dim=1)
|
||||
|
||||
assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0])
|
||||
assert all_values.size()[0] == self.value_size
|
||||
|
||||
query = torch_utils.linear_layer(query, self.query_weights)
|
||||
|
||||
return query, all_keys, all_values
|
||||
|
||||
def forward(self, query, keys, values=None):
|
||||
if not values:
|
||||
values = keys
|
||||
|
||||
query_t, keys_t, values_t = self.transform_arguments(query, keys, values)
|
||||
|
||||
scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query)
|
||||
|
||||
distribution = F.softmax(scores, dim=0) # len(key) x len(query)
|
||||
|
||||
context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query)
|
||||
|
||||
return AttentionResult(scores, distribution, context_vector)
|
|
@ -0,0 +1,97 @@
|
|||
import heapq
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple, List
|
||||
import torch
|
||||
|
||||
class BeamSearchState:
|
||||
def __init__(self, sequence=[], log_probability=0, state={}):
|
||||
self.sequence = sequence
|
||||
self.log_probability = log_probability
|
||||
self.state = state
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.log_probability < other.log_probability
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.log_probability}: {' '.join(self.sequence)}"
|
||||
|
||||
class BeamSearch:
|
||||
"""
|
||||
Implements the beam search algorithm for decoding the most likely sequences.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
is_end_of_sequence: Callable[[int, str], bool],
|
||||
max_steps: int = 50,
|
||||
beam_size: int = 10,
|
||||
per_node_beam_size: int = None,
|
||||
) -> None:
|
||||
self.is_end_of_sequence = is_end_of_sequence
|
||||
self.max_steps = max_steps
|
||||
self.beam_size = beam_size
|
||||
self.per_node_beam_size = per_node_beam_size
|
||||
if not per_node_beam_size:
|
||||
self.per_node_beam_size = beam_size
|
||||
|
||||
self.beam = []
|
||||
|
||||
def append(self, log_probability, new_state):
|
||||
if len(self.beam) < self.beam_size:
|
||||
heapq.heappush(self.beam, (log_probability, new_state))
|
||||
else:
|
||||
heapq.heapreplace(self.beam, (log_probability, new_state))
|
||||
|
||||
def get_largest(self, beam):
|
||||
return list(heapq.nlargest(1, beam))[0]
|
||||
|
||||
def search(
|
||||
self,
|
||||
start_state: Dict,
|
||||
step_function: Callable[[Dict], Tuple[List[str], torch.Tensor, Tuple]],
|
||||
append_token_function: Callable[[Tuple, Tuple, str, float], Tuple]
|
||||
) -> Tuple[List, float, Tuple]:
|
||||
state = BeamSearchState(state=start_state)
|
||||
heapq.heappush(self.beam, (0, state))
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
current_beam = self.beam
|
||||
|
||||
self.beam = []
|
||||
found_non_terminal = False
|
||||
|
||||
for (log_probability, state) in current_beam:
|
||||
if self.is_end_of_sequence(state.state, self.max_steps, state.sequence):
|
||||
self.append(log_probability, state)
|
||||
|
||||
else:
|
||||
found_non_terminal = True
|
||||
tokens, token_probabilities, extra_state = step_function(state.state)
|
||||
|
||||
tps = np.array(token_probabilities)
|
||||
top_indices = heapq.nlargest(self.per_node_beam_size, range(len(tps)), tps.take)
|
||||
top_tokens = [tokens[i] for i in top_indices]
|
||||
top_probabilities = [token_probabilities[i] for i in top_indices]
|
||||
sequence_length = len(state.sequence)
|
||||
for i in range(len(top_tokens)):
|
||||
if top_probabilities[i] > 1e-6:
|
||||
#lp = ((log_probability * sequence_length) + math.log(top_probabilities[i])) / (sequence_length + 1.0)
|
||||
lp = log_probability + math.log(top_probabilities[i])
|
||||
t = top_tokens[i]
|
||||
new_state = append_token_function(state.state, extra_state, t, lp)
|
||||
new_sequence = state.sequence.copy()
|
||||
new_sequence.append(t)
|
||||
|
||||
ns = BeamSearchState(sequence=new_sequence, log_probability=lp, state=new_state)
|
||||
|
||||
self.append(lp, ns)
|
||||
|
||||
if not found_non_terminal:
|
||||
done = True
|
||||
|
||||
output = self.get_largest(self.beam)[1]
|
||||
|
||||
return output.sequence, output.log_probability, output.state, self.beam
|
||||
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
""" Decoder for the SQL generation problem."""
|
||||
|
||||
from collections import namedtuple
|
||||
import copy
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
from .token_predictor import PredictionInput, PredictionInputWithSchema, PredictionStepInputWithSchema
|
||||
from .beam_search import BeamSearch
|
||||
import data_util.snippets as snippet_handler
|
||||
from . import embedder
|
||||
from data_util.vocabulary import EOS_TOK, UNK_TOK
|
||||
import math
|
||||
|
||||
def flatten_distribution(distribution_map, probabilities):
|
||||
""" Flattens a probability distribution given a map of "unique" values.
|
||||
All values in distribution_map with the same value should get the sum
|
||||
of the probabilities.
|
||||
|
||||
Arguments:
|
||||
distribution_map (list of str): List of values to get the probability for.
|
||||
probabilities (np.ndarray): Probabilities corresponding to the values in
|
||||
distribution_map.
|
||||
|
||||
Returns:
|
||||
list, np.ndarray of the same size where probabilities for duplicates
|
||||
in distribution_map are given the sum of the probabilities in probabilities.
|
||||
"""
|
||||
assert len(distribution_map) == len(probabilities)
|
||||
if len(distribution_map) != len(set(distribution_map)):
|
||||
idx_first_dup = 0
|
||||
seen_set = set()
|
||||
for i, tok in enumerate(distribution_map):
|
||||
if tok in seen_set:
|
||||
idx_first_dup = i
|
||||
break
|
||||
seen_set.add(tok)
|
||||
new_dist_map = distribution_map[:idx_first_dup] + list(
|
||||
set(distribution_map) - set(distribution_map[:idx_first_dup]))
|
||||
assert len(new_dist_map) == len(set(new_dist_map))
|
||||
new_probs = np.array(
|
||||
probabilities[:idx_first_dup] \
|
||||
+ [0. for _ in range(len(set(distribution_map)) \
|
||||
- idx_first_dup)])
|
||||
assert len(new_probs) == len(new_dist_map)
|
||||
|
||||
for i, token_name in enumerate(
|
||||
distribution_map[idx_first_dup:]):
|
||||
if token_name not in new_dist_map:
|
||||
new_dist_map.append(token_name)
|
||||
|
||||
new_index = new_dist_map.index(token_name)
|
||||
new_probs[new_index] += probabilities[i +
|
||||
idx_first_dup]
|
||||
new_probs = new_probs.tolist()
|
||||
else:
|
||||
new_dist_map = distribution_map
|
||||
new_probs = probabilities
|
||||
|
||||
assert len(new_dist_map) == len(new_probs)
|
||||
|
||||
return new_dist_map, new_probs
|
||||
|
||||
class SQLPrediction(namedtuple('SQLPrediction',
|
||||
('predictions',
|
||||
'sequence',
|
||||
'probability',
|
||||
'beam'))):
|
||||
"""Contains prediction for a sequence."""
|
||||
__slots__ = ()
|
||||
|
||||
def __str__(self):
|
||||
return str(self.probability) + "\t" + " ".join(self.sequence)
|
||||
|
||||
Count = 0
|
||||
|
||||
class SequencePredictorWithSchema(torch.nn.Module):
|
||||
""" Predicts a sequence.
|
||||
|
||||
Attributes:
|
||||
lstms (list of dy.RNNBuilder): The RNN used.
|
||||
token_predictor (TokenPredictor): Used to actually predict tokens.
|
||||
"""
|
||||
def __init__(self,
|
||||
params,
|
||||
input_size,
|
||||
output_embedder,
|
||||
column_name_token_embedder,
|
||||
token_predictor):
|
||||
super().__init__()
|
||||
|
||||
self.lstms = torch_utils.create_multilayer_lstm_params(params.decoder_num_layers, input_size, params.decoder_state_size, "LSTM-d")
|
||||
self.token_predictor = token_predictor
|
||||
self.output_embedder = output_embedder
|
||||
self.column_name_token_embedder = column_name_token_embedder
|
||||
self.start_token_embedding = torch_utils.add_params((params.output_embedding_size,), "y-0")
|
||||
|
||||
self.input_size = input_size
|
||||
self.params = params
|
||||
|
||||
self.params.use_turn_num = False
|
||||
|
||||
if self.params.use_turn_num:
|
||||
assert params.decoder_num_layers <= 2
|
||||
state_size = params.decoder_state_size
|
||||
self.layer1_c_0 = nn.Linear(state_size+5, state_size)
|
||||
self.layer1_h_0 = nn.Linear(state_size+5, state_size)
|
||||
if params.decoder_num_layers == 2:
|
||||
self.layer2_c_0 = nn.Linear(state_size+5, state_size)
|
||||
self.layer2_h_0 = nn.Linear(state_size+5, state_size)
|
||||
|
||||
|
||||
def _initialize_decoder_lstm(self, encoder_state, utterance_index):
|
||||
decoder_lstm_states = []
|
||||
#print('utterance_index', utterance_index)
|
||||
#print('self.params.decoder_num_layers', self.params.decoder_num_layers)
|
||||
if self.params.use_turn_num:
|
||||
if utterance_index >= 4:
|
||||
utterance_index = 4;
|
||||
turn_number_embedding = torch.zeros([1, 5]).cuda()
|
||||
turn_number_embedding[0, utterance_index] = 1
|
||||
#print('turn_number_embedding', turn_number_embedding)
|
||||
|
||||
for i, lstm in enumerate(self.lstms):
|
||||
encoder_layer_num = 0
|
||||
if len(encoder_state[0]) > 1:
|
||||
encoder_layer_num = i
|
||||
|
||||
# check which one is h_0, which is c_0
|
||||
c_0 = encoder_state[0][encoder_layer_num].view(1,-1)
|
||||
h_0 = encoder_state[1][encoder_layer_num].view(1,-1)
|
||||
|
||||
if self.params.use_turn_num:
|
||||
if i == 0:
|
||||
c_0 = self.layer1_c_0( torch.cat([c_0, turn_number_embedding], -1) )
|
||||
h_0 = self.layer1_h_0( torch.cat([h_0, turn_number_embedding], -1) )
|
||||
else:
|
||||
c_0 = self.layer2_c_0( torch.cat([c_0, turn_number_embedding], -1) )
|
||||
h_0 = self.layer2_h_0( torch.cat([h_0, turn_number_embedding], -1) )
|
||||
|
||||
decoder_lstm_states.append((h_0, c_0))
|
||||
return decoder_lstm_states
|
||||
|
||||
def get_output_token_embedding(self, output_token, input_schema, snippets):
|
||||
if self.params.use_snippets and snippet_handler.is_snippet(output_token):
|
||||
output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema)
|
||||
else:
|
||||
if input_schema:
|
||||
assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
|
||||
if self.output_embedder.in_vocabulary(output_token):
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
else:
|
||||
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
|
||||
else:
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
return output_token_embedding
|
||||
|
||||
def get_decoder_input(self, output_token_embedding, prediction):
|
||||
if self.params.use_schema_attention and self.params.use_query_attention:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector, prediction.query_attention_results.vector], dim=0)
|
||||
elif self.params.use_schema_attention:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector], dim=0)
|
||||
else:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector], dim=0)
|
||||
return decoder_input
|
||||
|
||||
def forward(self,
|
||||
final_encoder_state,
|
||||
encoder_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
utterance_index=None,
|
||||
snippets=None,
|
||||
gold_sequence=None,
|
||||
input_sequence=None,
|
||||
previous_queries=None,
|
||||
previous_query_states=None,
|
||||
input_schema=None,
|
||||
dropout_amount=0.):
|
||||
global Count
|
||||
""" Generates a sequence. """
|
||||
index = 0
|
||||
|
||||
context_vector_size = self.input_size - self.params.output_embedding_size
|
||||
|
||||
# Decoder states: just the initialized decoder.
|
||||
# Current input to decoder: phi(start_token) ; zeros the size of the
|
||||
# context vector
|
||||
predictions = []
|
||||
sequence = []
|
||||
probability = 1.
|
||||
|
||||
decoder_states = self._initialize_decoder_lstm(final_encoder_state, utterance_index)
|
||||
|
||||
if self.start_token_embedding.is_cuda:
|
||||
decoder_input = torch.cat([self.start_token_embedding, torch.cuda.FloatTensor(context_vector_size).fill_(0)], dim=0)
|
||||
else:
|
||||
decoder_input = torch.cat([self.start_token_embedding, torch.zeros(context_vector_size)], dim=0)
|
||||
|
||||
beam_search = BeamSearch(is_end_of_sequence=self.is_end_of_sequence, max_steps=max_generation_length, beam_size=10)
|
||||
prediction_step_input = PredictionStepInputWithSchema(
|
||||
index=index,
|
||||
decoder_input=decoder_input,
|
||||
decoder_states=decoder_states,
|
||||
encoder_states=encoder_states,
|
||||
schema_states=schema_states,
|
||||
snippets=snippets,
|
||||
gold_sequence=gold_sequence,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
dropout_amount=dropout_amount,
|
||||
predictions=predictions,
|
||||
)
|
||||
|
||||
sequence, probability, final_state, beam = beam_search.search(start_state=prediction_step_input, step_function=self.beam_search_step_function, append_token_function=self.beam_search_append_token)
|
||||
|
||||
sum_p = 0
|
||||
for x in beam:
|
||||
# print('math.exp(x[0])', math.exp(x[0]))
|
||||
#print('x[1]', x[1].sequence)
|
||||
sum_p += math.exp(x[0])
|
||||
#print('sum_p', sum_p)
|
||||
assert sum_p <= 1.2
|
||||
|
||||
'''if len(sequence) <= 2:
|
||||
print('beam', beam)
|
||||
for x in beam:
|
||||
print('x[0]', x[0])
|
||||
print('x[1]', x[1].sequence)
|
||||
print('='*20)
|
||||
Count+=1
|
||||
assert Count <= 10'''
|
||||
|
||||
return SQLPrediction(final_state.predictions,
|
||||
sequence,
|
||||
probability,
|
||||
beam)
|
||||
|
||||
def beam_search_step_function(self, prediction_step_input):
|
||||
# get decoded token probabilities
|
||||
prediction, tokens, token_probabilities, decoder_states = self.decode_next_token(prediction_step_input)
|
||||
return tokens, token_probabilities, (prediction, decoder_states)
|
||||
|
||||
def beam_search_append_token(self, prediction_step_input, step_function_output, token_to_append, token_log_probability):
|
||||
output_token_embedding = self.get_output_token_embedding(token_to_append, prediction_step_input.input_schema, prediction_step_input.snippets)
|
||||
decoder_input = self.get_decoder_input(output_token_embedding, step_function_output[0])
|
||||
|
||||
predictions = prediction_step_input.predictions.copy()
|
||||
predictions.append(step_function_output[0])
|
||||
new_state = prediction_step_input._replace(
|
||||
index=prediction_step_input.index+1,
|
||||
decoder_input=decoder_input,
|
||||
predictions=predictions,
|
||||
decoder_states=step_function_output[1],
|
||||
)
|
||||
|
||||
return new_state
|
||||
|
||||
def is_end_of_sequence(self, prediction_step_input, max_generation_length, sequence):
|
||||
if prediction_step_input.gold_sequence:
|
||||
return prediction_step_input.index >= len(prediction_step_input.gold_sequence)
|
||||
else:
|
||||
return prediction_step_input.index >= max_generation_length or (len(sequence) > 0 and sequence[-1] == EOS_TOK)
|
||||
|
||||
def decode_next_token(self, prediction_step_input):
|
||||
(index,
|
||||
decoder_input,
|
||||
decoder_states,
|
||||
encoder_states,
|
||||
schema_states,
|
||||
snippets,
|
||||
gold_sequence,
|
||||
input_sequence,
|
||||
previous_queries,
|
||||
previous_query_states,
|
||||
input_schema,
|
||||
dropout_amount,
|
||||
predictions) = prediction_step_input
|
||||
|
||||
_, decoder_state, decoder_states = torch_utils.forward_one_multilayer(self.lstms, decoder_input, decoder_states, dropout_amount)
|
||||
prediction_input = PredictionInputWithSchema(decoder_state=decoder_state,
|
||||
input_hidden_states=encoder_states,
|
||||
schema_states=schema_states,
|
||||
snippets=snippets,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema)
|
||||
|
||||
prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount)
|
||||
|
||||
#gold_sequence = None
|
||||
|
||||
if gold_sequence:
|
||||
decoded_token = gold_sequence[index]
|
||||
distribution_map = [decoded_token]
|
||||
probabilities = [1.0]
|
||||
|
||||
else:#/mnt/lichaochao/rqy/editsql_data/glove.840B.300d.txt
|
||||
assert prediction.scores.dim() == 1
|
||||
probabilities = F.softmax(prediction.scores, dim=0).cpu().data.numpy().tolist()
|
||||
|
||||
distribution_map = prediction.aligned_tokens
|
||||
assert len(probabilities) == len(distribution_map)
|
||||
|
||||
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
|
||||
assert prediction.query_scores.dim() == 1
|
||||
query_token_probabilities = F.softmax(prediction.query_scores, dim=0).cpu().data.numpy().tolist()
|
||||
|
||||
query_token_distribution_map = prediction.query_tokens
|
||||
|
||||
assert len(query_token_probabilities) == len(query_token_distribution_map)
|
||||
|
||||
copy_switch = prediction.copy_switch.cpu().data.numpy()
|
||||
|
||||
# Merge the two
|
||||
probabilities = ((np.array(probabilities) * (1 - copy_switch)).tolist() +
|
||||
(np.array(query_token_probabilities) * copy_switch).tolist()
|
||||
)
|
||||
distribution_map = distribution_map + query_token_distribution_map
|
||||
assert len(probabilities) == len(distribution_map)
|
||||
|
||||
# Get a new probabilities and distribution_map consolidating duplicates
|
||||
distribution_map, probabilities = flatten_distribution(distribution_map, probabilities)
|
||||
|
||||
# Modify the probability distribution so that the UNK token can never be produced
|
||||
probabilities[distribution_map.index(UNK_TOK)] = 0.
|
||||
|
||||
return prediction, distribution_map, probabilities, decoder_states
|
|
@ -0,0 +1,100 @@
|
|||
""" Embedder for tokens. """
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import data_util.snippets as snippet_handler
|
||||
import data_util.vocabulary as vocabulary_handler
|
||||
|
||||
class Embedder(torch.nn.Module):
|
||||
""" Embeds tokens. """
|
||||
def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True):
|
||||
super().__init__()
|
||||
|
||||
if vocabulary:
|
||||
assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \
|
||||
str(num_tokens)
|
||||
self.in_vocabulary = lambda token: token in vocabulary.tokens
|
||||
self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token)
|
||||
if use_unk:
|
||||
self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK)
|
||||
else:
|
||||
self.unknown_token_id = -1
|
||||
self.vocabulary_size = len(vocabulary)
|
||||
else:
|
||||
def check_vocab(index):
|
||||
""" Makes sure the index is in the vocabulary."""
|
||||
assert index < num_tokens, "Passed token ID " + \
|
||||
str(index) + "; expecting something less than " + str(num_tokens)
|
||||
return index < num_tokens
|
||||
self.in_vocabulary = check_vocab
|
||||
self.vocab_token_lookup = lambda x: x
|
||||
self.unknown_token_id = num_tokens # Deliberately throws an error here,
|
||||
# But should crash before this
|
||||
self.vocabulary_size = num_tokens
|
||||
|
||||
self.anonymizer = anonymizer
|
||||
|
||||
emb_name = name + "-tokens"
|
||||
print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size))
|
||||
|
||||
if initializer is not None:
|
||||
word_embeddings_tensor = torch.FloatTensor(initializer)
|
||||
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze)
|
||||
else:
|
||||
init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1)
|
||||
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
|
||||
|
||||
if self.anonymizer:
|
||||
emb_name = name + "-entities"
|
||||
entity_size = len(self.anonymizer.entity_types)
|
||||
print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size))
|
||||
init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1)
|
||||
self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
|
||||
|
||||
|
||||
def forward(self, token):
|
||||
assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"
|
||||
|
||||
if self.in_vocabulary(token):
|
||||
index_list = torch.LongTensor([self.vocab_token_lookup(token)])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.token_embedding_matrix(index_list).squeeze()
|
||||
elif self.anonymizer and self.anonymizer.is_anon_tok(token):
|
||||
index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.entity_embedding_matrix(index_list).squeeze()
|
||||
else:
|
||||
index_list = torch.LongTensor([self.unknown_token_id])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.token_embedding_matrix(index_list).squeeze()
|
||||
|
||||
|
||||
def bow_snippets(token, snippets, output_embedder, input_schema):
|
||||
""" Bag of words embedding for snippets"""
|
||||
assert snippet_handler.is_snippet(token) and snippets
|
||||
|
||||
snippet_sequence = []
|
||||
for snippet in snippets:
|
||||
if snippet.name == token:
|
||||
snippet_sequence = snippet.sequence
|
||||
break
|
||||
assert snippet_sequence
|
||||
|
||||
if input_schema:
|
||||
snippet_embeddings = []
|
||||
for output_token in snippet_sequence:
|
||||
assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
|
||||
if output_embedder.in_vocabulary(output_token):
|
||||
snippet_embeddings.append(output_embedder(output_token))
|
||||
else:
|
||||
snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True))
|
||||
else:
|
||||
snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence]
|
||||
|
||||
snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size
|
||||
return torch.mean(snippet_embeddings, dim=0) # emb_size
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
""" Contains code for encoding an input sequence. """
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .torch_utils import create_multilayer_lstm_params, encode_sequence
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
""" Encodes an input sequence. """
|
||||
def __init__(self, num_layers, input_size, state_size):
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.forward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-ef")
|
||||
self.backward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-eb")
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0.):
|
||||
""" Encodes a sequence forward and backward.
|
||||
Inputs:
|
||||
forward_seq (list of str): The string forwards.
|
||||
backward_seq (list of str): The string backwards.
|
||||
f_rnns (list of dy.RNNBuilder): The forward RNNs.
|
||||
b_rnns (list of dy.RNNBuilder): The backward RNNS.
|
||||
emb_fn (dict str->dy.Expression): Embedding function for tokens in the
|
||||
sequence.
|
||||
size (int): The size of the RNNs.
|
||||
dropout_amount (float, optional): The amount of dropout to apply.
|
||||
|
||||
Returns:
|
||||
(list of dy.Expression, list of dy.Expression), list of dy.Expression,
|
||||
where the first pair is the (final cell memories, final cell states) of
|
||||
all layers, and the second list is a list of the final layer's cell
|
||||
state for all tokens in the sequence.
|
||||
"""
|
||||
forward_state, forward_outputs = encode_sequence(
|
||||
sequence,
|
||||
self.forward_lstms,
|
||||
embedder,
|
||||
dropout_amount=dropout_amount)
|
||||
|
||||
backward_state, backward_outputs = encode_sequence(
|
||||
sequence[::-1],
|
||||
self.backward_lstms,
|
||||
embedder,
|
||||
dropout_amount=dropout_amount)
|
||||
|
||||
cell_memories = []
|
||||
hidden_states = []
|
||||
for i in range(self.num_layers):
|
||||
cell_memories.append(torch.cat([forward_state[0][i], backward_state[0][i]], dim=0))
|
||||
hidden_states.append(torch.cat([forward_state[1][i], backward_state[1][i]], dim=0))
|
||||
|
||||
assert len(forward_outputs) == len(backward_outputs)
|
||||
|
||||
backward_outputs = backward_outputs[::-1]
|
||||
|
||||
final_outputs = []
|
||||
for i in range(len(sequence)):
|
||||
final_outputs.append(torch.cat([forward_outputs[i], backward_outputs[i]], dim=0))
|
||||
|
||||
return (cell_memories, hidden_states), final_outputs
|
|
@ -0,0 +1,394 @@
|
|||
""" Class for the Sequence to sequence model for ATIS."""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
from . import utils_bert
|
||||
|
||||
from data_util.vocabulary import DEL_TOK, UNK_TOK
|
||||
|
||||
from .encoder import Encoder
|
||||
from .embedder import Embedder
|
||||
from .token_predictor import construct_token_predictor
|
||||
|
||||
import numpy as np
|
||||
|
||||
from data_util.atis_vocab import ATISVocabulary
|
||||
|
||||
def get_token_indices(token, index_to_token):
|
||||
""" Maps from a gold token (string) to a list of indices.
|
||||
|
||||
Inputs:
|
||||
token (string): String to look up.
|
||||
index_to_token (list of tokens): Ordered list of tokens.
|
||||
|
||||
Returns:
|
||||
list of int, representing the indices of the token in the probability
|
||||
distribution.
|
||||
"""
|
||||
if token in index_to_token:
|
||||
if len(set(index_to_token)) == len(index_to_token): # no duplicates
|
||||
return [index_to_token.index(token)]
|
||||
else:
|
||||
indices = []
|
||||
for index, other_token in enumerate(index_to_token):
|
||||
if token == other_token:
|
||||
indices.append(index)
|
||||
assert len(indices) == len(set(indices))
|
||||
return indices
|
||||
else:
|
||||
return [index_to_token.index(UNK_TOK)]
|
||||
|
||||
def flatten_utterances(utterances):
|
||||
""" Gets a flat sequence from a sequence of utterances.
|
||||
|
||||
Inputs:
|
||||
utterances (list of list of str): Utterances to concatenate.
|
||||
|
||||
Returns:
|
||||
list of str, representing the flattened sequence with separating
|
||||
delimiter tokens.
|
||||
"""
|
||||
sequence = []
|
||||
for i, utterance in enumerate(utterances):
|
||||
sequence.extend(utterance)
|
||||
if i < len(utterances) - 1:
|
||||
sequence.append(DEL_TOK)
|
||||
|
||||
return sequence
|
||||
|
||||
def encode_snippets_with_states(snippets, states):
|
||||
""" Encodes snippets by using previous query states instead.
|
||||
|
||||
Inputs:
|
||||
snippets (list of Snippet): Input snippets.
|
||||
states (list of dy.Expression): Previous hidden states to use.
|
||||
TODO: should this by dy.Expression or vector values?
|
||||
"""
|
||||
for snippet in snippets:
|
||||
snippet.set_embedding(torch.cat([states[snippet.startpos],states[snippet.endpos]], dim=0))
|
||||
return snippets
|
||||
|
||||
def load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params):
|
||||
print(output_vocabulary.inorder_tokens)
|
||||
print()
|
||||
|
||||
def read_glove_embedding(embedding_filename, embedding_size):
|
||||
glove_embeddings = {}
|
||||
|
||||
with open(embedding_filename) as f:
|
||||
cnt = 1
|
||||
for line in f:
|
||||
cnt += 1
|
||||
if params.debug or not params.train:
|
||||
if cnt == 1000:
|
||||
print('Read 1000 word embeddings')
|
||||
break
|
||||
l_split = line.split()
|
||||
word = " ".join(l_split[0:len(l_split) - embedding_size])
|
||||
embedding = np.array([float(val) for val in l_split[-embedding_size:]])
|
||||
glove_embeddings[word] = embedding
|
||||
|
||||
return glove_embeddings
|
||||
|
||||
print('Loading Glove Embedding from', params.embedding_filename)
|
||||
glove_embedding_size = 300
|
||||
# glove_embeddings = read_glove_embedding(params.embedding_filename, glove_embedding_size)
|
||||
import pickle as pkl
|
||||
# pkl.dump(glove_embeddings, open("glove_embeddings.pkl", "wb"))
|
||||
# exit()
|
||||
glove_embeddings = pkl.load(open("glove_embeddings.pkl", "rb"))
|
||||
print('Done')
|
||||
|
||||
input_embedding_size = glove_embedding_size
|
||||
|
||||
def create_word_embeddings(vocab):
|
||||
vocabulary_embeddings = np.zeros((len(vocab), glove_embedding_size), dtype=np.float32)
|
||||
vocabulary_tokens = vocab.inorder_tokens
|
||||
|
||||
glove_oov = 0
|
||||
para_oov = 0
|
||||
for token in vocabulary_tokens:
|
||||
token_id = vocab.token_to_id(token)
|
||||
if token in glove_embeddings:
|
||||
vocabulary_embeddings[token_id][:glove_embedding_size] = glove_embeddings[token]
|
||||
else:
|
||||
glove_oov += 1
|
||||
|
||||
print('Glove OOV:', glove_oov, 'Para OOV', para_oov, 'Total', len(vocab))
|
||||
|
||||
return vocabulary_embeddings
|
||||
|
||||
input_vocabulary_embeddings = create_word_embeddings(input_vocabulary)
|
||||
output_vocabulary_embeddings = create_word_embeddings(output_vocabulary)
|
||||
output_vocabulary_schema_embeddings = None
|
||||
if output_vocabulary_schema:
|
||||
output_vocabulary_schema_embeddings = create_word_embeddings(output_vocabulary_schema)
|
||||
del glove_embeddings
|
||||
return input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size
|
||||
|
||||
class ATISModel(torch.nn.Module):
|
||||
""" Sequence-to-sequence model for predicting a SQL query given an utterance
|
||||
and an interaction prefix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
input_vocabulary,
|
||||
output_vocabulary,
|
||||
output_vocabulary_schema,
|
||||
anonymizer):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
|
||||
if params.use_bert:
|
||||
self.model_bert, self.tokenizer, self.bert_config = utils_bert.get_bert(params)
|
||||
|
||||
if 'atis' not in params.data_directory:
|
||||
if params.use_bert:
|
||||
input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size = load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params)
|
||||
|
||||
# Create the output embeddings
|
||||
self.output_embedder = Embedder(params.output_embedding_size,
|
||||
name="output-embedding",
|
||||
initializer=output_vocabulary_embeddings,
|
||||
vocabulary=output_vocabulary,
|
||||
anonymizer=anonymizer,
|
||||
freeze=False)
|
||||
self.column_name_token_embedder = None
|
||||
else:
|
||||
input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size = load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params)
|
||||
|
||||
params.input_embedding_size = input_embedding_size
|
||||
self.params.input_embedding_size = input_embedding_size
|
||||
|
||||
# Create the input embeddings
|
||||
self.input_embedder = Embedder(params.input_embedding_size,
|
||||
name="input-embedding",
|
||||
initializer=input_vocabulary_embeddings,
|
||||
vocabulary=input_vocabulary,
|
||||
anonymizer=anonymizer,
|
||||
freeze=params.freeze)
|
||||
|
||||
# Create the output embeddings
|
||||
self.output_embedder = Embedder(params.output_embedding_size,
|
||||
name="output-embedding",
|
||||
initializer=output_vocabulary_embeddings,
|
||||
vocabulary=output_vocabulary,
|
||||
anonymizer=anonymizer,
|
||||
freeze=False)
|
||||
|
||||
self.column_name_token_embedder = Embedder(params.input_embedding_size,
|
||||
name="schema-embedding",
|
||||
initializer=output_vocabulary_schema_embeddings,
|
||||
vocabulary=output_vocabulary_schema,
|
||||
anonymizer=anonymizer,
|
||||
freeze=params.freeze)
|
||||
else:
|
||||
# Create the input embeddings
|
||||
self.input_embedder = Embedder(params.input_embedding_size,
|
||||
name="input-embedding",
|
||||
vocabulary=input_vocabulary,
|
||||
anonymizer=anonymizer,
|
||||
freeze=False)
|
||||
|
||||
# Create the output embeddings
|
||||
self.output_embedder = Embedder(params.output_embedding_size,
|
||||
name="output-embedding",
|
||||
vocabulary=output_vocabulary,
|
||||
anonymizer=anonymizer,
|
||||
freeze=False)
|
||||
|
||||
self.column_name_token_embedder = None
|
||||
|
||||
# Create the encoder
|
||||
encoder_input_size = params.input_embedding_size
|
||||
encoder_output_size = params.encoder_state_size
|
||||
if params.use_bert:
|
||||
encoder_input_size = self.bert_config.hidden_size
|
||||
|
||||
if params.discourse_level_lstm:
|
||||
encoder_input_size += params.encoder_state_size / 2
|
||||
|
||||
self.utterance_encoder = Encoder(params.encoder_num_layers, encoder_input_size, encoder_output_size)
|
||||
|
||||
# Positional embedder for utterances
|
||||
attention_key_size = params.encoder_state_size
|
||||
self.schema_attention_key_size = attention_key_size
|
||||
if params.state_positional_embeddings:
|
||||
attention_key_size += params.positional_embedding_size
|
||||
self.positional_embedder = Embedder(
|
||||
params.positional_embedding_size,
|
||||
name="positional-embedding",
|
||||
num_tokens=params.maximum_utterances)
|
||||
|
||||
self.utterance_attention_key_size = attention_key_size
|
||||
|
||||
|
||||
# Create the discourse-level LSTM parameters
|
||||
if params.discourse_level_lstm:
|
||||
self.discourse_lstms = torch_utils.create_multilayer_lstm_params(1, params.encoder_state_size, params.encoder_state_size / 2, "LSTM-t")
|
||||
self.initial_discourse_state = torch_utils.add_params(tuple([params.encoder_state_size / 2]), "V-turn-state-0")
|
||||
|
||||
# Snippet encoder
|
||||
final_snippet_size = 0
|
||||
if params.use_snippets and not params.previous_decoder_snippet_encoding:
|
||||
snippet_encoding_size = int(params.encoder_state_size / 2)
|
||||
final_snippet_size = params.encoder_state_size
|
||||
if params.snippet_age_embedding:
|
||||
snippet_encoding_size -= int(
|
||||
params.snippet_age_embedding_size / 4)
|
||||
self.snippet_age_embedder = Embedder(
|
||||
params.snippet_age_embedding_size,
|
||||
name="snippet-age-embedding",
|
||||
num_tokens=params.max_snippet_age_embedding)
|
||||
final_snippet_size = params.encoder_state_size + params.snippet_age_embedding_size / 2
|
||||
|
||||
|
||||
self.snippet_encoder = Encoder(params.snippet_num_layers,
|
||||
params.output_embedding_size,
|
||||
snippet_encoding_size)
|
||||
|
||||
# Previous query Encoder
|
||||
if params.use_previous_query:
|
||||
self.query_encoder = Encoder(params.encoder_num_layers, params.output_embedding_size, params.encoder_state_size)
|
||||
|
||||
self.final_snippet_size = final_snippet_size
|
||||
self.dropout = 0.
|
||||
|
||||
def _encode_snippets(self, previous_query, snippets, input_schema):
|
||||
""" Computes a single vector representation for each snippet.
|
||||
|
||||
Inputs:
|
||||
previous_query (list of str): Previous query in the interaction.
|
||||
snippets (list of Snippet): Snippets extracted from the previous
|
||||
|
||||
Returns:
|
||||
list of Snippets, where the embedding is set to a vector.
|
||||
"""
|
||||
startpoints = [snippet.startpos for snippet in snippets]
|
||||
endpoints = [snippet.endpos for snippet in snippets]
|
||||
assert len(startpoints) == 0 or min(startpoints) >= 0
|
||||
if input_schema:
|
||||
assert len(endpoints) == 0 or max(endpoints) <= len(previous_query)
|
||||
else:
|
||||
assert len(endpoints) == 0 or max(endpoints) < len(previous_query)
|
||||
|
||||
snippet_embedder = lambda query_token: self.get_query_token_embedding(query_token, input_schema)
|
||||
if previous_query and snippets:
|
||||
_, previous_outputs = self.snippet_encoder(
|
||||
previous_query, snippet_embedder, dropout_amount=self.dropout)
|
||||
assert len(previous_outputs) == len(previous_query)
|
||||
|
||||
for snippet in snippets:
|
||||
if input_schema:
|
||||
embedding = torch.cat([previous_outputs[snippet.startpos],previous_outputs[snippet.endpos-1]], dim=0)
|
||||
else:
|
||||
embedding = torch.cat([previous_outputs[snippet.startpos],previous_outputs[snippet.endpos]], dim=0)
|
||||
if self.params.snippet_age_embedding:
|
||||
embedding = torch.cat([embedding, self.snippet_age_embedder(min(snippet.age, self.params.max_snippet_age_embedding - 1))], dim=0)
|
||||
snippet.set_embedding(embedding)
|
||||
|
||||
return snippets
|
||||
|
||||
def _initialize_discourse_states(self):
|
||||
discourse_state = self.initial_discourse_state
|
||||
|
||||
discourse_lstm_states = []
|
||||
for lstm in self.discourse_lstms:
|
||||
hidden_size = lstm.weight_hh.size()[1]
|
||||
if lstm.weight_hh.is_cuda:
|
||||
h_0 = torch.cuda.FloatTensor(1,hidden_size).fill_(0)
|
||||
c_0 = torch.cuda.FloatTensor(1,hidden_size).fill_(0)
|
||||
else:
|
||||
h_0 = torch.zeros(1,hidden_size)
|
||||
c_0 = torch.zeros(1,hidden_size)
|
||||
discourse_lstm_states.append((h_0, c_0))
|
||||
|
||||
return discourse_state, discourse_lstm_states
|
||||
|
||||
def _add_positional_embeddings(self, hidden_states, utterances, group=False):
|
||||
grouped_states = []
|
||||
|
||||
start_index = 0
|
||||
for utterance in utterances:
|
||||
grouped_states.append(hidden_states[start_index:start_index + len(utterance)])
|
||||
start_index += len(utterance)
|
||||
assert len(hidden_states) == sum([len(seq) for seq in grouped_states]) == sum([len(utterance) for utterance in utterances])
|
||||
|
||||
new_states = []
|
||||
flat_sequence = []
|
||||
|
||||
num_utterances_to_keep = min(self.params.maximum_utterances, len(utterances))
|
||||
for i, (states, utterance) in enumerate(zip(
|
||||
grouped_states[-num_utterances_to_keep:], utterances[-num_utterances_to_keep:])):
|
||||
positional_sequence = []
|
||||
index = num_utterances_to_keep - i - 1
|
||||
|
||||
for state in states:
|
||||
positional_sequence.append(torch.cat([state, self.positional_embedder(index)], dim=0))
|
||||
|
||||
assert len(positional_sequence) == len(utterance), \
|
||||
"Expected utterance and state sequence length to be the same, " \
|
||||
+ "but they were " + str(len(utterance)) \
|
||||
+ " and " + str(len(positional_sequence))
|
||||
|
||||
if group:
|
||||
new_states.append(positional_sequence)
|
||||
else:
|
||||
new_states.extend(positional_sequence)
|
||||
flat_sequence.extend(utterance)
|
||||
|
||||
return new_states, flat_sequence
|
||||
|
||||
def build_optim(self):
|
||||
params_trainer = []
|
||||
params_bert_trainer = []
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad:
|
||||
if 'model_bert' in name:
|
||||
params_bert_trainer.append(param)
|
||||
else:
|
||||
params_trainer.append(param)
|
||||
self.trainer = torch.optim.Adam(params_trainer, lr=self.params.initial_learning_rate)
|
||||
if self.params.fine_tune_bert:
|
||||
self.bert_trainer = torch.optim.Adam(params_bert_trainer, lr=self.params.lr_bert)
|
||||
|
||||
def set_dropout(self, value):
|
||||
""" Sets the dropout to a specified value.
|
||||
|
||||
Inputs:
|
||||
value (float): Value to set dropout to.
|
||||
"""
|
||||
self.dropout = value
|
||||
|
||||
def set_learning_rate(self, value):
|
||||
""" Sets the learning rate for the trainer.
|
||||
|
||||
Inputs:
|
||||
value (float): The new learning rate.
|
||||
"""
|
||||
for param_group in self.trainer.param_groups:
|
||||
param_group['lr'] = value
|
||||
|
||||
def save(self, filename):
|
||||
""" Saves the model to the specified filename.
|
||||
|
||||
Inputs:
|
||||
filename (str): The filename to save to.
|
||||
"""
|
||||
torch.save(self.state_dict(), filename)
|
||||
|
||||
def load(self, filename):
|
||||
""" Loads saved parameters into the parameter collection.
|
||||
|
||||
Inputs:
|
||||
filename (str): Name of file containing parameters.
|
||||
"""
|
||||
self.load_state_dict(torch.load(filename))
|
||||
print("Loaded model from file " + filename)
|
||||
|
|
@ -0,0 +1,740 @@
|
|||
""" Class for the Sequence to sequence model for ATIS."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
import data_util.snippets as snippet_handler
|
||||
import data_util.sql_util
|
||||
import data_util.vocabulary as vocab
|
||||
from data_util.vocabulary import EOS_TOK, UNK_TOK
|
||||
import data_util.tokenizers
|
||||
|
||||
from .token_predictor import construct_token_predictor
|
||||
from .attention import Attention
|
||||
from .model import ATISModel, encode_snippets_with_states, get_token_indices
|
||||
from data_util.utterance import ANON_INPUT_KEY
|
||||
|
||||
from .encoder import Encoder
|
||||
from .transformer_layer.transformer_attention import TransformerAttention
|
||||
|
||||
from .decoder import SequencePredictorWithSchema
|
||||
|
||||
from . import utils_bert
|
||||
|
||||
import data_util.atis_batch
|
||||
from postprocess_eval import read_schema, postprocess_one
|
||||
from eval_scripts.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
|
||||
LIMITED_INTERACTIONS = {"raw/atis2/12-1.1/ATIS2/TEXT/TRAIN/SRI/QS0/1": 22,
|
||||
"raw/atis3/17-1.1/ATIS3/SP_TRN/MIT/8K7/5": 14,
|
||||
"raw/atis2/12-1.1/ATIS2/TEXT/TEST/NOV92/770/5": -1}
|
||||
|
||||
END_OF_INTERACTION = {"quit", "exit", "done"}
|
||||
|
||||
|
||||
class SchemaInteractionATISModel(ATISModel):
|
||||
""" Interaction ATIS model, where an interaction is processed all at once.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
input_vocabulary,
|
||||
output_vocabulary,
|
||||
output_vocabulary_schema,
|
||||
anonymizer):
|
||||
ATISModel.__init__(
|
||||
self,
|
||||
params,
|
||||
input_vocabulary,
|
||||
output_vocabulary,
|
||||
output_vocabulary_schema,
|
||||
anonymizer)
|
||||
|
||||
if self.params.use_schema_encoder:
|
||||
# Create the schema encoder
|
||||
schema_encoder_num_layer = 1
|
||||
schema_encoder_input_size = params.input_embedding_size
|
||||
schema_encoder_state_size = params.encoder_state_size
|
||||
if params.use_bert:
|
||||
schema_encoder_input_size = self.bert_config.hidden_size
|
||||
|
||||
self.schema_encoder = Encoder(schema_encoder_num_layer, schema_encoder_input_size, schema_encoder_state_size)
|
||||
|
||||
# self-attention
|
||||
if self.params.use_schema_self_attention:
|
||||
self.schema2schema_attention_module = Attention(self.schema_attention_key_size, self.schema_attention_key_size, self.schema_attention_key_size)
|
||||
|
||||
# utterance level attention
|
||||
if self.params.use_utterance_attention:
|
||||
self.utterance_attention_module = Attention(self.params.encoder_state_size, self.params.encoder_state_size, self.params.encoder_state_size)
|
||||
|
||||
# Use attention module between input_hidden_states and schema_states
|
||||
# schema_states: self.schema_attention_key_size x len(schema)
|
||||
# input_hidden_states: self.utterance_attention_key_size x len(input)
|
||||
if params.use_encoder_attention:
|
||||
self.utterance2schema_attention_module = Attention(self.schema_attention_key_size, self.utterance_attention_key_size, self.utterance_attention_key_size)
|
||||
self.schema2utterance_attention_module = Attention(self.utterance_attention_key_size, self.schema_attention_key_size, self.schema_attention_key_size)
|
||||
|
||||
new_attention_key_size = self.schema_attention_key_size + self.utterance_attention_key_size
|
||||
|
||||
if self.params.use_transformer_relation:
|
||||
self.transoformer_attention_key_size = self.schema_attention_key_size
|
||||
self.transoformer_attention_module = TransformerAttention(self.transoformer_attention_key_size, self.transoformer_attention_key_size)
|
||||
|
||||
new_attention_key_size += self.transoformer_attention_key_size
|
||||
|
||||
self.schema_attention_key_size = new_attention_key_size
|
||||
self.utterance_attention_key_size = new_attention_key_size
|
||||
|
||||
if self.params.use_schema_encoder_2:
|
||||
self.schema_encoder_2 = Encoder(schema_encoder_num_layer, self.schema_attention_key_size, self.schema_attention_key_size)
|
||||
self.utterance_encoder_2 = Encoder(params.encoder_num_layers, self.utterance_attention_key_size, self.utterance_attention_key_size)
|
||||
|
||||
|
||||
self.token_predictor = construct_token_predictor(params,
|
||||
output_vocabulary,
|
||||
self.utterance_attention_key_size,
|
||||
self.schema_attention_key_size,
|
||||
self.final_snippet_size,
|
||||
anonymizer)
|
||||
|
||||
# Use schema_attention in decoder
|
||||
if params.use_schema_attention and params.use_query_attention:
|
||||
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size + self.schema_attention_key_size + params.encoder_state_size
|
||||
elif params.use_schema_attention:
|
||||
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size + self.schema_attention_key_size
|
||||
else:
|
||||
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size
|
||||
|
||||
self.decoder = SequencePredictorWithSchema(params, decoder_input_size, self.output_embedder, self.column_name_token_embedder, self.token_predictor)
|
||||
|
||||
table_schema_path = params.table_schema_path
|
||||
self.database_schema = read_schema(table_schema_path)
|
||||
|
||||
def predict_turn(self,
|
||||
utterance_final_state,
|
||||
input_hidden_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
gold_query=None,
|
||||
snippets=None,
|
||||
input_sequence=None,
|
||||
previous_queries=None,
|
||||
previous_query_states=None,
|
||||
input_schema=None,
|
||||
feed_gold_tokens=False,
|
||||
training=False,
|
||||
unflat_sequence=None):
|
||||
""" Gets a prediction for a single turn -- calls decoder and updates loss, etc.
|
||||
|
||||
TODO: this can probably be split into two methods, one that just predicts
|
||||
and another that computes the loss.
|
||||
"""
|
||||
predicted_sequence = []
|
||||
fed_sequence = []
|
||||
loss = None
|
||||
token_accuracy = 0.
|
||||
|
||||
if self.params.use_encoder_attention:
|
||||
schema_attention = self.utterance2schema_attention_module(torch.stack(schema_states,dim=0), input_hidden_states).vector # input_value_size x len(schema)
|
||||
utterance_attention = self.schema2utterance_attention_module(torch.stack(input_hidden_states,dim=0), schema_states).vector # schema_value_size x len(input)
|
||||
|
||||
if schema_attention.dim() == 1:
|
||||
schema_attention = schema_attention.unsqueeze(1)
|
||||
if utterance_attention.dim() == 1:
|
||||
utterance_attention = utterance_attention.unsqueeze(1)
|
||||
|
||||
if self.params.use_transformer_relation:
|
||||
utterance_transoformer_attention, schema_transoformer_attention = self.transoformer_attention_module(torch.stack(input_hidden_states,dim=0 ), torch.stack(schema_states,dim=0), input_sequence, unflat_sequence, input_schema, self.schema, dropout_amount=self.dropout)
|
||||
schema_attention = torch.cat([schema_transoformer_attention, schema_attention], dim=0)
|
||||
utterance_attention = torch.cat([utterance_transoformer_attention, utterance_attention], dim=0)
|
||||
|
||||
new_schema_states = torch.cat([torch.stack(schema_states, dim=1), schema_attention], dim=0) # (input_value_size+schema_value_size) x len(schema)
|
||||
schema_states = list(torch.split(new_schema_states, split_size_or_sections=1, dim=1))
|
||||
schema_states = [schema_state.squeeze() for schema_state in schema_states]
|
||||
|
||||
new_input_hidden_states = torch.cat([torch.stack(input_hidden_states, dim=1), utterance_attention], dim=0) # (input_value_size+schema_value_size) x len(input)
|
||||
input_hidden_states = list(torch.split(new_input_hidden_states, split_size_or_sections=1, dim=1))
|
||||
input_hidden_states = [input_hidden_state.squeeze() for input_hidden_state in input_hidden_states]
|
||||
|
||||
# bi-lstm over schema_states and input_hidden_states (embedder is an identify function)
|
||||
if self.params.use_schema_encoder_2:
|
||||
final_schema_state, schema_states = self.schema_encoder_2(schema_states, lambda x: x, dropout_amount=self.dropout)
|
||||
final_utterance_state, input_hidden_states = self.utterance_encoder_2(input_hidden_states, lambda x: x, dropout_amount=self.dropout)
|
||||
|
||||
if feed_gold_tokens:
|
||||
decoder_results = self.decoder(utterance_final_state,
|
||||
input_hidden_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
gold_sequence=gold_query,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
snippets=snippets,
|
||||
dropout_amount=self.dropout)
|
||||
|
||||
all_scores = []
|
||||
all_alignments = []
|
||||
for prediction in decoder_results.predictions:
|
||||
scores = F.softmax(prediction.scores, dim=0)
|
||||
alignments = prediction.aligned_tokens
|
||||
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
|
||||
query_scores = F.softmax(prediction.query_scores, dim=0)
|
||||
copy_switch = prediction.copy_switch
|
||||
scores = torch.cat([scores * (1 - copy_switch), query_scores * copy_switch], dim=0)
|
||||
alignments = alignments + prediction.query_tokens
|
||||
|
||||
all_scores.append(scores)
|
||||
all_alignments.append(alignments)
|
||||
|
||||
# Compute the loss
|
||||
gold_sequence = gold_query
|
||||
|
||||
loss = torch_utils.compute_loss(gold_sequence, all_scores, all_alignments, get_token_indices)
|
||||
if not training:
|
||||
predicted_sequence = torch_utils.get_seq_from_scores(all_scores, all_alignments)
|
||||
token_accuracy = torch_utils.per_token_accuracy(gold_sequence, predicted_sequence)
|
||||
fed_sequence = gold_sequence
|
||||
else:
|
||||
decoder_results = self.decoder(utterance_final_state,
|
||||
input_hidden_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
snippets=snippets,
|
||||
dropout_amount=self.dropout)
|
||||
predicted_sequence = decoder_results.sequence
|
||||
fed_sequence = predicted_sequence
|
||||
|
||||
decoder_states = [pred.decoder_state for pred in decoder_results.predictions]
|
||||
|
||||
# fed_sequence contains EOS, which we don't need when encoding snippets.
|
||||
# also ignore the first state, as it contains the BEG encoding.
|
||||
|
||||
for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
|
||||
if snippet_handler.is_snippet(token):
|
||||
snippet_length = 0
|
||||
for snippet in snippets:
|
||||
if snippet.name == token:
|
||||
snippet_length = len(snippet.sequence)
|
||||
break
|
||||
assert snippet_length > 0
|
||||
decoder_states.extend([state for _ in range(snippet_length)])
|
||||
else:
|
||||
decoder_states.append(state)
|
||||
|
||||
return (predicted_sequence,
|
||||
loss,
|
||||
token_accuracy,
|
||||
decoder_states,
|
||||
decoder_results)
|
||||
|
||||
def encode_schema_bow_simple(self, input_schema):
|
||||
schema_states = []
|
||||
for column_name in input_schema.column_names_embedder_input:
|
||||
schema_states.append(input_schema.column_name_embedder_bow(column_name, surface_form=False, column_name_token_embedder=self.column_name_token_embedder))
|
||||
input_schema.set_column_name_embeddings(schema_states)
|
||||
return schema_states
|
||||
|
||||
def encode_schema_self_attention(self, schema_states):
|
||||
schema_self_attention = self.schema2schema_attention_module(torch.stack(schema_states,dim=0), schema_states).vector
|
||||
if schema_self_attention.dim() == 1:
|
||||
schema_self_attention = schema_self_attention.unsqueeze(1)
|
||||
residual_schema_states = list(torch.split(schema_self_attention, split_size_or_sections=1, dim=1))
|
||||
residual_schema_states = [schema_state.squeeze() for schema_state in residual_schema_states]
|
||||
|
||||
new_schema_states = [schema_state+residual_schema_state for schema_state, residual_schema_state in zip(schema_states, residual_schema_states)]
|
||||
|
||||
return new_schema_states
|
||||
|
||||
def encode_schema(self, input_schema, dropout=False):
|
||||
schema_states = []
|
||||
for column_name_embedder_input in input_schema.column_names_embedder_input:
|
||||
tokens = column_name_embedder_input.split()
|
||||
|
||||
if dropout:
|
||||
final_schema_state_one, schema_states_one = self.schema_encoder(tokens, self.column_name_token_embedder, dropout_amount=self.dropout)
|
||||
else:
|
||||
final_schema_state_one, schema_states_one = self.schema_encoder(tokens, self.column_name_token_embedder)
|
||||
|
||||
# final_schema_state_one: 1 means hidden_states instead of cell_memories, -1 means last layer
|
||||
schema_states.append(final_schema_state_one[1][-1])
|
||||
|
||||
input_schema.set_column_name_embeddings(schema_states)
|
||||
|
||||
# self-attention over schema_states
|
||||
if self.params.use_schema_self_attention:
|
||||
schema_states = self.encode_schema_self_attention(schema_states)
|
||||
|
||||
return schema_states
|
||||
|
||||
def get_bert_encoding(self, input_sequence, input_schema, discourse_state, dropout):
|
||||
utterance_states, schema_token_states = utils_bert.get_bert_encoding(self.bert_config, self.model_bert, self.tokenizer, input_sequence, input_schema, bert_input_version=self.params.bert_input_version, num_out_layers_n=1, num_out_layers_h=1)
|
||||
|
||||
if self.params.discourse_level_lstm:
|
||||
utterance_token_embedder = lambda x: torch.cat([x, discourse_state], dim=0)
|
||||
else:
|
||||
utterance_token_embedder = lambda x: x
|
||||
|
||||
if dropout:
|
||||
final_utterance_state, utterance_states = self.utterance_encoder(
|
||||
utterance_states,
|
||||
utterance_token_embedder,
|
||||
dropout_amount=self.dropout)
|
||||
else:
|
||||
final_utterance_state, utterance_states = self.utterance_encoder(
|
||||
utterance_states,
|
||||
utterance_token_embedder)
|
||||
|
||||
schema_states = []
|
||||
for schema_token_states1 in schema_token_states:
|
||||
if dropout:
|
||||
final_schema_state_one, schema_states_one = self.schema_encoder(schema_token_states1, lambda x: x, dropout_amount=self.dropout)
|
||||
else:
|
||||
final_schema_state_one, schema_states_one = self.schema_encoder(schema_token_states1, lambda x: x)
|
||||
|
||||
# final_schema_state_one: 1 means hidden_states instead of cell_memories, -1 means last layer
|
||||
schema_states.append(final_schema_state_one[1][-1])
|
||||
|
||||
input_schema.set_column_name_embeddings(schema_states)
|
||||
|
||||
# self-attention over schema_states
|
||||
if self.params.use_schema_self_attention:
|
||||
schema_states = self.encode_schema_self_attention(schema_states)
|
||||
|
||||
return final_utterance_state, utterance_states, schema_states
|
||||
|
||||
def get_query_token_embedding(self, output_token, input_schema):
|
||||
if input_schema:
|
||||
if not (self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)):
|
||||
output_token = 'value'
|
||||
if self.output_embedder.in_vocabulary(output_token):
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
else:
|
||||
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
|
||||
else:
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
return output_token_embedding
|
||||
|
||||
def get_utterance_attention(self, final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep):
|
||||
# self-attention between utterance_states
|
||||
final_utterance_states_c.append(final_utterance_state[0][0])
|
||||
final_utterance_states_h.append(final_utterance_state[1][0])
|
||||
final_utterance_states_c = final_utterance_states_c[-num_utterances_to_keep:]
|
||||
final_utterance_states_h = final_utterance_states_h[-num_utterances_to_keep:]
|
||||
|
||||
attention_result = self.utterance_attention_module(final_utterance_states_c[-1], final_utterance_states_c)
|
||||
final_utterance_state_attention_c = final_utterance_states_c[-1] + attention_result.vector.squeeze()
|
||||
|
||||
attention_result = self.utterance_attention_module(final_utterance_states_h[-1], final_utterance_states_h)
|
||||
final_utterance_state_attention_h = final_utterance_states_h[-1] + attention_result.vector.squeeze()
|
||||
|
||||
final_utterance_state = ([final_utterance_state_attention_c],[final_utterance_state_attention_h])
|
||||
|
||||
return final_utterance_states_c, final_utterance_states_h, final_utterance_state
|
||||
|
||||
def get_previous_queries(self, previous_queries, previous_query_states, previous_query, input_schema):
|
||||
previous_queries.append(previous_query)
|
||||
num_queries_to_keep = min(self.params.maximum_queries, len(previous_queries))
|
||||
previous_queries = previous_queries[-num_queries_to_keep:]
|
||||
|
||||
query_token_embedder = lambda query_token: self.get_query_token_embedding(query_token, input_schema)
|
||||
_, previous_outputs = self.query_encoder(previous_query, query_token_embedder, dropout_amount=self.dropout)
|
||||
assert len(previous_outputs) == len(previous_query)
|
||||
previous_query_states.append(previous_outputs)
|
||||
previous_query_states = previous_query_states[-num_queries_to_keep:]
|
||||
|
||||
return previous_queries, previous_query_states
|
||||
|
||||
def train_step(self, interaction, max_generation_length, snippet_alignment_probability=1.):
|
||||
""" Trains the interaction-level model on a single interaction.
|
||||
|
||||
Inputs:
|
||||
interaction (Interaction): The interaction to train on.
|
||||
learning_rate (float): Learning rate to use.
|
||||
snippet_keep_age (int): Age of oldest snippets to use.
|
||||
snippet_alignment_probability (float): The probability that a snippet will
|
||||
be used in constructing the gold sequence.
|
||||
"""
|
||||
# assert self.params.discourse_level_lstm
|
||||
|
||||
losses = []
|
||||
total_gold_tokens = 0
|
||||
|
||||
input_hidden_states = []
|
||||
input_sequences = []
|
||||
|
||||
final_utterance_states_c = []
|
||||
final_utterance_states_h = []
|
||||
|
||||
previous_query_states = []
|
||||
previous_queries = []
|
||||
|
||||
decoder_states = []
|
||||
|
||||
discourse_state = None
|
||||
if self.params.discourse_level_lstm:
|
||||
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
|
||||
discourse_states = []
|
||||
|
||||
# Schema and schema embeddings
|
||||
input_schema = interaction.get_schema()
|
||||
schema_states = []
|
||||
|
||||
if input_schema and not self.params.use_bert:
|
||||
schema_states = self.encode_schema_bow_simple(input_schema)
|
||||
|
||||
identifier = interaction.identifier
|
||||
if len(identifier.split('/')) == 2:
|
||||
database_id, interaction_id = identifier.split('/')
|
||||
else:
|
||||
database_id = 'atis'
|
||||
interaction_id = identifier
|
||||
|
||||
self.schema = self.database_schema[database_id]
|
||||
|
||||
for utterance_index, utterance in enumerate(interaction.gold_utterances()):
|
||||
if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]:
|
||||
break
|
||||
|
||||
input_sequence = utterance.input_sequence()
|
||||
|
||||
available_snippets = utterance.snippets()
|
||||
previous_query = utterance.previous_query()
|
||||
|
||||
# Get the gold query: reconstruct if the alignment probability is less than one
|
||||
if snippet_alignment_probability < 1.:
|
||||
gold_query = sql_util.add_snippets_to_query(
|
||||
available_snippets,
|
||||
utterance.contained_entities(),
|
||||
utterance.anonymized_gold_query(),
|
||||
prob_align=snippet_alignment_probability) + [vocab.EOS_TOK]
|
||||
else:
|
||||
gold_query = utterance.gold_query()
|
||||
|
||||
# Encode the utterance, and update the discourse-level states
|
||||
if not self.params.use_bert:
|
||||
if self.params.discourse_level_lstm:
|
||||
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
|
||||
else:
|
||||
utterance_token_embedder = self.input_embedder
|
||||
final_utterance_state, utterance_states = self.utterance_encoder(
|
||||
input_sequence,
|
||||
utterance_token_embedder,
|
||||
dropout_amount=self.dropout)
|
||||
else:
|
||||
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=True)
|
||||
|
||||
input_hidden_states.extend(utterance_states)
|
||||
input_sequences.append(input_sequence)
|
||||
|
||||
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
|
||||
|
||||
# final_utterance_state[1][0] is the first layer's hidden states at the last time step (concat forward lstm and backward lstm)
|
||||
if self.params.discourse_level_lstm:
|
||||
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states, self.dropout)
|
||||
|
||||
if self.params.use_utterance_attention:
|
||||
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
|
||||
|
||||
if self.params.state_positional_embeddings:
|
||||
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
|
||||
else:
|
||||
flat_sequence = []
|
||||
for utt in input_sequences[-num_utterances_to_keep:]:
|
||||
flat_sequence.extend(utt)
|
||||
|
||||
snippets = None
|
||||
if self.params.use_snippets:
|
||||
if self.params.previous_decoder_snippet_encoding:
|
||||
snippets = encode_snippets_with_states(available_snippets, decoder_states)
|
||||
else:
|
||||
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
|
||||
|
||||
if self.params.use_previous_query and len(previous_query) > 0:
|
||||
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
|
||||
|
||||
if len(gold_query) <= max_generation_length and len(previous_query) <= max_generation_length:
|
||||
prediction = self.predict_turn(final_utterance_state,
|
||||
utterance_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
gold_query=gold_query,
|
||||
snippets=snippets,
|
||||
input_sequence=flat_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
feed_gold_tokens=True,
|
||||
training=True)
|
||||
loss = prediction[1]
|
||||
decoder_states = prediction[3]
|
||||
total_gold_tokens += len(gold_query)
|
||||
losses.append(loss)
|
||||
print('<<', loss.item())
|
||||
else:
|
||||
# Break if previous decoder snippet encoding -- because the previous
|
||||
# sequence was too long to run the decoder.
|
||||
if self.params.previous_decoder_snippet_encoding:
|
||||
break
|
||||
continue
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if losses:
|
||||
average_loss = torch.sum(torch.stack(losses)) / total_gold_tokens
|
||||
if average_loss.item() > 80.0:
|
||||
print("Loss is too large", average_loss.item())
|
||||
loss_scalar = 0.0
|
||||
else:
|
||||
|
||||
# Renormalize so the effect is normalized by the batch size.
|
||||
normalized_loss = average_loss
|
||||
if self.params.reweight_batch:
|
||||
normalized_loss = len(losses) * average_loss / float(self.params.batch_size)
|
||||
|
||||
normalized_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=10, norm_type=2)
|
||||
self.trainer.step()
|
||||
if self.params.fine_tune_bert:
|
||||
self.bert_trainer.step()
|
||||
self.zero_grad()
|
||||
|
||||
loss_scalar = normalized_loss.item()
|
||||
else:
|
||||
loss_scalar = 0.
|
||||
|
||||
return loss_scalar
|
||||
|
||||
def predict_with_predicted_queries(self, interaction, max_generation_length, syntax_restrict=True):
|
||||
""" Predicts an interaction, using the predicted queries to get snippets."""
|
||||
# assert self.params.discourse_level_lstm
|
||||
|
||||
syntax_restrict=False
|
||||
|
||||
predictions = []
|
||||
|
||||
input_hidden_states = []
|
||||
input_sequences = []
|
||||
|
||||
final_utterance_states_c = []
|
||||
final_utterance_states_h = []
|
||||
|
||||
previous_query_states = []
|
||||
previous_queries = []
|
||||
|
||||
discourse_state = None
|
||||
if self.params.discourse_level_lstm:
|
||||
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
|
||||
discourse_states = []
|
||||
|
||||
# Schema and schema embeddings
|
||||
input_schema = interaction.get_schema()
|
||||
schema_states = []
|
||||
|
||||
if input_schema and not self.params.use_bert:
|
||||
schema_states = self.encode_schema_bow_simple(input_schema)
|
||||
|
||||
interaction.start_interaction()
|
||||
while not interaction.done():
|
||||
utterance = interaction.next_utterance()
|
||||
|
||||
available_snippets = utterance.snippets()
|
||||
previous_query = utterance.previous_query()
|
||||
|
||||
input_sequence = utterance.input_sequence()
|
||||
|
||||
if not self.params.use_bert:
|
||||
if self.params.discourse_level_lstm:
|
||||
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
|
||||
else:
|
||||
utterance_token_embedder = self.input_embedder
|
||||
final_utterance_state, utterance_states = self.utterance_encoder(
|
||||
input_sequence,
|
||||
utterance_token_embedder)
|
||||
else:
|
||||
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=False)
|
||||
|
||||
input_hidden_states.extend(utterance_states)
|
||||
input_sequences.append(input_sequence)
|
||||
|
||||
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
|
||||
|
||||
if self.params.discourse_level_lstm:
|
||||
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states)
|
||||
|
||||
if self.params.use_utterance_attention:
|
||||
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
|
||||
|
||||
if self.params.state_positional_embeddings:
|
||||
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
|
||||
else:
|
||||
flat_sequence = []
|
||||
for utt in input_sequences[-num_utterances_to_keep:]:
|
||||
flat_sequence.extend(utt)
|
||||
|
||||
snippets = None
|
||||
if self.params.use_snippets:
|
||||
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
|
||||
|
||||
if self.params.use_previous_query and len(previous_query) > 0:
|
||||
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
|
||||
|
||||
identifier = interaction.identifier
|
||||
if len(identifier.split('/')) == 2:
|
||||
database_id, interaction_id = identifier.split('/')
|
||||
else:
|
||||
database_id = 'atis'
|
||||
interaction_id = identifier
|
||||
|
||||
self.schema = self.database_schema[database_id]
|
||||
|
||||
results = self.predict_turn(final_utterance_state,
|
||||
utterance_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
input_sequence=flat_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
snippets=snippets)
|
||||
|
||||
predicted_sequence = results[0]
|
||||
predictions.append(results)
|
||||
|
||||
# Update things necessary for using predicted queries
|
||||
anonymized_sequence = utterance.remove_snippets(predicted_sequence)
|
||||
if EOS_TOK in anonymized_sequence:
|
||||
anonymized_sequence = anonymized_sequence[:-1] # Remove _EOS
|
||||
else:
|
||||
anonymized_sequence = ['select', '*', 'from', 't1']
|
||||
|
||||
if not syntax_restrict:
|
||||
utterance.set_predicted_query(interaction.remove_snippets(predicted_sequence))
|
||||
if input_schema:
|
||||
# on SParC
|
||||
interaction.add_utterance(utterance, anonymized_sequence, previous_snippets=utterance.snippets(), simple=True)
|
||||
else:
|
||||
# on ATIS
|
||||
interaction.add_utterance(utterance, anonymized_sequence, previous_snippets=utterance.snippets(), simple=False)
|
||||
else:
|
||||
utterance.set_predicted_query(utterance.previous_query())
|
||||
interaction.add_utterance(utterance, utterance.previous_query(), previous_snippets=utterance.snippets())
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def predict_with_gold_queries(self, interaction, max_generation_length, feed_gold_query=False):
|
||||
""" Predicts SQL queries for an interaction.
|
||||
|
||||
Inputs:
|
||||
interaction (Interaction): Interaction to predict for.
|
||||
feed_gold_query (bool): Whether or not to feed the gold token to the
|
||||
generation step.
|
||||
"""
|
||||
# assert self.params.discourse_level_lstm
|
||||
|
||||
predictions = []
|
||||
|
||||
input_hidden_states = []
|
||||
input_sequences = []
|
||||
|
||||
final_utterance_states_c = []
|
||||
final_utterance_states_h = []
|
||||
|
||||
previous_query_states = []
|
||||
previous_queries = []
|
||||
|
||||
decoder_states = []
|
||||
|
||||
discourse_state = None
|
||||
if self.params.discourse_level_lstm:
|
||||
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
|
||||
discourse_states = []
|
||||
|
||||
# Schema and schema embeddings
|
||||
input_schema = interaction.get_schema()
|
||||
schema_states = []
|
||||
if input_schema and not self.params.use_bert:
|
||||
schema_states = self.encode_schema_bow_simple(input_schema)
|
||||
|
||||
|
||||
for utterance in interaction.gold_utterances():
|
||||
input_sequence = utterance.input_sequence()
|
||||
|
||||
available_snippets = utterance.snippets()
|
||||
previous_query = utterance.previous_query()
|
||||
|
||||
# Encode the utterance, and update the discourse-level states
|
||||
if not self.params.use_bert:
|
||||
if self.params.discourse_level_lstm:
|
||||
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
|
||||
else:
|
||||
utterance_token_embedder = self.input_embedder
|
||||
final_utterance_state, utterance_states = self.utterance_encoder(
|
||||
input_sequence,
|
||||
utterance_token_embedder,
|
||||
dropout_amount=self.dropout)
|
||||
else:
|
||||
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=True)
|
||||
|
||||
input_hidden_states.extend(utterance_states)
|
||||
input_sequences.append(input_sequence)
|
||||
|
||||
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
|
||||
|
||||
if self.params.discourse_level_lstm:
|
||||
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states, self.dropout)
|
||||
|
||||
if self.params.use_utterance_attention:
|
||||
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
|
||||
|
||||
if self.params.state_positional_embeddings:
|
||||
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
|
||||
else:
|
||||
flat_sequence = []
|
||||
for utt in input_sequences[-num_utterances_to_keep:]:
|
||||
flat_sequence.extend(utt)
|
||||
|
||||
snippets = None
|
||||
if self.params.use_snippets:
|
||||
if self.params.previous_decoder_snippet_encoding:
|
||||
snippets = encode_snippets_with_states(available_snippets, decoder_states)
|
||||
else:
|
||||
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
|
||||
|
||||
if self.params.use_previous_query and len(previous_query) > 0:
|
||||
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
|
||||
|
||||
identifier = interaction.identifier
|
||||
if len(identifier.split('/')) == 2:
|
||||
database_id, interaction_id = identifier.split('/')
|
||||
else:
|
||||
database_id = 'atis'
|
||||
interaction_id = identifier
|
||||
|
||||
self.schema = self.database_schema[database_id]
|
||||
|
||||
prediction = self.predict_turn(final_utterance_state,
|
||||
utterance_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
gold_query=utterance.gold_query(),
|
||||
snippets=snippets,
|
||||
input_sequence=flat_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
feed_gold_tokens=feed_gold_query)
|
||||
|
||||
decoder_states = prediction[3]
|
||||
predictions.append(prediction)
|
||||
|
||||
return predictions
|
|
@ -0,0 +1,476 @@
|
|||
"""Predicts a token."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
from .attention import Attention, AttentionResult
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class PredictionInput(namedtuple('PredictionInput',
|
||||
('decoder_state',
|
||||
'input_hidden_states',
|
||||
'snippets',
|
||||
'input_sequence'))):
|
||||
""" Inputs to the token predictor. """
|
||||
__slots__ = ()
|
||||
|
||||
class PredictionStepInputWithSchema(namedtuple('PredictionStepInputWithSchema',
|
||||
('index',
|
||||
'decoder_input',
|
||||
'decoder_states',
|
||||
'encoder_states',
|
||||
'schema_states',
|
||||
'snippets',
|
||||
'gold_sequence',
|
||||
'input_sequence',
|
||||
'previous_queries',
|
||||
'previous_query_states',
|
||||
'input_schema',
|
||||
'dropout_amount',
|
||||
'predictions'))):
|
||||
""" Inputs to the next token predictor. """
|
||||
__slots__ = ()
|
||||
|
||||
class PredictionInputWithSchema(namedtuple('PredictionInputWithSchema',
|
||||
('decoder_state',
|
||||
'input_hidden_states',
|
||||
'schema_states',
|
||||
'snippets',
|
||||
'input_sequence',
|
||||
'previous_queries',
|
||||
'previous_query_states',
|
||||
'input_schema'))):
|
||||
""" Inputs to the token predictor. """
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class TokenPrediction(namedtuple('TokenPrediction',
|
||||
('scores',
|
||||
'aligned_tokens',
|
||||
'utterance_attention_results',
|
||||
'schema_attention_results',
|
||||
'query_attention_results',
|
||||
'copy_switch',
|
||||
'query_scores',
|
||||
'query_tokens',
|
||||
'decoder_state'))):
|
||||
|
||||
"""A token prediction."""
|
||||
__slots__ = ()
|
||||
|
||||
def score_snippets(snippets, scorer):
|
||||
""" Scores snippets given a scorer.
|
||||
|
||||
Inputs:
|
||||
snippets (list of Snippet): The snippets to score.
|
||||
scorer (dy.Expression): Dynet vector against which to score the snippets.
|
||||
|
||||
Returns:
|
||||
dy.Expression, list of str, where the first is the scores and the second
|
||||
is the names of the snippets that were scored.
|
||||
"""
|
||||
snippet_expressions = [snippet.embedding for snippet in snippets]
|
||||
all_snippet_embeddings = torch.stack(snippet_expressions, dim=1)
|
||||
|
||||
scores = torch.t(torch.mm(torch.t(scorer), all_snippet_embeddings))
|
||||
|
||||
if scores.size()[0] != len(snippets):
|
||||
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(snippets)) + " snippets")
|
||||
|
||||
return scores, [snippet.name for snippet in snippets]
|
||||
|
||||
def score_schema_tokens(input_schema, schema_states, scorer):
|
||||
# schema_states: emd_dim x num_tokens
|
||||
scores = torch.t(torch.mm(torch.t(scorer), schema_states)) # num_tokens x 1
|
||||
if scores.size()[0] != len(input_schema):
|
||||
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(input_schema)) + " schema tokens")
|
||||
return scores, input_schema.column_names_surface_form
|
||||
|
||||
def score_query_tokens(previous_query, previous_query_states, scorer):
|
||||
scores = torch.t(torch.mm(torch.t(scorer), previous_query_states)) # num_tokens x 1
|
||||
if scores.size()[0] != len(previous_query):
|
||||
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(previous_query)) + " query tokens")
|
||||
return scores, previous_query
|
||||
|
||||
class TokenPredictor(torch.nn.Module):
|
||||
""" Predicts a token given a (decoder) state.
|
||||
|
||||
Attributes:
|
||||
vocabulary (Vocabulary): A vocabulary object for the output.
|
||||
attention_module (Attention): An attention module.
|
||||
state_transformation_weights (dy.Parameters): Transforms the input state
|
||||
before predicting a token.
|
||||
vocabulary_weights (dy.Parameters): Final layer weights.
|
||||
vocabulary_biases (dy.Parameters): Final layer biases.
|
||||
"""
|
||||
|
||||
def __init__(self, params, vocabulary, attention_key_size):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.vocabulary = vocabulary
|
||||
self.attention_module = Attention(params.decoder_state_size, attention_key_size, attention_key_size)
|
||||
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + attention_key_size, params.decoder_state_size), "weights-state-transform")
|
||||
self.vocabulary_weights = torch_utils.add_params((params.decoder_state_size, len(vocabulary)), "weights-vocabulary")
|
||||
self.vocabulary_biases = torch_utils.add_params(tuple([len(vocabulary)]), "biases-vocabulary")
|
||||
|
||||
def _get_intermediate_state(self, state, dropout_amount=0.):
|
||||
intermediate_state = torch.tanh(torch_utils.linear_layer(state, self.state_transform_weights))
|
||||
return F.dropout(intermediate_state, dropout_amount)
|
||||
|
||||
def _score_vocabulary_tokens(self, state):
|
||||
scores = torch.t(torch_utils.linear_layer(state, self.vocabulary_weights, self.vocabulary_biases))
|
||||
|
||||
if scores.size()[0] != len(self.vocabulary.inorder_tokens):
|
||||
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(self.vocabulary.inorder_tokens)) + " vocabulary items")
|
||||
|
||||
return scores, self.vocabulary.inorder_tokens
|
||||
|
||||
def forward(self, prediction_input, dropout_amount=0.):
|
||||
decoder_state = prediction_input.decoder_state
|
||||
input_hidden_states = prediction_input.input_hidden_states
|
||||
|
||||
attention_results = self.attention_module(decoder_state, input_hidden_states)
|
||||
|
||||
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
|
||||
|
||||
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
|
||||
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
|
||||
|
||||
return TokenPrediction(vocab_scores, vocab_tokens, attention_results, decoder_state)
|
||||
|
||||
|
||||
class SchemaTokenPredictor(TokenPredictor):
|
||||
""" Token predictor that also predicts snippets.
|
||||
|
||||
Attributes:
|
||||
snippet_weights (dy.Parameter): Weights for scoring snippets against some
|
||||
state.
|
||||
"""
|
||||
|
||||
def __init__(self, params, vocabulary, utterance_attention_key_size, schema_attention_key_size, snippet_size):
|
||||
TokenPredictor.__init__(self, params, vocabulary, utterance_attention_key_size)
|
||||
if params.use_snippets:
|
||||
if snippet_size <= 0:
|
||||
raise ValueError("Snippet size must be greater than zero; was " + str(snippet_size))
|
||||
self.snippet_weights = torch_utils.add_params((params.decoder_state_size, snippet_size), "weights-snippet")
|
||||
|
||||
if params.use_schema_attention:
|
||||
self.utterance_attention_module = self.attention_module
|
||||
self.schema_attention_module = Attention(params.decoder_state_size, schema_attention_key_size, schema_attention_key_size)
|
||||
|
||||
if self.params.use_query_attention:
|
||||
self.query_attention_module = Attention(params.decoder_state_size, params.encoder_state_size, params.encoder_state_size)
|
||||
self.start_query_attention_vector = torch_utils.add_params((params.encoder_state_size,), "start_query_attention_vector")
|
||||
|
||||
if params.use_schema_attention and self.params.use_query_attention:
|
||||
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size + params.encoder_state_size, params.decoder_state_size), "weights-state-transform")
|
||||
elif params.use_schema_attention:
|
||||
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size, params.decoder_state_size), "weights-state-transform")
|
||||
|
||||
# Use lstm schema encoder
|
||||
self.schema_token_weights = torch_utils.add_params((params.decoder_state_size, schema_attention_key_size), "weights-schema-token")
|
||||
|
||||
if self.params.use_previous_query:
|
||||
self.query_token_weights = torch_utils.add_params((params.decoder_state_size, self.params.encoder_state_size), "weights-query-token")
|
||||
|
||||
if self.params.use_copy_switch:
|
||||
if self.params.use_query_attention:
|
||||
self.state2copyswitch_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size + params.encoder_state_size, 1), "weights-state-transform")
|
||||
else:
|
||||
self.state2copyswitch_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size, 1), "weights-state-transform")
|
||||
|
||||
def _get_snippet_scorer(self, state):
|
||||
scorer = torch.t(torch_utils.linear_layer(state, self.snippet_weights))
|
||||
return scorer
|
||||
|
||||
def _get_schema_token_scorer(self, state):
|
||||
scorer = torch.t(torch_utils.linear_layer(state, self.schema_token_weights))
|
||||
return scorer
|
||||
|
||||
def _get_query_token_scorer(self, state):
|
||||
scorer = torch.t(torch_utils.linear_layer(state, self.query_token_weights))
|
||||
return scorer
|
||||
|
||||
def _get_copy_switch(self, state):
|
||||
copy_switch = torch.sigmoid(torch_utils.linear_layer(state, self.state2copyswitch_transform_weights))
|
||||
return copy_switch.squeeze()
|
||||
|
||||
def forward(self, prediction_input, dropout_amount=0.):
|
||||
decoder_state = prediction_input.decoder_state
|
||||
input_hidden_states = prediction_input.input_hidden_states
|
||||
snippets = prediction_input.snippets
|
||||
|
||||
input_schema = prediction_input.input_schema
|
||||
schema_states = prediction_input.schema_states
|
||||
|
||||
if self.params.use_schema_attention:
|
||||
schema_attention_results = self.schema_attention_module(decoder_state, schema_states)
|
||||
utterance_attention_results = self.utterance_attention_module(decoder_state, input_hidden_states)
|
||||
else:
|
||||
utterance_attention_results = self.attention_module(decoder_state, input_hidden_states)
|
||||
schema_attention_results = None
|
||||
|
||||
query_attention_results = None
|
||||
if self.params.use_query_attention:
|
||||
previous_query_states = prediction_input.previous_query_states
|
||||
if len(previous_query_states) > 0:
|
||||
query_attention_results = self.query_attention_module(decoder_state, previous_query_states[-1])
|
||||
else:
|
||||
query_attention_results = self.start_query_attention_vector
|
||||
query_attention_results = AttentionResult(None, None, query_attention_results)
|
||||
|
||||
if self.params.use_schema_attention and self.params.use_query_attention:
|
||||
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector, schema_attention_results.vector, query_attention_results.vector], dim=0)
|
||||
elif self.params.use_schema_attention:
|
||||
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector, schema_attention_results.vector], dim=0)
|
||||
else:
|
||||
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector], dim=0)
|
||||
|
||||
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
|
||||
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
|
||||
|
||||
final_scores = vocab_scores
|
||||
aligned_tokens = []
|
||||
aligned_tokens.extend(vocab_tokens)
|
||||
|
||||
if self.params.use_snippets and snippets:
|
||||
snippet_scores, snippet_tokens = score_snippets(snippets, self._get_snippet_scorer(intermediate_state))
|
||||
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
|
||||
aligned_tokens.extend(snippet_tokens)
|
||||
|
||||
schema_states = torch.stack(schema_states, dim=1)
|
||||
schema_scores, schema_tokens = score_schema_tokens(input_schema, schema_states, self._get_schema_token_scorer(intermediate_state))
|
||||
|
||||
final_scores = torch.cat([final_scores, schema_scores], dim=0)
|
||||
aligned_tokens.extend(schema_tokens)
|
||||
|
||||
# Previous Queries
|
||||
previous_queries = prediction_input.previous_queries
|
||||
previous_query_states = prediction_input.previous_query_states
|
||||
|
||||
copy_switch = None
|
||||
query_scores = None
|
||||
query_tokens = None
|
||||
if self.params.use_previous_query and len(previous_queries) > 0:
|
||||
if self.params.use_copy_switch:
|
||||
copy_switch = self._get_copy_switch(state_and_attn)
|
||||
for turn, (previous_query, previous_query_state) in enumerate(zip(previous_queries, previous_query_states)):
|
||||
assert len(previous_query) == len(previous_query_state)
|
||||
previous_query_state = torch.stack(previous_query_state, dim=1)
|
||||
query_scores, query_tokens = score_query_tokens(previous_query, previous_query_state, self._get_query_token_scorer(intermediate_state))
|
||||
query_scores = query_scores.squeeze()
|
||||
|
||||
final_scores = final_scores.squeeze()
|
||||
|
||||
return TokenPrediction(final_scores, aligned_tokens, utterance_attention_results, schema_attention_results, query_attention_results, copy_switch, query_scores, query_tokens, decoder_state)
|
||||
|
||||
|
||||
class SnippetTokenPredictor(TokenPredictor):
|
||||
""" Token predictor that also predicts snippets.
|
||||
|
||||
Attributes:
|
||||
snippet_weights (dy.Parameter): Weights for scoring snippets against some
|
||||
state.
|
||||
"""
|
||||
|
||||
def __init__(self, params, vocabulary, attention_key_size, snippet_size):
|
||||
TokenPredictor.__init__(self, params, vocabulary, attention_key_size)
|
||||
if snippet_size <= 0:
|
||||
raise ValueError("Snippet size must be greater than zero; was " + str(snippet_size))
|
||||
|
||||
self.snippet_weights = torch_utils.add_params((params.decoder_state_size, snippet_size), "weights-snippet")
|
||||
|
||||
def _get_snippet_scorer(self, state):
|
||||
scorer = torch.t(torch_utils.linear_layer(state, self.snippet_weights))
|
||||
|
||||
return scorer
|
||||
|
||||
def forward(self, prediction_input, dropout_amount=0.):
|
||||
decoder_state = prediction_input.decoder_state
|
||||
input_hidden_states = prediction_input.input_hidden_states
|
||||
snippets = prediction_input.snippets
|
||||
|
||||
attention_results = self.attention_module(decoder_state,
|
||||
input_hidden_states)
|
||||
|
||||
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
|
||||
|
||||
|
||||
intermediate_state = self._get_intermediate_state(
|
||||
state_and_attn, dropout_amount=dropout_amount)
|
||||
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(
|
||||
intermediate_state)
|
||||
|
||||
final_scores = vocab_scores
|
||||
aligned_tokens = []
|
||||
aligned_tokens.extend(vocab_tokens)
|
||||
|
||||
if snippets:
|
||||
snippet_scores, snippet_tokens = score_snippets(
|
||||
snippets,
|
||||
self._get_snippet_scorer(intermediate_state))
|
||||
|
||||
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
|
||||
aligned_tokens.extend(snippet_tokens)
|
||||
|
||||
final_scores = final_scores.squeeze()
|
||||
|
||||
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
|
||||
|
||||
|
||||
class AnonymizationTokenPredictor(TokenPredictor):
|
||||
""" Token predictor that also predicts anonymization tokens.
|
||||
|
||||
Attributes:
|
||||
anonymizer (Anonymizer): The anonymization object.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params, vocabulary, attention_key_size, anonymizer):
|
||||
TokenPredictor.__init__(self, params, vocabulary, attention_key_size)
|
||||
if not anonymizer:
|
||||
raise ValueError("Expected an anonymizer, but was None")
|
||||
self.anonymizer = anonymizer
|
||||
|
||||
def _score_anonymized_tokens(self,
|
||||
input_sequence,
|
||||
attention_scores):
|
||||
scores = []
|
||||
tokens = []
|
||||
for i, token in enumerate(input_sequence):
|
||||
if self.anonymizer.is_anon_tok(token):
|
||||
scores.append(attention_scores[i])
|
||||
tokens.append(token)
|
||||
|
||||
if len(scores) > 0:
|
||||
if len(scores) != len(tokens):
|
||||
raise ValueError("Got " + str(len(scores)) + " scores for "
|
||||
+ str(len(tokens)) + " anonymized tokens")
|
||||
|
||||
anonymized_scores = torch.cat(scores, dim=0)
|
||||
if anonymized_scores.dim() == 1:
|
||||
anonymized_scores = anonymized_scores.unsqueeze(1)
|
||||
return anonymized_scores, tokens
|
||||
else:
|
||||
return None, []
|
||||
|
||||
def forward(self,
|
||||
prediction_input,
|
||||
dropout_amount=0.):
|
||||
decoder_state = prediction_input.decoder_state
|
||||
input_hidden_states = prediction_input.input_hidden_states
|
||||
input_sequence = prediction_input.input_sequence
|
||||
assert input_sequence
|
||||
|
||||
attention_results = self.attention_module(decoder_state,
|
||||
input_hidden_states)
|
||||
|
||||
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
|
||||
|
||||
intermediate_state = self._get_intermediate_state(
|
||||
state_and_attn, dropout_amount=dropout_amount)
|
||||
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(
|
||||
intermediate_state)
|
||||
|
||||
final_scores = vocab_scores
|
||||
aligned_tokens = []
|
||||
aligned_tokens.extend(vocab_tokens)
|
||||
|
||||
anonymized_scores, anonymized_tokens = self._score_anonymized_tokens(
|
||||
input_sequence,
|
||||
attention_results.scores)
|
||||
|
||||
if anonymized_scores:
|
||||
final_scores = torch.cat([final_scores, anonymized_scores], dim=0)
|
||||
aligned_tokens.extend(anonymized_tokens)
|
||||
|
||||
final_scores = final_scores.squeeze()
|
||||
|
||||
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
|
||||
|
||||
|
||||
# For Atis
|
||||
class SnippetAnonymizationTokenPredictor(SnippetTokenPredictor, AnonymizationTokenPredictor):
|
||||
""" Token predictor that both anonymizes and scores snippets."""
|
||||
|
||||
def __init__(self, params, vocabulary, attention_key_size, snippet_size, anonymizer):
|
||||
AnonymizationTokenPredictor.__init__(self, params, vocabulary, attention_key_size, anonymizer)
|
||||
SnippetTokenPredictor.__init__(self, params, vocabulary, attention_key_size, snippet_size)
|
||||
|
||||
def forward(self, prediction_input, dropout_amount=0.):
|
||||
decoder_state = prediction_input.decoder_state
|
||||
assert prediction_input.input_sequence
|
||||
|
||||
snippets = prediction_input.snippets
|
||||
|
||||
input_hidden_states = prediction_input.input_hidden_states
|
||||
|
||||
attention_results = self.attention_module(decoder_state,
|
||||
prediction_input.input_hidden_states)
|
||||
|
||||
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
|
||||
|
||||
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
|
||||
|
||||
# Vocabulary tokens
|
||||
final_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
|
||||
|
||||
aligned_tokens = []
|
||||
aligned_tokens.extend(vocab_tokens)
|
||||
|
||||
# Snippets
|
||||
if snippets:
|
||||
snippet_scores, snippet_tokens = score_snippets(
|
||||
snippets,
|
||||
self._get_snippet_scorer(intermediate_state))
|
||||
|
||||
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
|
||||
aligned_tokens.extend(snippet_tokens)
|
||||
|
||||
# Anonymized tokens
|
||||
anonymized_scores, anonymized_tokens = self._score_anonymized_tokens(
|
||||
prediction_input.input_sequence,
|
||||
attention_results.scores)
|
||||
|
||||
if anonymized_scores is not None:
|
||||
final_scores = torch.cat([final_scores, anonymized_scores], dim=0)
|
||||
aligned_tokens.extend(anonymized_tokens)
|
||||
|
||||
final_scores = final_scores.squeeze()
|
||||
|
||||
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
|
||||
|
||||
def construct_token_predictor(params,
|
||||
vocabulary,
|
||||
utterance_attention_key_size,
|
||||
schema_attention_key_size,
|
||||
snippet_size,
|
||||
anonymizer=None):
|
||||
""" Constructs a token predictor given the parameters.
|
||||
|
||||
Inputs:
|
||||
parameter_collection (dy.ParameterCollection): Contains the parameters.
|
||||
params (dictionary): Contains the command line parameters/hyperparameters.
|
||||
vocabulary (Vocabulary): Vocabulary object for output generation.
|
||||
attention_key_size (int): The size of the attention keys.
|
||||
anonymizer (Anonymizer): An anonymization object.
|
||||
"""
|
||||
|
||||
|
||||
if not anonymizer and not params.previous_decoder_snippet_encoding:
|
||||
print('using SchemaTokenPredictor')
|
||||
return SchemaTokenPredictor(params, vocabulary, utterance_attention_key_size, schema_attention_key_size, snippet_size)
|
||||
elif params.use_snippets and anonymizer and not params.previous_decoder_snippet_encoding:
|
||||
print('using SnippetAnonymizationTokenPredictor')
|
||||
return SnippetAnonymizationTokenPredictor(params,
|
||||
vocabulary,
|
||||
utterance_attention_key_size,
|
||||
snippet_size,
|
||||
anonymizer)
|
||||
else:
|
||||
print('Unknown token_predictor')
|
||||
exit()
|
|
@ -0,0 +1,217 @@
|
|||
"""Contains various utility functions for Dynet models."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def linear_layer(exp, weights, biases=None):
|
||||
# exp: input as size_1 or 1 x size_1
|
||||
# weight: size_1 x size_2
|
||||
# bias: size_2
|
||||
if exp.dim() == 1:
|
||||
exp = torch.unsqueeze(exp, 0)
|
||||
assert exp.size()[1] == weights.size()[0]
|
||||
if biases is not None:
|
||||
assert weights.size()[1] == biases.size()[0]
|
||||
result = torch.mm(exp, weights) + biases
|
||||
else:
|
||||
result = torch.mm(exp, weights)
|
||||
return result
|
||||
|
||||
|
||||
def compute_loss(gold_seq,
|
||||
scores,
|
||||
index_to_token_maps,
|
||||
gold_tok_to_id,
|
||||
noise=0.00000001):
|
||||
""" Computes the loss of a gold sequence given scores.
|
||||
|
||||
Inputs:
|
||||
gold_seq (list of str): A sequence of gold tokens.
|
||||
scores (list of dy.Expression): Expressions representing the scores of
|
||||
potential output tokens for each token in gold_seq.
|
||||
index_to_token_maps (list of dict str->list of int): Maps from index in the
|
||||
sequence to a dictionary mapping from a string to a set of integers.
|
||||
gold_tok_to_id (lambda (str, str)->list of int): Maps from the gold token
|
||||
and some lookup function to the indices in the probability distribution
|
||||
where the gold token occurs.
|
||||
noise (float, optional): The amount of noise to add to the loss.
|
||||
|
||||
Returns:
|
||||
dy.Expression representing the sum of losses over the sequence.
|
||||
"""
|
||||
assert len(gold_seq) == len(scores) == len(index_to_token_maps)
|
||||
|
||||
losses = []
|
||||
for i, gold_tok in enumerate(gold_seq):
|
||||
score = scores[i]
|
||||
token_map = index_to_token_maps[i]
|
||||
|
||||
gold_indices = gold_tok_to_id(gold_tok, token_map)
|
||||
assert len(gold_indices) > 0
|
||||
noise_i = noise
|
||||
if len(gold_indices) == 1:
|
||||
noise_i = 0
|
||||
|
||||
probdist = score
|
||||
prob_of_tok = noise_i + torch.sum(probdist[gold_indices])
|
||||
losses.append(-torch.log(prob_of_tok))
|
||||
|
||||
return torch.sum(torch.stack(losses))
|
||||
|
||||
|
||||
def get_seq_from_scores(scores, index_to_token_maps):
|
||||
"""Gets the argmax sequence from a set of scores.
|
||||
|
||||
Inputs:
|
||||
scores (list of dy.Expression): Sequences of output scores.
|
||||
index_to_token_maps (list of list of str): For each output token, maps
|
||||
the index in the probability distribution to a string.
|
||||
|
||||
Returns:
|
||||
list of str, representing the argmax sequence.
|
||||
"""
|
||||
seq = []
|
||||
for score, tok_map in zip(scores, index_to_token_maps):
|
||||
# score_numpy_list = score.cpu().detach().numpy()
|
||||
score_numpy_list = score.cpu().data.numpy()
|
||||
assert score.size()[0] == len(tok_map) == len(list(score_numpy_list))
|
||||
seq.append(tok_map[np.argmax(score_numpy_list)])
|
||||
return seq
|
||||
|
||||
def per_token_accuracy(gold_seq, pred_seq):
|
||||
""" Returns the per-token accuracy comparing two strings (recall).
|
||||
|
||||
Inputs:
|
||||
gold_seq (list of str): A list of gold tokens.
|
||||
pred_seq (list of str): A list of predicted tokens.
|
||||
|
||||
Returns:
|
||||
float, representing the accuracy.
|
||||
"""
|
||||
num_correct = 0
|
||||
for i, gold_token in enumerate(gold_seq):
|
||||
if i < len(pred_seq) and pred_seq[i] == gold_token:
|
||||
num_correct += 1
|
||||
|
||||
return float(num_correct) / len(gold_seq)
|
||||
|
||||
def forward_one_multilayer(rnns, lstm_input, layer_states, dropout_amount=0.):
|
||||
""" Goes forward for one multilayer RNN cell step.
|
||||
|
||||
Inputs:
|
||||
lstm_input (dy.Expression): Some input to the step.
|
||||
layer_states (list of dy.RNNState): The states of each layer in the cell.
|
||||
dropout_amount (float, optional): The amount of dropout to apply, in
|
||||
between the layers.
|
||||
|
||||
Returns:
|
||||
(list of dy.Expression, list of dy.Expression), dy.Expression, (list of dy.RNNSTate),
|
||||
representing (each layer's cell memory, each layer's cell hidden state),
|
||||
the final hidden state, and (each layer's updated RNNState).
|
||||
"""
|
||||
num_layers = len(layer_states)
|
||||
new_states = []
|
||||
cell_states = []
|
||||
hidden_states = []
|
||||
state = lstm_input
|
||||
for i in range(num_layers):
|
||||
# view as (1, input_size)
|
||||
layer_h, layer_c = rnns[i](torch.unsqueeze(state,0), layer_states[i])
|
||||
new_states.append((layer_h, layer_c))
|
||||
|
||||
layer_h = layer_h.squeeze()
|
||||
layer_c = layer_c.squeeze()
|
||||
|
||||
state = layer_h
|
||||
if i < num_layers - 1:
|
||||
# In both Dynet and Pytorch
|
||||
# p stands for probability of an element to be zeroed. i.e. p=1 means switch off all activations.
|
||||
state = F.dropout(state, p=dropout_amount)
|
||||
|
||||
cell_states.append(layer_c)
|
||||
hidden_states.append(layer_h)
|
||||
|
||||
return (cell_states, hidden_states), state, new_states
|
||||
|
||||
|
||||
def encode_sequence(sequence, rnns, embedder, dropout_amount=0.):
|
||||
""" Encodes a sequence given RNN cells and an embedding function.
|
||||
|
||||
Inputs:
|
||||
seq (list of str): The sequence to encode.
|
||||
rnns (list of dy._RNNBuilder): The RNNs to use.
|
||||
emb_fn (dict str->dy.Expression): Function that embeds strings to
|
||||
word vectors.
|
||||
size (int): The size of the RNN.
|
||||
dropout_amount (float, optional): The amount of dropout to apply.
|
||||
|
||||
Returns:
|
||||
(list of dy.Expression, list of dy.Expression), list of dy.Expression,
|
||||
where the first pair is the (final cell memories, final cell states) of
|
||||
all layers, and the second list is a list of the final layer's cell
|
||||
state for all tokens in the sequence.
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
layer_states = []
|
||||
for rnn in rnns:
|
||||
hidden_size = rnn.weight_hh.size()[1]
|
||||
|
||||
# h_0 of shape (batch, hidden_size)
|
||||
# c_0 of shape (batch, hidden_size)
|
||||
if rnn.weight_hh.is_cuda:
|
||||
h_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0)
|
||||
c_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0)
|
||||
else:
|
||||
h_0 = torch.zeros(batch_size,hidden_size)
|
||||
c_0 = torch.zeros(batch_size,hidden_size)
|
||||
|
||||
layer_states.append((h_0, c_0))
|
||||
|
||||
outputs = []
|
||||
for token in sequence:
|
||||
rnn_input = embedder(token)
|
||||
(cell_states, hidden_states), output, layer_states = forward_one_multilayer(rnns,rnn_input,layer_states,dropout_amount)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
return (cell_states, hidden_states), outputs
|
||||
|
||||
def create_multilayer_lstm_params(num_layers, in_size, state_size, name=""):
|
||||
""" Adds a multilayer LSTM to the model parameters.
|
||||
|
||||
Inputs:
|
||||
num_layers (int): Number of layers to create.
|
||||
in_size (int): The input size to the first layer.
|
||||
state_size (int): The size of the states.
|
||||
model (dy.ParameterCollection): The parameter collection for the model.
|
||||
name (str, optional): The name of the multilayer LSTM.
|
||||
"""
|
||||
lstm_layers = []
|
||||
for i in range(num_layers):
|
||||
layer_name = name + "-" + str(i)
|
||||
# print("LSTM " + layer_name + ": " + str(in_size) + " x " + str(state_size) + "; default Dynet initialization of hidden weights")
|
||||
lstm_layer = torch.nn.LSTMCell(input_size=int(in_size), hidden_size=int(state_size), bias=True)
|
||||
lstm_layers.append(lstm_layer)
|
||||
in_size = state_size
|
||||
return torch.nn.ModuleList(lstm_layers)
|
||||
|
||||
def add_params(size, name=""):
|
||||
""" Adds parameters to the model.
|
||||
|
||||
Inputs:
|
||||
model (dy.ParameterCollection): The parameter collection for the model.
|
||||
size (tuple of int): The size to create.
|
||||
name (str, optional): The name of the parameters.
|
||||
"""
|
||||
# if len(size) == 1:
|
||||
# print("vector " + name + ": " + str(size[0]) + "; uniform in [-0.1, 0.1]")
|
||||
# else:
|
||||
# print("matrix " + name + ": " + str(size[0]) + " x " + str(size[1]) + "; uniform in [-0.1, 0.1]")
|
||||
|
||||
size_int = tuple([int(ss) for ss in size])
|
||||
return torch.nn.Parameter(torch.empty(size_int).uniform_(-0.1, 0.1))
|
|
@ -0,0 +1,16 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(hidden_size) * 0.2)
|
||||
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = x.std(-1, keepdim=True)
|
||||
return self.gamma * (x-mean) / (std+self.eps) + self.beta
|
|
@ -0,0 +1,65 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""
|
||||
Each head is a self-attention operation.
|
||||
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
|
||||
"""
|
||||
def __init__(self, hidden_size, heads_num):
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.per_head_size = hidden_size // heads_num
|
||||
|
||||
self.linear_layers = nn.ModuleList([
|
||||
nn.Linear(hidden_size, hidden_size) for _ in range(3)
|
||||
])
|
||||
|
||||
self.final_linear = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, key, value, query, mask, dropout):
|
||||
"""
|
||||
Args:
|
||||
key: [batch_size x seq_length x hidden_size]
|
||||
value: [batch_size x seq_length x hidden_size]
|
||||
query: [batch_size x seq_length x hidden_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
batch_size, seq_length, hidden_size = key.size()
|
||||
heads_num = self.heads_num
|
||||
per_head_size = self.per_head_size
|
||||
|
||||
def shape(x):
|
||||
return x. \
|
||||
contiguous(). \
|
||||
view(batch_size, seq_length, heads_num, per_head_size). \
|
||||
transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
return x. \
|
||||
transpose(1, 2). \
|
||||
contiguous(). \
|
||||
view(batch_size, seq_length, hidden_size)
|
||||
|
||||
|
||||
query, key, value = [l(x). \
|
||||
view(batch_size, -1, heads_num, per_head_size). \
|
||||
transpose(1, 2) \
|
||||
for l, x in zip(self.linear_layers, (query, key, value))
|
||||
]
|
||||
|
||||
scores = torch.matmul(query, key.transpose(-2, -1))
|
||||
scores = scores / math.sqrt(float(per_head_size))
|
||||
scores = scores + mask
|
||||
probs = nn.Softmax(dim=-1)(scores)
|
||||
probs = F.dropout(probs, p=dropout)
|
||||
output = unshape(torch.matmul(probs, value))
|
||||
output = self.final_linear(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,19 @@
|
|||
import torch.nn as nn
|
||||
import math
|
||||
import torch
|
||||
|
||||
def gelu(x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
""" Feed Forward Layer """
|
||||
def __init__(self, hidden_size, feedforward_size):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.linear_1 = nn.Linear(hidden_size, feedforward_size)
|
||||
self.linear_2 = nn.Linear(feedforward_size, hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
inter = gelu(self.linear_1(x))
|
||||
output = self.linear_2(inter)
|
||||
return output
|
|
@ -0,0 +1,122 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
class RatMultiHeadedAttention(nn.Module):
|
||||
"""
|
||||
Each head is a self-attention operation.
|
||||
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
|
||||
"""
|
||||
def __init__(self, hidden_size, heads_num, relationship_number):
|
||||
super(RatMultiHeadedAttention, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.per_head_size = hidden_size // heads_num
|
||||
self.linear_layers = nn.ModuleList([
|
||||
nn.Linear(hidden_size, hidden_size) for _ in range(3)
|
||||
])
|
||||
|
||||
if relationship_number != -1:
|
||||
assert relationship_number == self.per_head_size
|
||||
self.relationship_K_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(heads_num, self.per_head_size)))
|
||||
self.relationship_V_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(heads_num, self.per_head_size)))
|
||||
# cross mutli-head
|
||||
#self.relationship_K_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(1, self.per_head_size)))
|
||||
#self.relationship_V_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(1, self.per_head_size)))
|
||||
|
||||
|
||||
|
||||
self.final_linear = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, key, value, query, relationship_matrix, dropout):
|
||||
"""
|
||||
Args:
|
||||
key: [batch_size x seq_length x hidden_size]
|
||||
value: [batch_size x seq_length x hidden_size]
|
||||
query: [batch_size x seq_length x hidden_size]
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
batch_size, seq_length, hidden_size = key.size()
|
||||
heads_num = self.heads_num
|
||||
per_head_size = self.per_head_size
|
||||
|
||||
def shape(x):
|
||||
return x. \
|
||||
contiguous(). \
|
||||
view(batch_size, seq_length, heads_num, per_head_size). \
|
||||
transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
return x. \
|
||||
transpose(1, 2). \
|
||||
contiguous(). \
|
||||
view(batch_size, seq_length, hidden_size)
|
||||
|
||||
# torch.Size([1, 30, 43, 10]) torch.Size([1, 30, 43, 10]) torch.Size([1, 30, 43, 10])
|
||||
query, key, value = [l(x). \
|
||||
view(batch_size, -1, heads_num, per_head_size). \
|
||||
transpose(1, 2) \
|
||||
for l, x in zip(self.linear_layers, (query, key, value))
|
||||
]
|
||||
#self.relationship_K_parameter [30, 10]
|
||||
|
||||
|
||||
|
||||
# relation ship torch.Size([1, 43, 43, 10])
|
||||
relationship_matrix = relationship_matrix.unsqueeze(3).repeat(1, 1, 1, heads_num, 1)
|
||||
# ↑ torch.Size([1, 43, 43, 30, 10])
|
||||
|
||||
kk = self.relationship_K_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length, 1, 1)
|
||||
#kk = self.relationship_K_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length,1,1)
|
||||
# ↑ torch.Size([1, 43, 43, 30, 10])
|
||||
vv = self.relationship_V_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length, 1, 1)
|
||||
#vv = self.relationship_V_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length,1,1)
|
||||
# ↑ torch.Size([1, 43, 43, 30, 10])
|
||||
|
||||
relationship_K = relationship_matrix * kk
|
||||
relationship_V = relationship_matrix * vv
|
||||
#print(relationship_K[0,17,7,:,:])
|
||||
#print(relationship_V[0,17,7,:,:])
|
||||
|
||||
# [1, 43, 43 ,30 ,10] -> [1, 30, 43, 43, 10]
|
||||
relationship_K = relationship_K.transpose(1, 3)
|
||||
relationship_V = relationship_V.transpose(1, 3)
|
||||
|
||||
#[1, 30, 43, 10] -> [1, 30, 43, 43, 10]
|
||||
query = query.unsqueeze(3).repeat(1, 1, 1, seq_length, 1)
|
||||
key = key.unsqueeze(2).repeat(1, 1, seq_length, 1, 1)
|
||||
|
||||
#↓ [1, 30, 43, 43]
|
||||
#scores = (query*(relationship_K)).sum(dim=-1)
|
||||
scores = (query*(key+relationship_K)).sum(dim=-1)
|
||||
|
||||
scores = scores / math.sqrt(float(per_head_size))
|
||||
|
||||
probs = nn.Softmax(dim=-1)(scores)
|
||||
|
||||
"""
|
||||
save_dct = {}
|
||||
for i in range(30):
|
||||
attention_map = probs[0, i, :, :]
|
||||
save_dct[i] = attention_map.cpu().numpy()
|
||||
import pickle as pkl
|
||||
#print(save_dct)
|
||||
pkl.dump(save_dct, open("./save_att_relation.pkl", "wb"))
|
||||
|
||||
|
||||
print("finish")
|
||||
time.sleep(5)
|
||||
"""
|
||||
#probs = F.dropout(probs, p=dropout)
|
||||
|
||||
probs = probs.unsqueeze(4).repeat(1, 1, 1, 1, per_head_size)
|
||||
value = value.unsqueeze(2).repeat(1, 1, seq_length, 1, 1)
|
||||
|
||||
output = (probs*(value+relationship_V)).sum(dim=-2)
|
||||
|
||||
output = unshape(output)
|
||||
output = self.final_linear(output)
|
||||
return output
|
|
@ -0,0 +1,47 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from .layer_norm import LayerNorm
|
||||
from .position_ffn import PositionwiseFeedForward
|
||||
from .rat_multi_headed_attn import RatMultiHeadedAttention
|
||||
|
||||
class RATTransoformer(nn.Module):
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, input_size, state_size, relationship_number):
|
||||
super(RATTransoformer, self).__init__()
|
||||
input_size = int(input_size)
|
||||
assert input_size == state_size
|
||||
|
||||
heads_num = 30
|
||||
|
||||
self.self_attn = RatMultiHeadedAttention(state_size, heads_num, relationship_number)
|
||||
#self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
#self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
|
||||
def forward(self, hidden, relationship_matrix, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
assert dropout_amount == 0
|
||||
batch_size = 1
|
||||
seq_length = hidden.size()[1]
|
||||
|
||||
inter = self.self_attn(hidden, hidden, hidden, relationship_matrix, dropout_amount)
|
||||
output = self.feed_forward(inter)
|
||||
|
||||
# inter = F.dropout(self.self_attn(hidden, hidden, hidden, relationship_matrix, dropout_amount), p=dropout_amount)
|
||||
#inter = self.layer_norm_1(inter + hidden)
|
||||
# output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
#output = self.layer_norm_2(output + inter)
|
||||
|
||||
return output
|
|
@ -0,0 +1,182 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .layer_norm import LayerNorm
|
||||
from .position_ffn import PositionwiseFeedForward
|
||||
from .multi_headed_attn import MultiHeadedAttention
|
||||
from .rat_transformer_layer import RATTransoformer
|
||||
|
||||
class TransformerAttention(nn.Module):
|
||||
def __init__(self, input_size, state_size):
|
||||
super(TransformerAttention, self).__init__()
|
||||
input_size = int(input_size)
|
||||
assert input_size == state_size
|
||||
self.relationship_number = 10
|
||||
self.rat_transoformer = RATTransoformer(input_size, state_size, relationship_number=self.relationship_number)
|
||||
|
||||
|
||||
def forward(self, utterance_states, schema_states, input_sequence, unflat_sequence, input_schema, schema, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
utterance_states : [utterance_len x emb_size]
|
||||
schema_states : [schema_len x emb_size]
|
||||
input_sequence : utterance_len
|
||||
input_schema : schema_len
|
||||
|
||||
Returns:
|
||||
utterance_output: [emb_size x utterance_len]
|
||||
schema_output: [emb_size x schema_len]
|
||||
"""
|
||||
|
||||
assert utterance_states.size()[1] == 350
|
||||
utterance_states = utterance_states[:,:300]
|
||||
|
||||
print('input_sequence', input_sequence)
|
||||
#print('input_schema.column_names_surface_form', input_schema.column_names_surface_form)
|
||||
|
||||
if True:
|
||||
#print('foreign_keys', schema['foreign_keys'])
|
||||
#print('input_schema.column_names_surface_form', input_schema.column_names_surface_form)
|
||||
|
||||
utterance_len = utterance_states.size()[0]
|
||||
schema_len = schema_states.size()[0]
|
||||
#print(utterance_len, schema_len)
|
||||
assert len(input_sequence) == utterance_len
|
||||
assert len(input_schema.column_names_surface_form) == schema_len
|
||||
|
||||
relationship_matrix = torch.zeros([ utterance_len+schema_len, utterance_len+schema_len,self.relationship_number])
|
||||
|
||||
def get_name(s):
|
||||
if s == '*':
|
||||
return '', ''
|
||||
table_name, column_num = s.split('.')
|
||||
|
||||
if column_num == '*':
|
||||
column_num = ''
|
||||
return table_name, column_num
|
||||
|
||||
def is_ngram_match(s1, s2, n):
|
||||
vis = set()
|
||||
for i in range(len(s1)-n+1):
|
||||
vis.add( s1[i:i+n] )
|
||||
for i in range(len(s2)-n+1):
|
||||
if s2[i:i+n] in vis:
|
||||
return True
|
||||
return False
|
||||
|
||||
''' 0. same table
|
||||
1. foreign key
|
||||
2. word and table: exact math
|
||||
3-5. word and table: n-gram 3-5
|
||||
6. word and column: exact math
|
||||
7-9. word and column: n-gram 3-5'''
|
||||
for i in range(schema_len):
|
||||
table_i, column_i = get_name(input_schema.column_names_surface_form[i])
|
||||
idx = i+utterance_len
|
||||
for j in range(schema_len):
|
||||
table_j, column_j = get_name(input_schema.column_names_surface_form[j])
|
||||
jdx = j+utterance_len
|
||||
|
||||
if table_i == table_j:
|
||||
relationship_matrix[idx][jdx][0] = 1
|
||||
|
||||
for i, j in schema['foreign_keys']:
|
||||
idx = i+utterance_len
|
||||
jdx = j+utterance_len
|
||||
|
||||
|
||||
relationship_matrix[idx][jdx][1] = 1
|
||||
relationship_matrix[jdx][idx][1] = 1
|
||||
|
||||
#print(idx, jdx)
|
||||
#print(input_schema.column_names_surface_form)
|
||||
#print(input_schema.column_names_surface_form[i], input_schema.column_names_surface_form[j])
|
||||
|
||||
|
||||
for i in range(schema_len):
|
||||
table, column = get_name(input_schema.column_names_surface_form[i])
|
||||
idx = i+utterance_len
|
||||
for j in range(utterance_len):
|
||||
word = input_sequence[j].lower()
|
||||
jdx = j
|
||||
|
||||
#word and table
|
||||
no_math = True
|
||||
if word == table:
|
||||
relationship_matrix[idx][jdx][2] = 1
|
||||
relationship_matrix[jdx][idx][2] = 1
|
||||
|
||||
for n in [5,4,3]:
|
||||
if no_math and is_ngram_match(word, table, n):
|
||||
relationship_matrix[idx][jdx][n] = 1
|
||||
relationship_matrix[jdx][idx][n] = 1
|
||||
|
||||
#word and column
|
||||
no_math = True
|
||||
if word == column:
|
||||
relationship_matrix[idx][jdx][6] = 1
|
||||
relationship_matrix[jdx][idx][6] = 1
|
||||
|
||||
for n in [5,4,3]:
|
||||
if no_math and is_ngram_match(word, column, n):
|
||||
relationship_matrix[idx][jdx][4+n] = 1
|
||||
relationship_matrix[jdx][idx][4+n] = 1
|
||||
|
||||
|
||||
#print(relationship_matrix)
|
||||
#print(unflat_sequence)
|
||||
|
||||
# start = 0
|
||||
# for i in unflat_sequence:
|
||||
# start = len(i) + start
|
||||
# relationship_matrix[0:start][0:start][:] *= 0.98
|
||||
|
||||
relationship_matrix = relationship_matrix.unsqueeze(0).cuda() #[1, all_len, all_len, 17]
|
||||
|
||||
"""
|
||||
print(relationship_matrix.shape)
|
||||
matrix = torch.zeros_like(relationship_matrix[:,:,:,0]).cuda()
|
||||
print(matrix.shape)
|
||||
for i in range(relationship_matrix.shape[-1]):
|
||||
matrix += relationship_matrix[:,:,:,i]
|
||||
#matrix = matrix.squeeze(0).cpu().numpy()
|
||||
label = input_sequence + input_schema.column_names_surface_form
|
||||
matrix = matrix.squeeze(0).cpu().numpy()
|
||||
save_matrix(matrix, label, 0)
|
||||
"""
|
||||
utterance_and_schema_states = torch.cat([utterance_states, schema_states], dim=0).unsqueeze(0) #[1, all_len, emb_size]
|
||||
utterance_and_schema_output = self.rat_transoformer(utterance_and_schema_states, relationship_matrix, 0) #[1, all_len, emb_size]
|
||||
#print(utterance_and_schema_output.shape)
|
||||
utterance_output = utterance_and_schema_output[0,:utterance_len,:]
|
||||
schema_output = utterance_and_schema_output[0,utterance_len:,:]
|
||||
|
||||
else:
|
||||
utterance_output = utterance_states
|
||||
schema_output = schema_states
|
||||
utterance_output = torch.transpose(utterance_output, 1, 0)
|
||||
schema_output = torch.transpose(schema_output, 1, 0)
|
||||
|
||||
return utterance_output, schema_output
|
||||
|
||||
|
||||
def save_matrix(matrix, label, i):
|
||||
"""
|
||||
matrix: n x n
|
||||
label: len(str) == n
|
||||
"""
|
||||
plt.rcParams["figure.figsize"] = (1000,1000)
|
||||
plt.matshow(matrix, cmap=plt.cm.Reds)
|
||||
plt.xticks(np.arange(len(label)), label, rotation=90, fontsize=16)
|
||||
plt.yticks(np.arange(len(label)), label, fontsize=16)
|
||||
plt.margins(2)
|
||||
plt.subplots_adjust()
|
||||
plt.savefig('relation_' + str(i), dpi=300)
|
||||
exit()
|
|
@ -0,0 +1,455 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from .layer_norm import LayerNorm
|
||||
from .position_ffn import PositionwiseFeedForward
|
||||
from .multi_headed_attn import MultiHeadedAttention
|
||||
from ..encoder import Encoder as Encoder2
|
||||
|
||||
def gelu(x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
class oldestTransformerLayer(nn.Module): #oldest
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, num_layers, input_size, state_size):
|
||||
super(TransformerLayer, self).__init__()
|
||||
#num_layers is no useful
|
||||
print('input_size',input_size, type(input_size))
|
||||
print('state_size',state_size, type(state_size))
|
||||
input_size = int(input_size)
|
||||
|
||||
heads_num = -1
|
||||
for i in range(4, 10):
|
||||
if state_size % i == 0:
|
||||
heads_num = i
|
||||
assert heads_num != -1
|
||||
|
||||
self.first_linear = nn.Linear(input_size, state_size)
|
||||
|
||||
self.self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
self._1_self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self._1_layer_norm_1 = LayerNorm(state_size)
|
||||
self._1_feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self._1_layer_norm_2 = LayerNorm(state_size)
|
||||
self._2_self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self._2_layer_norm_1 = LayerNorm(state_size)
|
||||
self._2_feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self._2_layer_norm_2 = LayerNorm(state_size)
|
||||
'''self._3_self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self._3_layer_norm_1 = LayerNorm(state_size)
|
||||
self._3_feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self._3_layer_norm_2 = LayerNorm(state_size)
|
||||
self._4_self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self._4_layer_norm_1 = LayerNorm(state_size)
|
||||
self._4_feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self._4_layer_norm_2 = LayerNorm(state_size)'''
|
||||
|
||||
self.X = Encoder2(num_layers, input_size, state_size)
|
||||
|
||||
'''self.cell_memories_linear = nn.Linear(1, input_size)
|
||||
self.hidden_states_linear = nn.Linear(1, input_size)'''
|
||||
|
||||
self.last_linear = nn.Linear(state_size*2, state_size)
|
||||
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
seq_length = len(sequence)
|
||||
hidden = []
|
||||
|
||||
for token in sequence:
|
||||
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
'''hidden.append(self.cell_memories_linear( torch.ones([1,1,1]).cuda() ))
|
||||
hidden.append(self.hidden_states_linear( torch.ones([1,1,1]).cuda() ))'''
|
||||
|
||||
hidden = torch.cat(hidden, 1)
|
||||
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
|
||||
|
||||
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
|
||||
|
||||
inter = F.dropout(self._1_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self._1_layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self._1_feed_forward(inter), p=dropout_amount)
|
||||
hidden = self._1_layer_norm_2(output + inter)
|
||||
|
||||
inter = F.dropout(self._2_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self._2_layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self._2_feed_forward(inter), p=dropout_amount)
|
||||
hidden = self._2_layer_norm_2(output + inter)
|
||||
|
||||
'''inter = F.dropout(self._3_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self._3_layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self._3_feed_forward(inter), p=dropout_amount)
|
||||
hidden = self._3_layer_norm_2(output + inter)
|
||||
|
||||
inter = F.dropout(self._4_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self._4_layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self._4_feed_forward(inter), p=dropout_amount)
|
||||
hidden = self._4_layer_norm_2(output + inter)'''
|
||||
|
||||
|
||||
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
|
||||
final_outputs = []
|
||||
|
||||
#assert len(sequence) + 2 == output.size()[1]
|
||||
|
||||
for i in range(len(sequence)):
|
||||
x = output[0,i,]
|
||||
final_outputs.append(x)
|
||||
|
||||
#cell_memories = [final_outputs[-2]]
|
||||
#hidden_states = [final_outputs[-1]]
|
||||
#final_outputs = final_outputs[:-2]
|
||||
|
||||
#return (cell_memories, hidden_states), final_outputs
|
||||
|
||||
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
|
||||
|
||||
#print('<', final_outputs[0].mean().item(), final_outputs[0].std().item(), '>', '<', final_outputs2[0].mean().item(), final_outputs2[0].std().item(), '>')
|
||||
|
||||
for i in range(len(sequence)):
|
||||
final_outputs[i] = F.dropout(gelu(self.last_linear(torch.cat( [final_outputs2[i], final_outputs[i]]))), p=dropout_amount)
|
||||
|
||||
|
||||
|
||||
return x, final_outputs
|
||||
|
||||
class NewNewTransformerLayer(nn.Module): #Newnew
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, num_layers, input_size, _state_size):
|
||||
super(TransformerLayer, self).__init__()
|
||||
#num_layers is no useful
|
||||
print('input_size', input_size, type(input_size))
|
||||
print('_state_size', _state_size, type(_state_size))
|
||||
if _state_size == 650:
|
||||
state_size = 324
|
||||
else:
|
||||
state_size = _state_size//2
|
||||
input_size = int(input_size)
|
||||
|
||||
heads_num = -1
|
||||
for i in range(4, 10):
|
||||
if state_size % i == 0:
|
||||
heads_num = i
|
||||
assert heads_num != -1
|
||||
|
||||
self.first_linear = nn.Linear(input_size, state_size)
|
||||
|
||||
self.self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
self.cell_memories_linear = nn.Linear(state_size, state_size)
|
||||
self.hidden_states_linear = nn.Linear(state_size, state_size)
|
||||
|
||||
self.X = Encoder2(num_layers, input_size, _state_size-state_size)
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
seq_length = len(sequence)
|
||||
hidden = []
|
||||
|
||||
for token in sequence:
|
||||
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
|
||||
hidden = torch.cat(hidden, 1)
|
||||
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
|
||||
|
||||
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
|
||||
|
||||
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self.layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
output = self.layer_norm_2(output + inter)
|
||||
|
||||
final_outputs = []
|
||||
|
||||
assert len(sequence) == output.size()[1]
|
||||
|
||||
for i in range(len(sequence)):
|
||||
x = output[0,i,]
|
||||
final_outputs.append(x)
|
||||
|
||||
x = output[0].mean(dim=0)
|
||||
cell_memories = self.cell_memories_linear(x)
|
||||
cell_memories = [F.dropout(gelu(cell_memories), p=dropout_amount)]
|
||||
hidden_states = self.hidden_states_linear(x)
|
||||
hidden_states = [F.dropout(gelu(hidden_states), p=dropout_amount)]
|
||||
|
||||
'''print('hidden_states[0]', hidden_states[0].size())
|
||||
print('cell_memories[0]', cell_memories[0].size())
|
||||
print('final_outputs[0]', final_outputs[0].size())'''
|
||||
|
||||
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
|
||||
|
||||
cell_memories2 = x[0]
|
||||
hidden_states2 = x[1]
|
||||
|
||||
for i in range(len(sequence)):
|
||||
final_outputs[i] = torch.cat( [final_outputs[i], final_outputs2[i]])
|
||||
|
||||
cell_memories[0] = torch.cat( [cell_memories[0], cell_memories2[0]])
|
||||
hidden_states[0] = torch.cat( [hidden_states[0], hidden_states2[0]])
|
||||
|
||||
'''print('hidden_states[0]', hidden_states[0].size())
|
||||
print('cell_memories[0]', cell_memories[0].size())
|
||||
print('final_outputs[0]', final_outputs[0].size())'''
|
||||
|
||||
return (cell_memories, hidden_states), final_outputs
|
||||
|
||||
class TransformerLayer(nn.Module): #New
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, num_layers, input_size, state_size):
|
||||
super(TransformerLayer, self).__init__()
|
||||
#num_layers is no useful
|
||||
print('input_size',input_size, type(input_size))
|
||||
print('state_size',state_size, type(state_size))
|
||||
input_size = int(input_size)
|
||||
|
||||
heads_num = -1
|
||||
for i in range(4, 10):
|
||||
if state_size % i == 0:
|
||||
heads_num = i
|
||||
assert heads_num != -1
|
||||
|
||||
self.first_linear = nn.Linear(input_size, state_size)
|
||||
|
||||
self.self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
self.X = Encoder2(num_layers, input_size, state_size)
|
||||
self.last_linear = nn.Linear(state_size*2, state_size)
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
seq_length = len(sequence)
|
||||
hidden = []
|
||||
|
||||
for token in sequence:
|
||||
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
|
||||
hidden = torch.cat(hidden, 1)
|
||||
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
|
||||
|
||||
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
|
||||
|
||||
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self.layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
output = self.layer_norm_2(output + inter)
|
||||
|
||||
final_outputs = []
|
||||
|
||||
assert len(sequence) == output.size()[1]
|
||||
|
||||
for i in range(len(sequence)):
|
||||
x = output[0,i,]
|
||||
final_outputs.append(x)
|
||||
|
||||
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
|
||||
|
||||
for i in range(len(sequence)):
|
||||
#final_outputs[i] = F.dropout(final_outputs2[i]+final_outputs[i], p=dropout_amount)
|
||||
final_outputs[i] = F.dropout(gelu(self.last_linear(torch.cat( [final_outputs2[i], final_outputs[i]]))), p=dropout_amount)
|
||||
|
||||
return x, final_outputs
|
||||
|
||||
class Old2only2TransformerLayer(nn.Module): #Old2 2020 05 24 01.47
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, num_layers, input_size, state_size):
|
||||
super(TransformerLayer, self).__init__()
|
||||
#num_layers is no useful
|
||||
print('input_size',input_size, type(input_size))
|
||||
print('state_size',state_size, type(state_size))
|
||||
input_size = int(input_size)
|
||||
|
||||
heads_num = -1
|
||||
for i in range(4, 10):
|
||||
if state_size % i == 0:
|
||||
heads_num = i
|
||||
assert heads_num != -1
|
||||
|
||||
self.first_linear = nn.Linear(input_size, state_size)
|
||||
|
||||
self.self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
#self.cell_memories_linear = nn.Linear(state_size, state_size)
|
||||
#self.hidden_states_linear = nn.Linear(state_size, state_size)
|
||||
|
||||
#self.last_linear = nn.Linear(state_size+state_size, state_size)
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
seq_length = len(sequence)
|
||||
hidden = []
|
||||
|
||||
for token in sequence:
|
||||
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
hidden = torch.cat(hidden, 1)
|
||||
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
|
||||
|
||||
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
|
||||
|
||||
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self.layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
output = self.layer_norm_2(output + inter)
|
||||
|
||||
final_outputs = []
|
||||
|
||||
assert len(sequence) == output.size()[1]
|
||||
|
||||
#output = F.dropout(gelu( self.last_linear(torch.cat([output, hidden], dim=-1)) ), p=dropout_amount)
|
||||
|
||||
for i in range(len(sequence)):
|
||||
x = output[0,i,]
|
||||
final_outputs.append(x)
|
||||
|
||||
#x = output[0].mean(dim=0)
|
||||
#cell_memories = self.cell_memories_linear(x)
|
||||
#cell_memories = [F.dropout(gelu(cell_memories), p=dropout_amount)]
|
||||
#hidden_states = self.hidden_states_linear(x)
|
||||
#hidden_states = [F.dropout(gelu(hidden_states), p=dropout_amount)]
|
||||
|
||||
|
||||
#return (cell_memories, hidden_states), final_outputs
|
||||
return None, final_outputs
|
||||
|
||||
class OldTransformerLayer(nn.Module):
|
||||
"""
|
||||
Transformer layer mainly consists of two parts:
|
||||
multi-headed self-attention and feed forward layer.
|
||||
"""
|
||||
def __init__(self, num_layers, input_size, state_size):
|
||||
super(TransformerLayer, self).__init__()
|
||||
#num_layers is no useful
|
||||
print('input_size',input_size, type(input_size))
|
||||
print('state_size',state_size, type(state_size))
|
||||
input_size = int(input_size)
|
||||
|
||||
heads_num = -1
|
||||
for i in range(4, 10):
|
||||
if state_size % i == 0:
|
||||
heads_num = i
|
||||
assert heads_num != -1
|
||||
|
||||
self.first_linear = nn.Linear(input_size, state_size)
|
||||
|
||||
self.self_attn = MultiHeadedAttention(state_size, heads_num)
|
||||
self.layer_norm_1 = LayerNorm(state_size)
|
||||
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
|
||||
self.layer_norm_2 = LayerNorm(state_size)
|
||||
|
||||
self.cell_memories_linear = nn.Linear(1, input_size)
|
||||
self.hidden_states_linear = nn.Linear(1, input_size)
|
||||
|
||||
|
||||
def forward(self, sequence, embedder, dropout_amount=0):
|
||||
"""
|
||||
Args:
|
||||
hidden: [batch_size x seq_length x emb_size]
|
||||
mask: [batch_size x 1 x seq_length x seq_length]
|
||||
|
||||
Returns:
|
||||
output: [batch_size x seq_length x hidden_size]
|
||||
"""
|
||||
|
||||
batch_size = 1
|
||||
seq_length = len(sequence)
|
||||
hidden = []
|
||||
|
||||
for token in sequence:
|
||||
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
hidden.append(self.cell_memories_linear( torch.ones([1,1,1]).cuda() ))
|
||||
hidden.append(self.hidden_states_linear( torch.ones([1,1,1]).cuda() ))
|
||||
|
||||
hidden = torch.cat(hidden, 1)
|
||||
mask = torch.zeros([batch_size, 1, seq_length+2, seq_length+2]).cuda()
|
||||
|
||||
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
|
||||
|
||||
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
|
||||
inter = self.layer_norm_1(inter + hidden)
|
||||
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
|
||||
output = self.layer_norm_2(output + inter)
|
||||
|
||||
final_outputs = []
|
||||
|
||||
assert len(sequence) + 2 == output.size()[1]
|
||||
|
||||
for i in range(len(sequence)+2):
|
||||
x = output[0,i,]
|
||||
final_outputs.append(x)
|
||||
|
||||
cell_memories = [final_outputs[-2]]
|
||||
hidden_states = [final_outputs[-1]]
|
||||
final_outputs = final_outputs[:-2]
|
||||
|
||||
return (cell_memories, hidden_states), final_outputs
|
|
@ -0,0 +1,398 @@
|
|||
# modified from https://github.com/naver/sqlova
|
||||
|
||||
import os, json
|
||||
import random as rd
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .bert import tokenization as tokenization
|
||||
from .bert.modeling import BertConfig, BertModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def get_bert(params):
|
||||
BERT_PT_PATH = './model/bert/data/annotated_wikisql_and_PyTorch_bert_param'
|
||||
map_bert_type_abb = {'uS': 'uncased_L-12_H-768_A-12',
|
||||
'uL': 'uncased_L-24_H-1024_A-16',
|
||||
'cS': 'cased_L-12_H-768_A-12',
|
||||
'cL': 'cased_L-24_H-1024_A-16',
|
||||
'mcS': 'multi_cased_L-12_H-768_A-12'}
|
||||
bert_type = map_bert_type_abb[params.bert_type_abb]
|
||||
if params.bert_type_abb == 'cS' or params.bert_type_abb == 'cL' or params.bert_type_abb == 'mcS':
|
||||
do_lower_case = False
|
||||
else:
|
||||
do_lower_case = True
|
||||
no_pretraining = False
|
||||
|
||||
bert_config_file = os.path.join(BERT_PT_PATH, f'bert_config_{bert_type}.json')
|
||||
vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt')
|
||||
init_checkpoint = os.path.join(BERT_PT_PATH, f'pytorch_model_{bert_type}.bin')
|
||||
|
||||
print('bert_config_file', bert_config_file)
|
||||
print('vocab_file', vocab_file)
|
||||
print('init_checkpoint', init_checkpoint)
|
||||
|
||||
bert_config = BertConfig.from_json_file(bert_config_file)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=vocab_file, do_lower_case=do_lower_case)
|
||||
bert_config.print_status()
|
||||
|
||||
model_bert = BertModel(bert_config)
|
||||
if no_pretraining:
|
||||
pass
|
||||
else:
|
||||
model_bert.load_state_dict(torch.load(init_checkpoint, map_location='cpu'))
|
||||
print("Load pre-trained parameters.")
|
||||
model_bert.to(device)
|
||||
|
||||
return model_bert, tokenizer, bert_config
|
||||
|
||||
def generate_inputs(tokenizer, nlu1_tok, hds1):
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
|
||||
t_to_tt_idx_hds1 = []
|
||||
|
||||
tokens.append("[CLS]")
|
||||
i_st_nlu = len(tokens) # to use it later
|
||||
|
||||
segment_ids.append(0)
|
||||
for token in nlu1_tok:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
i_ed_nlu = len(tokens)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
i_hds = []
|
||||
for i, hds11 in enumerate(hds1):
|
||||
i_st_hd = len(tokens)
|
||||
t_to_tt_idx_hds11 = []
|
||||
sub_tok = []
|
||||
for sub_tok1 in hds11.split():
|
||||
t_to_tt_idx_hds11.append(len(sub_tok))
|
||||
sub_tok += tokenizer.tokenize(sub_tok1)
|
||||
t_to_tt_idx_hds1.append(t_to_tt_idx_hds11)
|
||||
tokens += sub_tok
|
||||
|
||||
i_ed_hd = len(tokens)
|
||||
i_hds.append((i_st_hd, i_ed_hd))
|
||||
segment_ids += [1] * len(sub_tok)
|
||||
if i < len(hds1)-1:
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
elif i == len(hds1)-1:
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
else:
|
||||
raise EnvironmentError
|
||||
|
||||
i_nlu = (i_st_nlu, i_ed_nlu)
|
||||
|
||||
return tokens, segment_ids, i_nlu, i_hds, t_to_tt_idx_hds1
|
||||
|
||||
def gen_l_hpu(i_hds):
|
||||
"""
|
||||
# Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size
|
||||
i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)])
|
||||
"""
|
||||
l_hpu = []
|
||||
for i_hds1 in i_hds:
|
||||
for i_hds11 in i_hds1:
|
||||
l_hpu.append(i_hds11[1] - i_hds11[0])
|
||||
|
||||
return l_hpu
|
||||
|
||||
def get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length):
|
||||
"""
|
||||
Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT.
|
||||
|
||||
INPUT
|
||||
:param model_bert:
|
||||
:param tokenizer: WordPiece toknizer
|
||||
:param nlu: Question
|
||||
:param nlu_t: CoreNLP tokenized nlu.
|
||||
:param hds: Headers
|
||||
:param hs_t: None or 1st-level tokenized headers
|
||||
:param max_seq_length: max input token length
|
||||
|
||||
OUTPUT
|
||||
tokens: BERT input tokens
|
||||
nlu_tt: WP-tokenized input natural language questions
|
||||
orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token
|
||||
tok_to_orig_index: inverse map.
|
||||
|
||||
"""
|
||||
|
||||
l_n = []
|
||||
l_hs = [] # The length of columns for each batch
|
||||
|
||||
input_ids = []
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
input_mask = []
|
||||
|
||||
i_nlu = [] # index to retreive the position of contextual vector later.
|
||||
i_hds = []
|
||||
|
||||
doc_tokens = []
|
||||
nlu_tt = []
|
||||
|
||||
t_to_tt_idx = []
|
||||
tt_to_t_idx = []
|
||||
|
||||
t_to_tt_idx_hds = []
|
||||
|
||||
for b, nlu_t1 in enumerate(nlu_t):
|
||||
hds1 = hds[b]
|
||||
l_hs.append(len(hds1))
|
||||
|
||||
# 1. 2nd tokenization using WordPiece
|
||||
tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens (here, CoreNLP).
|
||||
t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens.
|
||||
nlu_tt1 = [] # all_doc_tokens[ orig_to_tok_idx[i] ] returns first sub-token segement of i-th-1st-level-token
|
||||
for (i, token) in enumerate(nlu_t1):
|
||||
t_to_tt_idx1.append(
|
||||
len(nlu_tt1)) # all_doc_tokens[ indicate the start position of original 'white-space' tokens.
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tt_to_t_idx1.append(i)
|
||||
nlu_tt1.append(sub_token) # all_doc_tokens are further tokenized using WordPiece tokenizer
|
||||
nlu_tt.append(nlu_tt1)
|
||||
tt_to_t_idx.append(tt_to_t_idx1)
|
||||
t_to_tt_idx.append(t_to_tt_idx1)
|
||||
|
||||
l_n.append(len(nlu_tt1))
|
||||
|
||||
# [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP]
|
||||
# 2. Generate BERT inputs & indices.
|
||||
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, hds1)
|
||||
|
||||
assert len(t_to_tt_idx_hds1) == len(hds1)
|
||||
|
||||
t_to_tt_idx_hds.append(t_to_tt_idx_hds1)
|
||||
|
||||
input_ids1 = tokenizer.convert_tokens_to_ids(tokens1)
|
||||
|
||||
# Input masks
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask1 = [1] * len(input_ids1)
|
||||
|
||||
# 3. Zero-pad up to the sequence length.
|
||||
if len(nlu_t) == 1:
|
||||
max_seq_length = len(input_ids1)
|
||||
while len(input_ids1) < max_seq_length:
|
||||
input_ids1.append(0)
|
||||
input_mask1.append(0)
|
||||
segment_ids1.append(0)
|
||||
|
||||
assert len(input_ids1) == max_seq_length
|
||||
assert len(input_mask1) == max_seq_length
|
||||
assert len(segment_ids1) == max_seq_length
|
||||
|
||||
input_ids.append(input_ids1)
|
||||
tokens.append(tokens1)
|
||||
segment_ids.append(segment_ids1)
|
||||
input_mask.append(input_mask1)
|
||||
|
||||
i_nlu.append(i_nlu1)
|
||||
i_hds.append(i_hds1)
|
||||
|
||||
# Convert to tensor
|
||||
all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
|
||||
all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device)
|
||||
all_segment_ids = torch.tensor(segment_ids, dtype=torch.long).to(device)
|
||||
|
||||
# 4. Generate BERT output.
|
||||
all_encoder_layer, pooled_output = model_bert(all_input_ids, all_segment_ids, all_input_mask)
|
||||
|
||||
# 5. generate l_hpu from i_hds
|
||||
l_hpu = gen_l_hpu(i_hds)
|
||||
|
||||
assert len(set(l_n)) == 1 and len(set(i_nlu)) == 1
|
||||
assert l_n[0] == i_nlu[0][1] - i_nlu[0][0]
|
||||
|
||||
return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \
|
||||
l_n, l_hpu, l_hs, \
|
||||
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds
|
||||
|
||||
def get_wemb_n(i_nlu, l_n, hS, num_hidden_layers, all_encoder_layer, num_out_layers_n):
|
||||
"""
|
||||
Get the representation of each tokens.
|
||||
"""
|
||||
bS = len(l_n)
|
||||
l_n_max = max(l_n)
|
||||
# print('wemb_n: [bS, l_n_max, hS * num_out_layers_n] = ', bS, l_n_max, hS * num_out_layers_n)
|
||||
wemb_n = torch.zeros([bS, l_n_max, hS * num_out_layers_n]).to(device)
|
||||
for b in range(bS):
|
||||
# [B, max_len, dim]
|
||||
# Fill zero for non-exist part.
|
||||
l_n1 = l_n[b]
|
||||
i_nlu1 = i_nlu[b]
|
||||
for i_noln in range(num_out_layers_n):
|
||||
i_layer = num_hidden_layers - 1 - i_noln
|
||||
st = i_noln * hS
|
||||
ed = (i_noln + 1) * hS
|
||||
wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), st:ed] = all_encoder_layer[i_layer][b, i_nlu1[0]:i_nlu1[1], :]
|
||||
return wemb_n
|
||||
|
||||
def get_wemb_h(i_hds, l_hpu, l_hs, hS, num_hidden_layers, all_encoder_layer, num_out_layers_h):
|
||||
"""
|
||||
As if
|
||||
[ [table-1-col-1-tok1, t1-c1-t2, ...],
|
||||
[t1-c2-t1, t1-c2-t2, ...].
|
||||
...
|
||||
[t2-c1-t1, ...,]
|
||||
]
|
||||
"""
|
||||
bS = len(l_hs)
|
||||
l_hpu_max = max(l_hpu)
|
||||
num_of_all_hds = sum(l_hs)
|
||||
wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS * num_out_layers_h]).to(device)
|
||||
# print('wemb_h: [num_of_all_hds, l_hpu_max, hS * num_out_layers_h] = ', wemb_h.size())
|
||||
b_pu = -1
|
||||
for b, i_hds1 in enumerate(i_hds):
|
||||
for b1, i_hds11 in enumerate(i_hds1):
|
||||
b_pu += 1
|
||||
for i_nolh in range(num_out_layers_h):
|
||||
i_layer = num_hidden_layers - 1 - i_nolh
|
||||
st = i_nolh * hS
|
||||
ed = (i_nolh + 1) * hS
|
||||
wemb_h[b_pu, 0:(i_hds11[1] - i_hds11[0]), st:ed] \
|
||||
= all_encoder_layer[i_layer][b, i_hds11[0]:i_hds11[1],:]
|
||||
|
||||
|
||||
return wemb_h
|
||||
|
||||
def get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=1, num_out_layers_h=1):
|
||||
|
||||
# get contextual output of all tokens from bert
|
||||
all_encoder_layer, pooled_output, tokens, i_nlu, i_hds,\
|
||||
l_n, l_hpu, l_hs, \
|
||||
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length)
|
||||
# all_encoder_layer: BERT outputs from all layers.
|
||||
# pooled_output: output of [CLS] vec.
|
||||
# tokens: BERT intput tokens
|
||||
# i_nlu: start and end indices of question in tokens
|
||||
# i_hds: start and end indices of headers
|
||||
|
||||
# get the wemb
|
||||
wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer,
|
||||
num_out_layers_n)
|
||||
|
||||
wemb_h = get_wemb_h(i_hds, l_hpu, l_hs, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer,
|
||||
num_out_layers_h)
|
||||
|
||||
return wemb_n, wemb_h, l_n, l_hpu, l_hs, \
|
||||
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds
|
||||
|
||||
def prepare_input(tokenizer, input_sequence, input_schema, max_seq_length):
|
||||
nlu_t = []
|
||||
hds = []
|
||||
|
||||
nlu_t1 = input_sequence
|
||||
all_hds = input_schema.column_names_embedder_input
|
||||
|
||||
nlu_tt1 = []
|
||||
for (i, token) in enumerate(nlu_t1):
|
||||
nlu_tt1 += tokenizer.tokenize(token)
|
||||
|
||||
current_hds1 = []
|
||||
for hds1 in all_hds:
|
||||
new_hds1 = current_hds1 + [hds1]
|
||||
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, new_hds1)
|
||||
if len(segment_ids1) > max_seq_length:
|
||||
nlu_t.append(nlu_t1)
|
||||
hds.append(current_hds1)
|
||||
current_hds1 = [hds1]
|
||||
else:
|
||||
current_hds1 = new_hds1
|
||||
|
||||
if len(current_hds1) > 0:
|
||||
nlu_t.append(nlu_t1)
|
||||
hds.append(current_hds1)
|
||||
|
||||
return nlu_t, hds
|
||||
|
||||
def prepare_input_v2(tokenizer, input_sequence, input_schema):
|
||||
nlu_t = []
|
||||
hds = []
|
||||
max_seq_length = 0
|
||||
|
||||
nlu_t1 = input_sequence
|
||||
all_hds = input_schema.column_names_embedder_input
|
||||
|
||||
nlu_tt1 = []
|
||||
for (i, token) in enumerate(nlu_t1):
|
||||
nlu_tt1 += tokenizer.tokenize(token)
|
||||
|
||||
current_hds1 = []
|
||||
current_table = ''
|
||||
for hds1 in all_hds:
|
||||
hds1_table = hds1.split('.')[0].strip()
|
||||
if hds1_table == current_table:
|
||||
current_hds1.append(hds1)
|
||||
else:
|
||||
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1)
|
||||
max_seq_length = max(max_seq_length, len(segment_ids1))
|
||||
|
||||
nlu_t.append(nlu_t1)
|
||||
hds.append(current_hds1)
|
||||
current_hds1 = [hds1]
|
||||
current_table = hds1_table
|
||||
|
||||
if len(current_hds1) > 0:
|
||||
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1)
|
||||
max_seq_length = max(max_seq_length, len(segment_ids1))
|
||||
nlu_t.append(nlu_t1)
|
||||
hds.append(current_hds1)
|
||||
|
||||
return nlu_t, hds, max_seq_length
|
||||
|
||||
def get_bert_encoding(bert_config, model_bert, tokenizer, input_sequence, input_schema, bert_input_version='v1', max_seq_length=512, num_out_layers_n=1, num_out_layers_h=1):
|
||||
if bert_input_version == 'v1':
|
||||
nlu_t, hds = prepare_input(tokenizer, input_sequence, input_schema, max_seq_length)
|
||||
elif bert_input_version == 'v2':
|
||||
nlu_t, hds, max_seq_length = prepare_input_v2(tokenizer, input_sequence, input_schema)
|
||||
|
||||
wemb_n, wemb_h, l_n, l_hpu, l_hs, nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n, num_out_layers_h)
|
||||
|
||||
t_to_tt_idx = t_to_tt_idx[0]
|
||||
assert len(t_to_tt_idx) == len(input_sequence)
|
||||
assert sum(len(t_to_tt_idx_hds1) for t_to_tt_idx_hds1 in t_to_tt_idx_hds) == len(input_schema.column_names_embedder_input)
|
||||
|
||||
assert list(wemb_h.size())[0] == len(input_schema.column_names_embedder_input)
|
||||
|
||||
utterance_states = []
|
||||
for i in range(len(t_to_tt_idx)):
|
||||
start = t_to_tt_idx[i]
|
||||
if i == len(t_to_tt_idx)-1:
|
||||
end = l_n[0]
|
||||
else:
|
||||
end = t_to_tt_idx[i+1]
|
||||
utterance_states.append(torch.mean(wemb_n[:,start:end,:], dim=[0,1]))
|
||||
assert len(utterance_states) == len(input_sequence)
|
||||
|
||||
schema_token_states = []
|
||||
cnt = -1
|
||||
for t_to_tt_idx_hds1 in t_to_tt_idx_hds:
|
||||
for t_to_tt_idx_hds11 in t_to_tt_idx_hds1:
|
||||
cnt += 1
|
||||
schema_token_states1 = []
|
||||
for i in range(len(t_to_tt_idx_hds11)):
|
||||
start = t_to_tt_idx_hds11[i]
|
||||
if i == len(t_to_tt_idx_hds11)-1:
|
||||
end = l_hpu[cnt]
|
||||
else:
|
||||
end = t_to_tt_idx_hds11[i+1]
|
||||
schema_token_states1.append(torch.mean(wemb_h[cnt,start:end,:], dim=0))
|
||||
assert len(schema_token_states1) == len(input_schema.column_names_embedder_input[cnt].split())
|
||||
schema_token_states.append(schema_token_states1)
|
||||
|
||||
assert len(schema_token_states) == len(input_schema.column_names_embedder_input)
|
||||
|
||||
return utterance_states, schema_token_states
|
|
@ -0,0 +1,609 @@
|
|||
"""Basic model training and evaluation functions."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import random
|
||||
import sys
|
||||
import json
|
||||
import progressbar
|
||||
import model.torch_utils
|
||||
import data_util.sql_util
|
||||
import torch
|
||||
|
||||
def write_prediction(fileptr,
|
||||
identifier,
|
||||
input_seq,
|
||||
probability,
|
||||
prediction,
|
||||
flat_prediction,
|
||||
gold_query,
|
||||
flat_gold_queries,
|
||||
gold_tables,
|
||||
index_in_interaction,
|
||||
database_username,
|
||||
database_password,
|
||||
database_timeout,
|
||||
compute_metrics=True,
|
||||
beam=None):
|
||||
pred_obj = {}
|
||||
pred_obj["identifier"] = identifier
|
||||
if len(identifier.split('/')) == 2:
|
||||
database_id, interaction_id = identifier.split('/')
|
||||
else:
|
||||
database_id = 'atis'
|
||||
interaction_id = identifier
|
||||
pred_obj["database_id"] = database_id
|
||||
pred_obj["interaction_id"] = interaction_id
|
||||
|
||||
pred_obj["input_seq"] = input_seq
|
||||
pred_obj["probability"] = probability
|
||||
pred_obj["prediction"] = prediction
|
||||
pred_obj["flat_prediction"] = flat_prediction
|
||||
pred_obj["gold_query"] = gold_query
|
||||
pred_obj["flat_gold_queries"] = flat_gold_queries
|
||||
pred_obj["index_in_interaction"] = index_in_interaction
|
||||
pred_obj["gold_tables"] = str(gold_tables)
|
||||
|
||||
pred_obj["beam"] = beam
|
||||
|
||||
# Now compute the metrics we want.
|
||||
|
||||
if compute_metrics:
|
||||
# First metric: whether flat predicted query is in the gold query set.
|
||||
correct_string = " ".join(flat_prediction) in [
|
||||
" ".join(q) for q in flat_gold_queries]
|
||||
pred_obj["correct_string"] = correct_string
|
||||
|
||||
# Database metrics
|
||||
if not correct_string:
|
||||
syntactic, semantic, pred_table = sql_util.execution_results(
|
||||
" ".join(flat_prediction), database_username, database_password, database_timeout)
|
||||
pred_table = sorted(pred_table)
|
||||
best_prec = 0.
|
||||
best_rec = 0.
|
||||
best_f1 = 0.
|
||||
|
||||
for gold_table in gold_tables:
|
||||
num_overlap = float(len(set(pred_table) & set(gold_table)))
|
||||
|
||||
if len(set(gold_table)) > 0:
|
||||
prec = num_overlap / len(set(gold_table))
|
||||
else:
|
||||
prec = 1.
|
||||
|
||||
if len(set(pred_table)) > 0:
|
||||
rec = num_overlap / len(set(pred_table))
|
||||
else:
|
||||
rec = 1.
|
||||
|
||||
if prec > 0. and rec > 0.:
|
||||
f1 = (2 * (prec * rec)) / (prec + rec)
|
||||
else:
|
||||
f1 = 1.
|
||||
|
||||
best_prec = max(best_prec, prec)
|
||||
best_rec = max(best_rec, rec)
|
||||
best_f1 = max(best_f1, f1)
|
||||
|
||||
else:
|
||||
syntactic = True
|
||||
semantic = True
|
||||
pred_table = []
|
||||
best_prec = 1.
|
||||
best_rec = 1.
|
||||
best_f1 = 1.
|
||||
|
||||
assert best_prec <= 1.
|
||||
assert best_rec <= 1.
|
||||
assert best_f1 <= 1.
|
||||
pred_obj["syntactic"] = syntactic
|
||||
pred_obj["semantic"] = semantic
|
||||
correct_table = (pred_table in gold_tables) or correct_string
|
||||
pred_obj["correct_table"] = correct_table
|
||||
pred_obj["strict_correct_table"] = correct_table and syntactic
|
||||
pred_obj["pred_table"] = str(pred_table)
|
||||
pred_obj["table_prec"] = best_prec
|
||||
pred_obj["table_rec"] = best_rec
|
||||
pred_obj["table_f1"] = best_f1
|
||||
|
||||
fileptr.write(json.dumps(pred_obj) + "\n")
|
||||
|
||||
class Metrics(Enum):
|
||||
"""Definitions of simple metrics to compute."""
|
||||
LOSS = 1
|
||||
TOKEN_ACCURACY = 2
|
||||
STRING_ACCURACY = 3
|
||||
CORRECT_TABLES = 4
|
||||
STRICT_CORRECT_TABLES = 5
|
||||
SEMANTIC_QUERIES = 6
|
||||
SYNTACTIC_QUERIES = 7
|
||||
|
||||
|
||||
def get_progressbar(name, size):
|
||||
"""Gets a progress bar object given a name and the total size.
|
||||
|
||||
Inputs:
|
||||
name (str): The name to display on the side.
|
||||
size (int): The maximum size of the progress bar.
|
||||
|
||||
"""
|
||||
return progressbar.ProgressBar(maxval=size,
|
||||
widgets=[name,
|
||||
progressbar.Bar('=', '[', ']'),
|
||||
' ',
|
||||
progressbar.Percentage(),
|
||||
' ',
|
||||
progressbar.ETA()])
|
||||
|
||||
|
||||
def train_epoch_with_utterances(batches,
|
||||
model,
|
||||
randomize=True):
|
||||
"""Trains model for a single epoch given batches of utterance data.
|
||||
|
||||
Inputs:
|
||||
batches (UtteranceBatch): The batches to give to training.
|
||||
model (ATISModel): The model obect.
|
||||
learning_rate (float): The learning rate to use during training.
|
||||
dropout_amount (float): Amount of dropout to set in the model.
|
||||
randomize (bool): Whether or not to randomize the order that the batches are seen.
|
||||
"""
|
||||
if randomize:
|
||||
random.shuffle(batches)
|
||||
progbar = get_progressbar("train ", len(batches))
|
||||
progbar.start()
|
||||
loss_sum = 0.
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
batch_loss = model.train_step(batch)
|
||||
loss_sum += batch_loss
|
||||
|
||||
progbar.update(i)
|
||||
|
||||
progbar.finish()
|
||||
|
||||
total_loss = loss_sum / len(batches)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def train_epoch_with_interactions(interaction_batches,
|
||||
params,
|
||||
model,
|
||||
randomize=True):
|
||||
"""Trains model for single epoch given batches of interactions.
|
||||
|
||||
Inputs:
|
||||
interaction_batches (list of InteractionBatch): The batches to train on.
|
||||
params (namespace): Parameters to run with.
|
||||
model (ATISModel): Model to train.
|
||||
randomize (bool): Whether or not to randomize the order that batches are seen.
|
||||
"""
|
||||
if randomize:
|
||||
random.shuffle(interaction_batches)
|
||||
progbar = get_progressbar("train ", len(interaction_batches))
|
||||
progbar.start()
|
||||
loss_sum = 0.
|
||||
break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1"]
|
||||
#break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1", "behavior_monitoring"]
|
||||
for i, interaction_batch in enumerate(interaction_batches):
|
||||
assert len(interaction_batch) == 1
|
||||
interaction = interaction_batch.items[0]
|
||||
|
||||
|
||||
name = interaction.identifier.split('/')[0]
|
||||
if name in break_lst:
|
||||
continue
|
||||
|
||||
print(i, interaction.identifier)
|
||||
|
||||
# try:
|
||||
batch_loss = model.train_step(interaction, params.train_maximum_sql_length)
|
||||
# except RuntimeError:
|
||||
# torch.cuda.empty_cache()
|
||||
# continue
|
||||
print(batch_loss)
|
||||
|
||||
loss_sum += batch_loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
progbar.update(i)
|
||||
|
||||
progbar.finish()
|
||||
|
||||
total_loss = loss_sum / len(interaction_batches)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def update_sums(metrics,
|
||||
metrics_sums,
|
||||
predicted_sequence,
|
||||
flat_sequence,
|
||||
gold_query,
|
||||
original_gold_query,
|
||||
gold_forcing=False,
|
||||
loss=None,
|
||||
token_accuracy=0.,
|
||||
database_username="",
|
||||
database_password="",
|
||||
database_timeout=0,
|
||||
gold_table=None):
|
||||
"""" Updates summing for metrics in an aggregator.
|
||||
|
||||
TODO: don't use sums, just keep the raw value.
|
||||
"""
|
||||
if Metrics.LOSS in metrics:
|
||||
metrics_sums[Metrics.LOSS] += loss.item()
|
||||
if Metrics.TOKEN_ACCURACY in metrics:
|
||||
if gold_forcing:
|
||||
metrics_sums[Metrics.TOKEN_ACCURACY] += token_accuracy
|
||||
else:
|
||||
num_tokens_correct = 0.
|
||||
for j, token in enumerate(gold_query):
|
||||
if len(
|
||||
predicted_sequence) > j and predicted_sequence[j] == token:
|
||||
num_tokens_correct += 1
|
||||
metrics_sums[Metrics.TOKEN_ACCURACY] += num_tokens_correct / \
|
||||
len(gold_query)
|
||||
if Metrics.STRING_ACCURACY in metrics:
|
||||
metrics_sums[Metrics.STRING_ACCURACY] += int(
|
||||
flat_sequence == original_gold_query)
|
||||
|
||||
if Metrics.CORRECT_TABLES in metrics:
|
||||
assert database_username, "You did not provide a database username"
|
||||
assert database_password, "You did not provide a database password"
|
||||
assert database_timeout > 0, "Database timeout is 0 seconds"
|
||||
|
||||
# Evaluate SQL
|
||||
if flat_sequence != original_gold_query:
|
||||
syntactic, semantic, table = sql_util.execution_results(
|
||||
" ".join(flat_sequence), database_username, database_password, database_timeout)
|
||||
else:
|
||||
syntactic = True
|
||||
semantic = True
|
||||
table = gold_table
|
||||
|
||||
metrics_sums[Metrics.CORRECT_TABLES] += int(table == gold_table)
|
||||
if Metrics.SYNTACTIC_QUERIES in metrics:
|
||||
metrics_sums[Metrics.SYNTACTIC_QUERIES] += int(syntactic)
|
||||
if Metrics.SEMANTIC_QUERIES in metrics:
|
||||
metrics_sums[Metrics.SEMANTIC_QUERIES] += int(semantic)
|
||||
if Metrics.STRICT_CORRECT_TABLES in metrics:
|
||||
metrics_sums[Metrics.STRICT_CORRECT_TABLES] += int(
|
||||
table == gold_table and syntactic)
|
||||
|
||||
|
||||
def construct_averages(metrics_sums, total_num):
|
||||
""" Computes the averages for metrics.
|
||||
|
||||
Inputs:
|
||||
metrics_sums (dict Metric -> float): Sums for a metric.
|
||||
total_num (int): Number to divide by (average).
|
||||
"""
|
||||
metrics_averages = {}
|
||||
for metric, value in metrics_sums.items():
|
||||
metrics_averages[metric] = value / total_num
|
||||
if metric != "loss":
|
||||
metrics_averages[metric] *= 100.
|
||||
|
||||
return metrics_averages
|
||||
|
||||
|
||||
def evaluate_utterance_sample(sample,
|
||||
model,
|
||||
max_generation_length,
|
||||
name="",
|
||||
gold_forcing=False,
|
||||
metrics=None,
|
||||
total_num=-1,
|
||||
database_username="",
|
||||
database_password="",
|
||||
database_timeout=0,
|
||||
write_results=False):
|
||||
"""Evaluates a sample of utterance examples.
|
||||
|
||||
Inputs:
|
||||
sample (list of Utterance): Examples to evaluate.
|
||||
model (ATISModel): Model to predict with.
|
||||
max_generation_length (int): Maximum length to generate.
|
||||
name (str): Name to log with.
|
||||
gold_forcing (bool): Whether to force the gold tokens during decoding.
|
||||
metrics (list of Metric): Metrics to evaluate with.
|
||||
total_num (int): Number to divide by when reporting results.
|
||||
database_username (str): Username to use for executing queries.
|
||||
database_password (str): Password to use when executing queries.
|
||||
database_timeout (float): Timeout on queries when executing.
|
||||
write_results (bool): Whether to write the results to a file.
|
||||
"""
|
||||
assert metrics
|
||||
|
||||
if total_num < 0:
|
||||
total_num = len(sample)
|
||||
|
||||
metrics_sums = {}
|
||||
for metric in metrics:
|
||||
metrics_sums[metric] = 0.
|
||||
|
||||
predictions_file = open(name + "_predictions.json", "w")
|
||||
print("Predicting with filename " + str(name) + "_predictions.json")
|
||||
progbar = get_progressbar(name, len(sample))
|
||||
progbar.start()
|
||||
|
||||
predictions = []
|
||||
for i, item in enumerate(sample):
|
||||
_, loss, predicted_seq = model.eval_step(
|
||||
item, max_generation_length, feed_gold_query=gold_forcing)
|
||||
loss = loss / len(item.gold_query())
|
||||
predictions.append(predicted_seq)
|
||||
|
||||
flat_sequence = item.flatten_sequence(predicted_seq)
|
||||
token_accuracy = torch_utils.per_token_accuracy(
|
||||
item.gold_query(), predicted_seq)
|
||||
|
||||
if write_results:
|
||||
write_prediction(
|
||||
predictions_file,
|
||||
identifier=item.interaction.identifier,
|
||||
input_seq=item.input_sequence(),
|
||||
probability=0,
|
||||
prediction=predicted_seq,
|
||||
flat_prediction=flat_sequence,
|
||||
gold_query=item.gold_query(),
|
||||
flat_gold_queries=item.original_gold_queries(),
|
||||
gold_tables=item.gold_tables(),
|
||||
index_in_interaction=item.utterance_index,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout)
|
||||
|
||||
update_sums(metrics,
|
||||
metrics_sums,
|
||||
predicted_seq,
|
||||
flat_sequence,
|
||||
item.gold_query(),
|
||||
item.original_gold_queries()[0],
|
||||
gold_forcing,
|
||||
loss,
|
||||
token_accuracy,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout,
|
||||
gold_table=item.gold_tables()[0])
|
||||
|
||||
progbar.update(i)
|
||||
|
||||
progbar.finish()
|
||||
predictions_file.close()
|
||||
|
||||
return construct_averages(metrics_sums, total_num), None
|
||||
|
||||
|
||||
def evaluate_interaction_sample(sample,
|
||||
model,
|
||||
max_generation_length,
|
||||
name="",
|
||||
gold_forcing=False,
|
||||
metrics=None,
|
||||
total_num=-1,
|
||||
database_username="",
|
||||
database_password="",
|
||||
database_timeout=0,
|
||||
use_predicted_queries=False,
|
||||
write_results=False,
|
||||
use_gpu=False,
|
||||
compute_metrics=False):
|
||||
""" Evaluates a sample of interactions. """
|
||||
predictions_file = open(name + "_predictions.json", "w")
|
||||
print("Predicting with file " + str(name + "_predictions.json"))
|
||||
metrics_sums = {}
|
||||
for metric in metrics:
|
||||
metrics_sums[metric] = 0.
|
||||
progbar = get_progressbar(name, len(sample))
|
||||
progbar.start()
|
||||
num_utterances = 0
|
||||
predictions = []
|
||||
|
||||
use_gpu = not ("--no_gpus" in sys.argv or "--no_gpus=1" in sys.argv)
|
||||
|
||||
model.eval()
|
||||
break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1"]
|
||||
for i, interaction in enumerate(sample):
|
||||
name = interaction.identifier.split('/')[0]
|
||||
if name in break_lst:
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
if use_predicted_queries:
|
||||
example_preds = model.predict_with_predicted_queries(
|
||||
interaction,
|
||||
max_generation_length)
|
||||
else:
|
||||
example_preds = model.predict_with_gold_queries(
|
||||
interaction,
|
||||
max_generation_length,
|
||||
feed_gold_query=gold_forcing)
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as exception:
|
||||
print("Failed on interaction: " + str(interaction.identifier))
|
||||
print(exception)
|
||||
print("\n\n")
|
||||
exit()
|
||||
|
||||
predictions.extend(example_preds)
|
||||
|
||||
assert len(example_preds) == len(
|
||||
interaction.interaction.utterances) or not example_preds
|
||||
for j, pred in enumerate(example_preds):
|
||||
num_utterances += 1
|
||||
|
||||
sequence, loss, token_accuracy, _, decoder_results = pred
|
||||
|
||||
|
||||
if use_predicted_queries:
|
||||
item = interaction.processed_utterances[j]
|
||||
original_utt = interaction.interaction.utterances[item.index]
|
||||
|
||||
gold_query = original_utt.gold_query_to_use
|
||||
original_gold_query = original_utt.original_gold_query
|
||||
|
||||
gold_table = original_utt.gold_sql_results
|
||||
gold_queries = [q[0] for q in original_utt.all_gold_queries]
|
||||
gold_tables = [q[1] for q in original_utt.all_gold_queries]
|
||||
index = item.index
|
||||
else:
|
||||
item = interaction.gold_utterances()[j]
|
||||
|
||||
gold_query = item.gold_query()
|
||||
original_gold_query = item.original_gold_query()
|
||||
|
||||
gold_table = item.gold_table()
|
||||
gold_queries = item.original_gold_queries()
|
||||
gold_tables = item.gold_tables()
|
||||
index = item.utterance_index
|
||||
if loss:
|
||||
loss = loss / len(gold_query)
|
||||
|
||||
flat_sequence = item.flatten_sequence(sequence)
|
||||
|
||||
if write_results:
|
||||
ori_beam = decoder_results.beam
|
||||
beam = []
|
||||
for x in ori_beam:
|
||||
beam.append((-x[0], item.flatten_sequence(x[1].sequence)))
|
||||
beam.sort()
|
||||
write_prediction(
|
||||
predictions_file,
|
||||
identifier=interaction.identifier,
|
||||
input_seq=item.input_sequence(),
|
||||
probability=decoder_results.probability,
|
||||
prediction=sequence,
|
||||
flat_prediction=flat_sequence,
|
||||
gold_query=gold_query,
|
||||
flat_gold_queries=gold_queries,
|
||||
gold_tables=gold_tables,
|
||||
index_in_interaction=index,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout,
|
||||
compute_metrics=compute_metrics,
|
||||
beam=beam)
|
||||
|
||||
update_sums(metrics,
|
||||
metrics_sums,
|
||||
sequence,
|
||||
flat_sequence,
|
||||
gold_query,
|
||||
original_gold_query,
|
||||
gold_forcing,
|
||||
loss,
|
||||
token_accuracy,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout,
|
||||
gold_table=gold_table)
|
||||
|
||||
progbar.update(i)
|
||||
|
||||
progbar.finish()
|
||||
|
||||
if total_num < 0:
|
||||
total_num = num_utterances
|
||||
|
||||
predictions_file.close()
|
||||
return construct_averages(metrics_sums, total_num), predictions
|
||||
|
||||
|
||||
def evaluate_using_predicted_queries(sample,
|
||||
model,
|
||||
name="",
|
||||
gold_forcing=False,
|
||||
metrics=None,
|
||||
total_num=-1,
|
||||
database_username="",
|
||||
database_password="",
|
||||
database_timeout=0,
|
||||
snippet_keep_age=1):
|
||||
predictions_file = open(name + "_predictions.json", "w")
|
||||
print("Predicting with file " + str(name + "_predictions.json"))
|
||||
assert not gold_forcing
|
||||
metrics_sums = {}
|
||||
for metric in metrics:
|
||||
metrics_sums[metric] = 0.
|
||||
progbar = get_progressbar(name, len(sample))
|
||||
progbar.start()
|
||||
|
||||
num_utterances = 0
|
||||
predictions = []
|
||||
for i, item in enumerate(sample):
|
||||
int_predictions = []
|
||||
item.start_interaction()
|
||||
while not item.done():
|
||||
utterance = item.next_utterance(snippet_keep_age)
|
||||
|
||||
predicted_sequence, loss, _, probability = model.eval_step(
|
||||
utterance)
|
||||
int_predictions.append((utterance, predicted_sequence))
|
||||
|
||||
flat_sequence = utterance.flatten_sequence(predicted_sequence)
|
||||
|
||||
if sql_util.executable(
|
||||
flat_sequence,
|
||||
username=database_username,
|
||||
password=database_password,
|
||||
timeout=database_timeout) and probability >= 0.24:
|
||||
utterance.set_pred_query(
|
||||
item.remove_snippets(predicted_sequence))
|
||||
item.add_utterance(utterance,
|
||||
item.remove_snippets(predicted_sequence),
|
||||
previous_snippets=utterance.snippets())
|
||||
else:
|
||||
# Add the /previous/ predicted query, guaranteed to be syntactically
|
||||
# correct
|
||||
seq = []
|
||||
utterance.set_pred_query(seq)
|
||||
item.add_utterance(
|
||||
utterance, seq, previous_snippets=utterance.snippets())
|
||||
|
||||
original_utt = item.interaction.utterances[utterance.index]
|
||||
write_prediction(
|
||||
predictions_file,
|
||||
identifier=item.interaction.identifier,
|
||||
input_seq=utterance.input_sequence(),
|
||||
probability=probability,
|
||||
prediction=predicted_sequence,
|
||||
flat_prediction=flat_sequence,
|
||||
gold_query=original_utt.gold_query_to_use,
|
||||
flat_gold_queries=[
|
||||
q[0] for q in original_utt.all_gold_queries],
|
||||
gold_tables=[
|
||||
q[1] for q in original_utt.all_gold_queries],
|
||||
index_in_interaction=utterance.index,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout)
|
||||
|
||||
update_sums(metrics,
|
||||
metrics_sums,
|
||||
predicted_sequence,
|
||||
flat_sequence,
|
||||
original_utt.gold_query_to_use,
|
||||
original_utt.original_gold_query,
|
||||
gold_forcing,
|
||||
loss,
|
||||
token_accuracy=0,
|
||||
database_username=database_username,
|
||||
database_password=database_password,
|
||||
database_timeout=database_timeout,
|
||||
gold_table=original_utt.gold_sql_results)
|
||||
|
||||
predictions.append(int_predictions)
|
||||
progbar.update(i)
|
||||
|
||||
progbar.finish()
|
||||
|
||||
if total_num < 0:
|
||||
total_num = num_utterances
|
||||
predictions_file.close()
|
||||
|
||||
return construct_averages(metrics_sums, total_num), predictions
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,167 @@
|
|||
import sys
|
||||
args = sys.argv
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
def interpret_args():
|
||||
""" Interprets the command line arguments, and returns a dictionary. """
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--no_gpus", type=bool, default=1)
|
||||
parser.add_argument("--seed", type=int, default=141)
|
||||
### Data parameters
|
||||
parser.add_argument(
|
||||
'--raw_train_filename',
|
||||
type=str,
|
||||
default='../atis_data/data/resplit/processed/train_with_tables.pkl')
|
||||
parser.add_argument(
|
||||
'--raw_dev_filename',
|
||||
type=str,
|
||||
default='../atis_data/data/resplit/processed/dev_with_tables.pkl')
|
||||
parser.add_argument(
|
||||
'--raw_validation_filename',
|
||||
type=str,
|
||||
default='../atis_data/data/resplit/processed/valid_with_tables.pkl')
|
||||
parser.add_argument(
|
||||
'--raw_test_filename',
|
||||
type=str,
|
||||
default='../atis_data/data/resplit/processed/test_with_tables.pkl')
|
||||
|
||||
parser.add_argument('--data_directory', type=str, default='processed_data')
|
||||
|
||||
parser.add_argument('--processed_train_filename', type=str, default='train.pkl')
|
||||
parser.add_argument('--processed_dev_filename', type=str, default='dev.pkl')
|
||||
parser.add_argument('--processed_validation_filename', type=str, default='validation.pkl')
|
||||
parser.add_argument('--processed_test_filename', type=str, default='test.pkl')
|
||||
|
||||
parser.add_argument('--database_schema_filename', type=str, default=None)
|
||||
parser.add_argument('--embedding_filename', type=str, default=None)
|
||||
|
||||
parser.add_argument('--input_vocabulary_filename', type=str, default='input_vocabulary.pkl')
|
||||
parser.add_argument('--output_vocabulary_filename',
|
||||
type=str,
|
||||
default='output_vocabulary.pkl')
|
||||
|
||||
parser.add_argument('--input_key', type=str, default='nl_with_dates')
|
||||
|
||||
parser.add_argument('--anonymize', type=bool, default=False)
|
||||
parser.add_argument('--anonymization_scoring', type=bool, default=False)
|
||||
parser.add_argument('--use_snippets', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--use_previous_query', type=bool, default=False)
|
||||
parser.add_argument('--maximum_queries', type=int, default=1)
|
||||
parser.add_argument('--use_copy_switch', type=bool, default=False)
|
||||
parser.add_argument('--use_query_attention', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--use_utterance_attention', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--freeze', type=bool, default=False)
|
||||
parser.add_argument('--scheduler', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--use_bert', type=bool, default=False)
|
||||
parser.add_argument("--bert_type_abb", type=str, help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS")
|
||||
parser.add_argument("--bert_input_version", type=str, default='v1')
|
||||
parser.add_argument('--fine_tune_bert', type=bool, default=False)
|
||||
parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.')
|
||||
|
||||
### Debugging/logging parameters
|
||||
parser.add_argument('--logdir', type=str, default='logs')
|
||||
parser.add_argument('--deterministic', type=bool, default=False)
|
||||
parser.add_argument('--num_train', type=int, default=-1)
|
||||
|
||||
parser.add_argument('--logfile', type=str, default='log.txt')
|
||||
parser.add_argument('--results_file', type=str, default='results.txt')
|
||||
|
||||
### Model architecture
|
||||
parser.add_argument('--input_embedding_size', type=int, default=300)
|
||||
parser.add_argument('--output_embedding_size', type=int, default=300)
|
||||
|
||||
parser.add_argument('--encoder_state_size', type=int, default=300)
|
||||
parser.add_argument('--decoder_state_size', type=int, default=300)
|
||||
|
||||
parser.add_argument('--encoder_num_layers', type=int, default=1)
|
||||
parser.add_argument('--decoder_num_layers', type=int, default=2)
|
||||
parser.add_argument('--snippet_num_layers', type=int, default=1)
|
||||
|
||||
parser.add_argument('--maximum_utterances', type=int, default=5)
|
||||
parser.add_argument('--state_positional_embeddings', type=bool, default=False)
|
||||
parser.add_argument('--positional_embedding_size', type=int, default=50)
|
||||
|
||||
parser.add_argument('--snippet_age_embedding', type=bool, default=False)
|
||||
parser.add_argument('--snippet_age_embedding_size', type=int, default=64)
|
||||
parser.add_argument('--max_snippet_age_embedding', type=int, default=4)
|
||||
parser.add_argument('--previous_decoder_snippet_encoding', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--discourse_level_lstm', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--use_schema_attention', type=bool, default=False)
|
||||
parser.add_argument('--use_encoder_attention', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--use_schema_encoder', type=bool, default=False)
|
||||
parser.add_argument('--use_schema_self_attention', type=bool, default=False)
|
||||
parser.add_argument('--use_schema_encoder_2', type=bool, default=False)
|
||||
|
||||
### Training parameters
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--train_maximum_sql_length', type=int, default=200)
|
||||
parser.add_argument('--train_evaluation_size', type=int, default=100)
|
||||
|
||||
parser.add_argument('--dropout_amount', type=float, default=0.5)
|
||||
|
||||
parser.add_argument('--initial_patience', type=float, default=10.)
|
||||
parser.add_argument('--patience_ratio', type=float, default=1.01)
|
||||
|
||||
parser.add_argument('--initial_learning_rate', type=float, default=0.001)
|
||||
parser.add_argument('--learning_rate_ratio', type=float, default=0.8)
|
||||
|
||||
parser.add_argument('--interaction_level', type=bool, default=False)
|
||||
parser.add_argument('--reweight_batch', type=bool, default=False)
|
||||
|
||||
### Setting
|
||||
parser.add_argument('--train', type=bool, default=False)
|
||||
parser.add_argument('--debug', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--evaluate', type=bool, default=False)
|
||||
parser.add_argument('--attention', type=bool, default=False)
|
||||
parser.add_argument('--save_file', type=str, default="")
|
||||
parser.add_argument('--enable_testing', type=bool, default=False)
|
||||
parser.add_argument('--use_predicted_queries', type=bool, default=False)
|
||||
parser.add_argument('--evaluate_split', type=str, default='dev')
|
||||
parser.add_argument('--evaluate_with_gold_forcing', type=bool, default=False)
|
||||
parser.add_argument('--eval_maximum_sql_length', type=int, default=1000)
|
||||
parser.add_argument('--results_note', type=str, default='')
|
||||
parser.add_argument('--compute_metrics', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--reference_results', type=str, default='')
|
||||
|
||||
parser.add_argument('--interactive', type=bool, default=False)
|
||||
|
||||
parser.add_argument('--database_username', type=str, default="aviarmy")
|
||||
parser.add_argument('--database_password', type=str, default="aviarmy")
|
||||
parser.add_argument('--database_timeout', type=int, default=2)
|
||||
|
||||
parser.add_argument('--use_transformer_relation', type=bool, default=True)
|
||||
parser.add_argument('--table_schema_path', type=str, default="data/cosql/tables.json")
|
||||
## relace it for colab
|
||||
# parser.add_argument('--table_schema_path', type=str, default="data/table.json")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.logdir):
|
||||
os.makedirs(args.logdir)
|
||||
|
||||
if not (args.train or args.evaluate or args.interactive or args.attention):
|
||||
raise ValueError('You need to be training or evaluating')
|
||||
if args.enable_testing and not args.evaluate:
|
||||
raise ValueError('You should evaluate the model if enabling testing')
|
||||
|
||||
if args.train:
|
||||
args_file = args.logdir + '/args.log'
|
||||
if os.path.exists(args_file):
|
||||
raise ValueError('Warning: arguments already exist in ' + str(args_file))
|
||||
with open(args_file, 'w') as infile:
|
||||
infile.write(str(args))
|
||||
|
||||
return args
|
|
@ -0,0 +1,614 @@
|
|||
import os
|
||||
import json
|
||||
import copy
|
||||
import torch
|
||||
import sqlparse
|
||||
import argparse
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from reranker.predict import Ranker
|
||||
import numpy
|
||||
numpy.random.seed()
|
||||
torch.manual_seed(1)
|
||||
|
||||
#from eval_scripts.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
stop_word = set({'=', 'is', 'and', 'desc', '!=', 'select', ')', '_UNK', 'max', '0', 'none', 'order', 'like', 'join', '>', 'on', 'count', 'or', 'from', 'group_by', 'limit', '<=', 'sum', 'group', '(', '_EOS', 'intersect', 'order_by', 'limit_value', '-', 'except', 'where', 'distinct', 'avg', '<', 'min', 'asc', '+', 'in', 'union', 'between', 'exists', '/', 'having', 'not', 'as', '>=', 'value', ',', '*'})
|
||||
|
||||
#ranker = Ranker("./cosql_train_48.pt")
|
||||
ranker = Ranker("./submit_models/cosql_reranker_roberta.pt")
|
||||
def find_shortest_path(start, end, graph):
|
||||
stack = [[start, []]]
|
||||
visited = list()
|
||||
while len(stack) > 0:
|
||||
ele, history = stack.pop()
|
||||
if ele == end:
|
||||
return history
|
||||
for node in graph[ele]:
|
||||
if node[0] not in visited:
|
||||
stack.append((node[0], history + [(node[0], node[1])]))
|
||||
visited.append(node[0])
|
||||
|
||||
|
||||
def gen_from(candidate_tables, schema):
|
||||
if len(candidate_tables) <= 1:
|
||||
if len(candidate_tables) == 1:
|
||||
ret = "from {}".format(schema["table_names_original"][list(candidate_tables)[0]])
|
||||
else:
|
||||
ret = "from {}".format(schema["table_names_original"][0])
|
||||
return {}, ret
|
||||
|
||||
table_alias_dict = {}
|
||||
uf_dict = {}
|
||||
for t in candidate_tables:
|
||||
uf_dict[t] = -1
|
||||
idx = 1
|
||||
graph = defaultdict(list)
|
||||
for acol, bcol in schema["foreign_keys"]:
|
||||
t1 = schema["column_names"][acol][0]
|
||||
t2 = schema["column_names"][bcol][0]
|
||||
graph[t1].append((t2, (acol, bcol)))
|
||||
graph[t2].append((t1, (bcol, acol)))
|
||||
candidate_tables = list(candidate_tables)
|
||||
start = candidate_tables[0]
|
||||
table_alias_dict[start] = idx
|
||||
idx += 1
|
||||
ret = "from {} as T1".format(schema["table_names_original"][start])
|
||||
try:
|
||||
for end in candidate_tables[1:]:
|
||||
if end in table_alias_dict:
|
||||
continue
|
||||
path = find_shortest_path(start, end, graph)
|
||||
prev_table = start
|
||||
if not path:
|
||||
table_alias_dict[end] = idx
|
||||
idx += 1
|
||||
ret = "{} join {} as T{}".format(ret, schema["table_names_original"][end],
|
||||
table_alias_dict[end],
|
||||
)
|
||||
continue
|
||||
for node, (acol, bcol) in path:
|
||||
if node in table_alias_dict:
|
||||
prev_table = node
|
||||
continue
|
||||
table_alias_dict[node] = idx
|
||||
idx += 1
|
||||
ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(ret, schema["table_names_original"][node],
|
||||
table_alias_dict[node],
|
||||
table_alias_dict[prev_table],
|
||||
schema["column_names_original"][acol][1],
|
||||
table_alias_dict[node],
|
||||
schema["column_names_original"][bcol][1])
|
||||
prev_table = node
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("db:{}".format(schema["db_id"]))
|
||||
return table_alias_dict, ret
|
||||
return table_alias_dict, ret
|
||||
|
||||
|
||||
def normalize_space(format_sql):
|
||||
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('(',' ( ').replace(')',' ) ').split())
|
||||
for sub_sql in format_sql.split('\n')]
|
||||
format_sql_1 = '\n'.join(format_sql_1)
|
||||
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('union ', 'union\n').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
|
||||
return format_sql_2
|
||||
|
||||
|
||||
def get_candidate_tables(format_sql, schema):
|
||||
candidate_tables = []
|
||||
|
||||
tokens = format_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if '.' in token:
|
||||
table_name = token.split('.')[0]
|
||||
candidate_tables.append(table_name)
|
||||
|
||||
candidate_tables = list(set(candidate_tables))
|
||||
candidate_tables.sort()
|
||||
table_names_original = [table_name.lower() for table_name in schema['table_names_original']]
|
||||
candidate_tables_id = [table_names_original.index(table_name) for table_name in candidate_tables]
|
||||
|
||||
assert -1 not in candidate_tables_id
|
||||
table_names_original = schema['table_names_original']
|
||||
|
||||
return candidate_tables_id, table_names_original
|
||||
|
||||
|
||||
def get_surface_form_orig(format_sql_2, schema):
|
||||
column_names_surface_form = []
|
||||
column_names_surface_form_original = []
|
||||
|
||||
column_names_original = schema['column_names_original']
|
||||
table_names_original = schema['table_names_original']
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
# this is just *
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
column_names_surface_form_original.append(column_name_surface_form)
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
column_names_surface_form_original.append('{}.*'.format(table_name))
|
||||
|
||||
assert len(column_names_surface_form) == len(column_names_surface_form_original)
|
||||
for surface_form, surface_form_original in zip(column_names_surface_form, column_names_surface_form_original):
|
||||
format_sql_2 = format_sql_2.replace(surface_form, surface_form_original)
|
||||
|
||||
return format_sql_2
|
||||
|
||||
|
||||
def add_from_clase(sub_sql, from_clause):
|
||||
select_right_sub_sql = []
|
||||
left_sub_sql = []
|
||||
left = True
|
||||
num_left_parathesis = 0 # in select_right_sub_sql
|
||||
num_right_parathesis = 0 # in select_right_sub_sql
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token == 'select':
|
||||
left = False
|
||||
if left:
|
||||
left_sub_sql.append(token)
|
||||
continue
|
||||
select_right_sub_sql.append(token)
|
||||
if token == '(':
|
||||
num_left_parathesis += 1
|
||||
elif token == ')':
|
||||
num_right_parathesis += 1
|
||||
|
||||
def remove_missing_tables_from_select(select_statement):
|
||||
tokens = select_statement.split(',')
|
||||
|
||||
stop_idx = -1
|
||||
for i in range(len(tokens)):
|
||||
idx = len(tokens) - 1 - i
|
||||
token = tokens[idx]
|
||||
if '.*' in token and 'count ' not in token:
|
||||
pass
|
||||
else:
|
||||
stop_idx = idx+1
|
||||
break
|
||||
|
||||
if stop_idx > 0:
|
||||
new_select_statement = ','.join(tokens[:stop_idx]).strip()
|
||||
else:
|
||||
new_select_statement = select_statement
|
||||
|
||||
return new_select_statement
|
||||
|
||||
if num_left_parathesis == num_right_parathesis or num_left_parathesis > num_right_parathesis:
|
||||
sub_sqls = []
|
||||
sub_sqls.append(remove_missing_tables_from_select(sub_sql))
|
||||
sub_sqls.append(from_clause)
|
||||
else:
|
||||
assert num_left_parathesis < num_right_parathesis
|
||||
select_sub_sql = []
|
||||
right_sub_sql = []
|
||||
for i in range(len(select_right_sub_sql)):
|
||||
token_idx = len(select_right_sub_sql)-1-i
|
||||
token = select_right_sub_sql[token_idx]
|
||||
if token == ')':
|
||||
num_right_parathesis -= 1
|
||||
if num_right_parathesis == num_left_parathesis:
|
||||
select_sub_sql = select_right_sub_sql[:token_idx]
|
||||
right_sub_sql = select_right_sub_sql[token_idx:]
|
||||
break
|
||||
|
||||
sub_sqls = []
|
||||
|
||||
if len(left_sub_sql) > 0:
|
||||
sub_sqls.append(' '.join(left_sub_sql))
|
||||
if len(select_sub_sql) > 0:
|
||||
new_select_statement = remove_missing_tables_from_select(' '.join(select_sub_sql))
|
||||
sub_sqls.append(new_select_statement)
|
||||
|
||||
sub_sqls.append(from_clause)
|
||||
|
||||
if len(right_sub_sql) > 0:
|
||||
sub_sqls.append(' '.join(right_sub_sql))
|
||||
|
||||
return sub_sqls
|
||||
|
||||
|
||||
def postprocess_single(format_sql_2, schema, start_alias_id=0):
|
||||
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_2, schema)
|
||||
format_sql_2 = get_surface_form_orig(format_sql_2, schema)
|
||||
|
||||
if len(candidate_tables_id) == 0:
|
||||
final_sql = format_sql_2.replace('\n', ' ')
|
||||
elif len(candidate_tables_id) == 1:
|
||||
# easy case
|
||||
table_name = table_names_original[candidate_tables_id[0]]
|
||||
from_clause = 'from {}'.format(table_name)
|
||||
format_sql_3 = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select' in sub_sql:
|
||||
format_sql_3 += add_from_clase(sub_sql, from_clause)
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
final_sql = ' '.join(format_sql_3).replace('{}.'.format(table_name), '')
|
||||
else:
|
||||
# more than 1 candidate_tables
|
||||
table_alias_dict, ret = gen_from(candidate_tables_id, schema)
|
||||
|
||||
from_clause = ret
|
||||
for i in range(len(table_alias_dict)):
|
||||
from_clause = from_clause.replace('T{}'.format(i+1), 'T{}'.format(i+1+start_alias_id))
|
||||
|
||||
table_name_to_alias = {}
|
||||
for table_id, alias_id in table_alias_dict.items():
|
||||
table_name = table_names_original[table_id]
|
||||
alias = 'T{}'.format(alias_id+start_alias_id)
|
||||
table_name_to_alias[table_name] = alias
|
||||
start_alias_id = start_alias_id + len(table_alias_dict)
|
||||
|
||||
format_sql_3 = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select' in sub_sql:
|
||||
format_sql_3 += add_from_clase(sub_sql, from_clause)
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
format_sql_3 = ' '.join(format_sql_3)
|
||||
|
||||
for table_name, alias in table_name_to_alias.items():
|
||||
format_sql_3 = format_sql_3.replace('{}.'.format(table_name), '{}.'.format(alias))
|
||||
|
||||
final_sql = format_sql_3
|
||||
|
||||
for i in range(5):
|
||||
final_sql = final_sql.replace('select count ( T{}.* ) '.format(i), 'select count ( * ) ')
|
||||
final_sql = final_sql.replace('count ( T{}.* ) from '.format(i), 'count ( * ) from ')
|
||||
final_sql = final_sql.replace('order by count ( T{}.* ) '.format(i), 'order by count ( * ) ')
|
||||
final_sql = final_sql.replace('having count ( T{}.* ) '.format(i), 'having count ( * ) ')
|
||||
|
||||
return final_sql, start_alias_id
|
||||
|
||||
|
||||
def postprocess_nested(format_sql_2, schema):
|
||||
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_2, schema)
|
||||
if len(candidate_tables_id) == 1:
|
||||
format_sql_2 = get_surface_form_orig(format_sql_2, schema)
|
||||
# easy case
|
||||
table_name = table_names_original[candidate_tables_id[0]]
|
||||
from_clause = 'from {}'.format(table_name)
|
||||
format_sql_3 = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select' in sub_sql:
|
||||
format_sql_3 += add_from_clase(sub_sql, from_clause)
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
final_sql = ' '.join(format_sql_3).replace('{}.'.format(table_name), '')
|
||||
else:
|
||||
# case 1: easy case, except / union / intersect
|
||||
# case 2: nested queries in condition
|
||||
final_sql = []
|
||||
|
||||
num_keywords = format_sql_2.count('except') + format_sql_2.count('union') + format_sql_2.count('intersect')
|
||||
num_select = format_sql_2.count('select')
|
||||
|
||||
def postprocess_subquery(sub_query_one, schema, start_alias_id_1):
|
||||
num_select = sub_query_one.count('select ')
|
||||
final_sub_sql = []
|
||||
sub_query = []
|
||||
for sub_sql in sub_query_one.split('\n'):
|
||||
if 'select' in sub_sql:
|
||||
if len(sub_query) > 0:
|
||||
sub_query = '\n'.join(sub_query)
|
||||
sub_query, start_alias_id_1 = postprocess_single(sub_query, schema, start_alias_id_1)
|
||||
final_sub_sql.append(sub_query)
|
||||
sub_query = []
|
||||
sub_query.append(sub_sql)
|
||||
else:
|
||||
sub_query.append(sub_sql)
|
||||
if len(sub_query) > 0:
|
||||
sub_query = '\n'.join(sub_query)
|
||||
sub_query, start_alias_id_1 = postprocess_single(sub_query, schema, start_alias_id_1)
|
||||
final_sub_sql.append(sub_query)
|
||||
|
||||
final_sub_sql = ' '.join(final_sub_sql)
|
||||
return final_sub_sql, False, start_alias_id_1
|
||||
|
||||
start_alias_id = 0
|
||||
sub_query = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'except' in sub_sql or 'union' in sub_sql or 'intersect' in sub_sql:
|
||||
sub_query = '\n'.join(sub_query)
|
||||
sub_query, _, start_alias_id = postprocess_subquery(sub_query, schema, start_alias_id)
|
||||
final_sql.append(sub_query)
|
||||
final_sql.append(sub_sql)
|
||||
sub_query = []
|
||||
else:
|
||||
sub_query.append(sub_sql)
|
||||
if len(sub_query) > 0:
|
||||
sub_query = '\n'.join(sub_query)
|
||||
sub_query, _, start_alias_id = postprocess_subquery(sub_query, schema, start_alias_id)
|
||||
final_sql.append(sub_query)
|
||||
|
||||
final_sql = ' '.join(final_sql)
|
||||
|
||||
# special case of from a subquery
|
||||
final_sql = final_sql.replace('select count ( * ) (', 'select count ( * ) from (')
|
||||
|
||||
return final_sql
|
||||
|
||||
|
||||
def postprocess_one(pred_sql, schema):
|
||||
pred_sql = pred_sql.replace('group_by', 'group by').replace('order_by', 'order by').replace('limit_value', 'limit 1').replace('_EOS', '').replace(' value ',' 1 ').replace('distinct', '').strip(',').strip()
|
||||
if pred_sql.endswith('value'):
|
||||
pred_sql = pred_sql[:-len('value')] + '1'
|
||||
|
||||
try:
|
||||
format_sql = sqlparse.format(pred_sql, reindent=True)
|
||||
except:
|
||||
if len(pred_sql) == 0:
|
||||
return "None"
|
||||
return pred_sql
|
||||
format_sql_2 = normalize_space(format_sql)
|
||||
|
||||
num_select = format_sql_2.count('select')
|
||||
|
||||
if num_select > 1:
|
||||
final_sql = postprocess_nested(format_sql_2, schema)
|
||||
else:
|
||||
final_sql, _ = postprocess_single(format_sql_2, schema)
|
||||
if final_sql == "":
|
||||
return "None"
|
||||
return final_sql
|
||||
|
||||
def postprocess(predictions, database_schema, remove_from=False):
|
||||
import math
|
||||
use_reranker = True
|
||||
correct = 0
|
||||
total = 0
|
||||
postprocess_sqls = {}
|
||||
utterances = []
|
||||
Count = 0
|
||||
score_vis = dict()
|
||||
if os.path.exists('./score_vis.json'):
|
||||
with open('./score_vis.json', 'r') as f:
|
||||
score_vis = json.load(f)
|
||||
|
||||
for pred in predictions:
|
||||
Count += 1
|
||||
if Count % 10 == 0:
|
||||
print('Count', Count, "score_vis", len(score_vis))
|
||||
db_id = pred['database_id']
|
||||
schema = database_schema[db_id]
|
||||
try:
|
||||
return_match = pred['return_match']
|
||||
except:
|
||||
return_match = ''
|
||||
if db_id not in postprocess_sqls:
|
||||
postprocess_sqls[db_id] = []
|
||||
|
||||
interaction_id = pred['interaction_id']
|
||||
turn_id = pred['index_in_interaction']
|
||||
total += 1
|
||||
|
||||
if turn_id == 0:
|
||||
utterances = []
|
||||
|
||||
question = ' '.join(pred['input_seq'])
|
||||
pred_sql_str = ' '.join(pred['flat_prediction'])
|
||||
|
||||
utterances.append(question)
|
||||
|
||||
beam_sql_strs = []
|
||||
for score, beam_sql_str in pred['beam']:
|
||||
beam_sql_strs.append( (score, ' '.join(beam_sql_str)))
|
||||
|
||||
gold_sql_str = ' '.join(pred['flat_gold_queries'][0])
|
||||
if pred_sql_str == gold_sql_str:
|
||||
correct += 1
|
||||
|
||||
postprocess_sql = pred_sql_str
|
||||
if remove_from:
|
||||
postprocess_sql = postprocess_one(pred_sql_str, schema)
|
||||
sqls = []
|
||||
key_idx = dict()
|
||||
for i in range(len(beam_sql_strs)):
|
||||
sql = postprocess_one(beam_sql_strs[i][1], schema)
|
||||
key = '&'.join(utterances)+'#'+sql
|
||||
if key not in score_vis:
|
||||
if key not in key_idx:
|
||||
key_idx[key]=len(sqls)
|
||||
sqls.append(sql.replace("value", "1"))
|
||||
if use_reranker and len(sqls) > 0:
|
||||
score_list = ranker.get_score_batch(utterances, sqls)
|
||||
|
||||
for i in range(len(beam_sql_strs)):
|
||||
score = beam_sql_strs[i][0]
|
||||
sql = postprocess_one(beam_sql_strs[i][1], schema)
|
||||
if use_reranker:
|
||||
key = '&'.join(utterances)+'#'+sql
|
||||
old_score = score
|
||||
if key in score_vis:
|
||||
score = score_vis[key]
|
||||
else:
|
||||
score = score_list[key_idx[key]]
|
||||
score_vis[key] = score
|
||||
score += i * 1e-4
|
||||
score = -math.log(score)
|
||||
score += old_score
|
||||
beam_sql_strs[i] = (score, sql)
|
||||
assert beam_sql_strs[0][1] == postprocess_sql
|
||||
if use_reranker and Count % 100 == 0:
|
||||
with open('score_vis.json', 'w') as file:
|
||||
json.dump(score_vis, file)
|
||||
|
||||
gold_sql_str = postprocess_one(gold_sql_str, schema)
|
||||
postprocess_sqls[db_id].append((postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_sql_str))
|
||||
|
||||
print (correct, total, float(correct)/total)
|
||||
return postprocess_sqls
|
||||
|
||||
|
||||
def read_prediction(pred_file):
|
||||
if len( pred_file.split(',') ) > 1:
|
||||
pred_files = pred_file.split(',')
|
||||
|
||||
sum_predictions = []
|
||||
|
||||
for pred_file in pred_files:
|
||||
print('Read prediction from', pred_file)
|
||||
predictions = []
|
||||
with open(pred_file) as f:
|
||||
for line in f:
|
||||
pred = json.loads(line)
|
||||
pred["beam"] = pred["beam"][:5]
|
||||
predictions.append(pred)
|
||||
print('Number of predictions', len(predictions))
|
||||
|
||||
if len(sum_predictions) == 0:
|
||||
sum_predictions = predictions
|
||||
else:
|
||||
assert len(sum_predictions) == len(predictions)
|
||||
for i in range(len(sum_predictions)):
|
||||
sum_predictions[i]['beam'] += predictions[i]['beam']
|
||||
|
||||
return sum_predictions
|
||||
else:
|
||||
print('Read prediction from', pred_file)
|
||||
predictions = []
|
||||
with open(pred_file) as f:
|
||||
for line in f:
|
||||
pred = json.loads(line)
|
||||
pred["beam"] = pred["beam"][:5]
|
||||
predictions.append(pred)
|
||||
print('Number of predictions', len(predictions))
|
||||
return predictions
|
||||
|
||||
def read_schema(table_schema_path):
|
||||
with open(table_schema_path) as f:
|
||||
database_schema = json.load(f)
|
||||
|
||||
database_schema_dict = {}
|
||||
for table_schema in database_schema:
|
||||
db_id = table_schema['db_id']
|
||||
database_schema_dict[db_id] = table_schema
|
||||
|
||||
return database_schema_dict
|
||||
|
||||
|
||||
def write_and_evaluate(postprocess_sqls, db_path, table_schema_path, gold_path, dataset):
|
||||
db_list = []
|
||||
if gold_path is not None:
|
||||
with open(gold_path) as f:
|
||||
for line in f:
|
||||
line_split = line.strip().split('\t')
|
||||
if len(line_split) != 2:
|
||||
continue
|
||||
db = line.strip().split('\t')[1]
|
||||
if db not in db_list:
|
||||
db_list.append(db)
|
||||
else:
|
||||
gold_path = './fake_gold.txt'
|
||||
with open(gold_path, "w") as file:
|
||||
db_list = list(postprocess_sqls.keys())
|
||||
last_id = None
|
||||
for db in db_list:
|
||||
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
|
||||
if last_id is not None and last_id != str(interaction_id)+db:
|
||||
file.write('\n')
|
||||
last_id = str(interaction_id)+db
|
||||
file.write(gold_str+'\t'+db+'\n')
|
||||
|
||||
output_file = 'output_temp.txt'
|
||||
|
||||
if dataset == 'spider':
|
||||
with open(output_file, "w") as f:
|
||||
for db in db_list:
|
||||
for postprocess_sql, interaction_id, turn_id in postprocess_sqls[db]:
|
||||
f.write(postprocess_sql+'\n')
|
||||
|
||||
command = 'python3 eval_scripts/evaluation.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,
|
||||
table_schema_path,
|
||||
gold_path,
|
||||
os.path.abspath(output_file))
|
||||
elif dataset in ['sparc', 'cosql']:
|
||||
cnt = 0
|
||||
with open(output_file, "w") as f:
|
||||
last_id = None
|
||||
for db in db_list:
|
||||
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
|
||||
if last_id is not None and last_id != str(interaction_id)+db:
|
||||
f.write('\n')
|
||||
last_id = str(interaction_id) + db
|
||||
f.write('{}\n'.format( '\t'.join( [x[1] for x in beam_sql_strs] ) ))
|
||||
f.write('{}\n'.format( '\t'.join( [str(x[0]) for x in beam_sql_strs] )) )
|
||||
f.write('{}\n'.format( question ))
|
||||
cnt += 1
|
||||
"""
|
||||
predict_file = 'predicted_sql.txt'
|
||||
cnt = 0
|
||||
with open(predict_file, "w") as f:
|
||||
for db in db_list:
|
||||
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
|
||||
|
||||
print(postprocess_sql)
|
||||
print(beam_sql_strs)
|
||||
print(question)
|
||||
print(gold_str)
|
||||
|
||||
if turn_id == 0 and cnt > 0:
|
||||
f.write('\n')
|
||||
f.write('{}\n'.format(postprocess_sql))
|
||||
cnt += 1
|
||||
"""
|
||||
|
||||
|
||||
command = 'python3 eval_scripts/gen_final.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,table_schema_path,gold_path,os.path.abspath(output_file))
|
||||
|
||||
#command = 'python3 eval_scripts/evaluation_sqa.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,table_schema_path,gold_path,os.path.abspath(output_file))
|
||||
#command += '; rm output_temp.txt'
|
||||
|
||||
print('begin command')
|
||||
return command
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset', choices=('spider', 'sparc', 'cosql'), default='sparc')
|
||||
parser.add_argument('--split', type=str, default='dev')
|
||||
parser.add_argument('--pred_file', type=str, default='')
|
||||
parser.add_argument('--remove_from', action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
db_path = 'data/database/'
|
||||
gold_path = None
|
||||
if args.dataset == 'spider':
|
||||
table_schema_path = 'data/spider/tables.json'
|
||||
if args.split == 'dev':
|
||||
gold_path = 'data/spider/dev_gold.sql'
|
||||
elif args.dataset == 'sparc':
|
||||
table_schema_path = 'data/sparc/tables.json'
|
||||
if args.split == 'dev':
|
||||
gold_path = 'data/sparc/dev_gold.txt'
|
||||
elif args.dataset == 'cosql':
|
||||
table_schema_path = 'data/cosql/tables.json'
|
||||
if args.split == 'dev':
|
||||
gold_path = 'data/cosql/dev_gold.txt'
|
||||
|
||||
pred_file = args.pred_file
|
||||
|
||||
database_schema = read_schema(table_schema_path)
|
||||
predictions = read_prediction(pred_file)
|
||||
postprocess_sqls = postprocess(predictions, database_schema, args.remove_from)
|
||||
|
||||
command = write_and_evaluate(postprocess_sqls, db_path, table_schema_path, gold_path, args.dataset)
|
||||
|
||||
print('command', command)
|
||||
|
||||
eval_output = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)
|
||||
if len( pred_file.split(',') ) > 1:
|
||||
pred_file = 'ensemble'
|
||||
else:
|
||||
pred_file = pred_file.split(',')[0]
|
||||
with open(pred_file+'.eval', 'w') as f:
|
||||
f.write(eval_output.decode("utf-8"))
|
||||
print('Eval result in', pred_file+'.eval')
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,556 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import json
|
||||
import shutil
|
||||
import sqlparse
|
||||
from postprocess_eval import get_candidate_tables
|
||||
|
||||
|
||||
def write_interaction(interaction_list,split,output_dir):
|
||||
json_split = os.path.join(output_dir,split+'.json')
|
||||
pkl_split = os.path.join(output_dir,split+'.pkl')
|
||||
|
||||
with open(json_split, 'w') as outfile:
|
||||
for interaction in interaction_list:
|
||||
json.dump(interaction, outfile, indent = 4)
|
||||
outfile.write('\n')
|
||||
|
||||
new_objs = []
|
||||
for i, obj in enumerate(interaction_list):
|
||||
new_interaction = []
|
||||
for ut in obj["interaction"]:
|
||||
sql = ut["sql"]
|
||||
sqls = [sql]
|
||||
tok_sql_list = []
|
||||
for sql in sqls:
|
||||
results = []
|
||||
tokenized_sql = sql.split()
|
||||
tok_sql_list.append((tokenized_sql, results))
|
||||
ut["sql"] = tok_sql_list
|
||||
new_interaction.append(ut)
|
||||
obj["interaction"] = new_interaction
|
||||
new_objs.append(obj)
|
||||
|
||||
with open(pkl_split,'wb') as outfile:
|
||||
pickle.dump(new_objs, outfile)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas_dict):
|
||||
with open(database_schema_filename) as f:
|
||||
database_schemas = json.load(f)
|
||||
|
||||
def get_schema_tokens(table_schema):
|
||||
column_names_surface_form = []
|
||||
column_names = []
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
# this is just *
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
column_names.append(column_name.lower())
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
|
||||
return column_names_surface_form, column_names
|
||||
|
||||
for table_schema in database_schemas:
|
||||
database_id = table_schema['db_id']
|
||||
database_schemas_dict[database_id] = table_schema
|
||||
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
|
||||
|
||||
return schema_tokens, column_names, database_schemas_dict
|
||||
|
||||
|
||||
def remove_from_with_join(format_sql_2):
|
||||
used_tables_list = []
|
||||
format_sql_3 = []
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
old_table_to_name = {}
|
||||
old_table_list = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
old_table_list = table_list
|
||||
old_table_to_name = table_to_name
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
format_sql_3.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
new_sub_sql = None
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
for t_i, t in enumerate(sub_sql_tokens):
|
||||
if t == 'as':
|
||||
table_to_name[sub_sql_tokens[t_i+1]] = sub_sql_tokens[t_i-1]
|
||||
table_list.append(sub_sql_tokens[t_i-1])
|
||||
elif t == ')' and new_sub_sql is None:
|
||||
# new_sub_sql keeps some trailing parts after ')'
|
||||
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
|
||||
if len(table_list) > 0:
|
||||
# if it's a from clause with join
|
||||
if new_sub_sql is not None:
|
||||
format_sql_3.append(new_sub_sql)
|
||||
|
||||
used_tables_list.append(table_list)
|
||||
else:
|
||||
# if it's a from clause without join
|
||||
table_list = old_table_list
|
||||
table_to_name = old_table_to_name
|
||||
assert 'join' not in sub_sql
|
||||
if new_sub_sql is not None:
|
||||
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
|
||||
assert len(sub_sub_sql.split()) == 2
|
||||
used_tables_list.append([sub_sub_sql.split()[1]])
|
||||
format_sql_3.append(sub_sub_sql)
|
||||
format_sql_3.append(new_sub_sql)
|
||||
elif 'join' not in sub_sql:
|
||||
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
|
||||
if len(sub_sql.split()) == 2:
|
||||
used_tables_list.append([sub_sql.split()[1]])
|
||||
|
||||
format_sql_3.append(sub_sql)
|
||||
else:
|
||||
print('bad from clause in remove_from_with_join')
|
||||
exit()
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
used_tables = []
|
||||
for t in used_tables_list:
|
||||
for tt in t:
|
||||
used_tables.append(tt)
|
||||
used_tables = list(set(used_tables))
|
||||
|
||||
return format_sql_3, used_tables, used_tables_list
|
||||
|
||||
|
||||
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
|
||||
format_sql_4 = []
|
||||
table_name = None
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
format_sql_4.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
if len(sub_sql_tokens) == 1:
|
||||
table_name = None
|
||||
elif len(sub_sql_tokens) == 2:
|
||||
table_name = sub_sql_tokens[1]
|
||||
else:
|
||||
print('bad from clause in remove_from_without_join')
|
||||
print(format_sql_3)
|
||||
exit()
|
||||
else:
|
||||
format_sql_4.append(sub_sql)
|
||||
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
|
||||
# If just one table used, easy case, replace all column_name -> table_name.column_name
|
||||
if len(used_tables) == 1:
|
||||
table_name = used_tables[0]
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
return format_sql_4
|
||||
|
||||
def get_table_name_for(token):
|
||||
table_names = []
|
||||
for table_name in used_tables:
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
table_names.append(table_name)
|
||||
if len(table_names) == 0:
|
||||
return 'table'
|
||||
if len(table_names) > 1:
|
||||
return None
|
||||
else:
|
||||
return table_names[0]
|
||||
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
# skip *
|
||||
if token == '*':
|
||||
continue
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
table_name = get_table_name_for(token)
|
||||
if table_name:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def check_oov(format_sql_final, output_vocab, schema_tokens):
|
||||
for sql_tok in format_sql_final.split():
|
||||
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
|
||||
print('OOV!', sql_tok)
|
||||
raise Exception('OOV')
|
||||
|
||||
|
||||
def normalize_space(format_sql):
|
||||
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('.',' . ').replace('(',' ( ').replace(')',' ) ').split()) for sub_sql in format_sql.split('\n')]
|
||||
format_sql_1 = '\n'.join(format_sql_1)
|
||||
|
||||
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor', 'select constructor').replace('select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork', 'select work').replace('select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder', 'select order').replace('from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor', 'from for').replace('from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace('.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace('.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor', '. constructor').replace('.\ngroup', '. group').replace('.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
|
||||
return format_sql_2
|
||||
|
||||
|
||||
def normalize_final_sql(format_sql_5):
|
||||
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
|
||||
|
||||
# normalize two bad sqls
|
||||
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
|
||||
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
|
||||
|
||||
# This is the failure case of remove_from_without_join()
|
||||
format_sql_final = format_sql_final.replace('select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value', 'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def parse_sql(sql_string, db_id, column_names, output_vocab, schema_tokens, schema):
|
||||
format_sql = sqlparse.format(sql_string, reindent=True)
|
||||
format_sql_2 = normalize_space(format_sql)
|
||||
|
||||
num_from = sum([1 for sub_sql in format_sql_2.split('\n') if sub_sql.startswith('from')])
|
||||
num_select = format_sql_2.count('select ') + format_sql_2.count('select\n')
|
||||
|
||||
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
|
||||
|
||||
format_sql_3 = '\n'.join(format_sql_3)
|
||||
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
|
||||
|
||||
format_sql_4 = '\n'.join(format_sql_4)
|
||||
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
|
||||
|
||||
format_sql_5 = '\n'.join(format_sql_5)
|
||||
format_sql_final = normalize_final_sql(format_sql_5)
|
||||
|
||||
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_final, schema)
|
||||
|
||||
failure = False
|
||||
if len(candidate_tables_id) != len(used_tables):
|
||||
failure = True
|
||||
|
||||
check_oov(format_sql_final, output_vocab, schema_tokens)
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def read_spider_split(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
with open(split_json) as f:
|
||||
split_data = json.load(f)
|
||||
print('read_spider_split', split_json, len(split_data))
|
||||
|
||||
for i, ex in enumerate(split_data):
|
||||
db_id = ex['db_id']
|
||||
|
||||
final_sql = []
|
||||
skip = False
|
||||
for query_tok in ex['query_toks_no_value']:
|
||||
if query_tok != '.' and '.' in query_tok:
|
||||
# invalid sql; didn't use table alias in join
|
||||
final_sql += query_tok.replace('.',' . ').split()
|
||||
skip = True
|
||||
else:
|
||||
final_sql.append(query_tok)
|
||||
final_sql = ' '.join(final_sql)
|
||||
|
||||
if skip and 'train' in split_json:
|
||||
continue
|
||||
|
||||
if remove_from:
|
||||
final_sql_parse = parse_sql(final_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
|
||||
else:
|
||||
final_sql_parse = final_sql
|
||||
|
||||
final_utterance = ' '.join(ex['question_toks'])
|
||||
|
||||
if db_id not in interaction_list:
|
||||
interaction_list[db_id] = []
|
||||
|
||||
interaction = {}
|
||||
interaction['id'] = ''
|
||||
interaction['scenario'] = ''
|
||||
interaction['database_id'] = db_id
|
||||
interaction['interaction_id'] = len(interaction_list[db_id])
|
||||
interaction['final'] = {}
|
||||
interaction['final']['utterance'] = final_utterance
|
||||
interaction['final']['sql'] = final_sql_parse
|
||||
interaction['interaction'] = [{'utterance': final_utterance, 'sql': final_sql_parse}]
|
||||
interaction_list[db_id].append(interaction)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_data_json(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
with open(split_json) as f:
|
||||
split_data = json.load(f)
|
||||
print('read_data_json', split_json, len(split_data))
|
||||
|
||||
for interaction_data in split_data:
|
||||
db_id = interaction_data['database_id']
|
||||
final_sql = interaction_data['final']['query']
|
||||
final_utterance = interaction_data['final']['utterance']
|
||||
|
||||
if db_id not in interaction_list:
|
||||
interaction_list[db_id] = []
|
||||
|
||||
# no interaction_id in train
|
||||
if 'interaction_id' in interaction_data['interaction']:
|
||||
interaction_id = interaction_data['interaction']['interaction_id']
|
||||
else:
|
||||
interaction_id = len(interaction_list[db_id])
|
||||
|
||||
interaction = {}
|
||||
interaction['id'] = ''
|
||||
interaction['scenario'] = ''
|
||||
interaction['database_id'] = db_id
|
||||
interaction['interaction_id'] = interaction_id
|
||||
interaction['final'] = {}
|
||||
interaction['final']['utterance'] = final_utterance
|
||||
interaction['final']['sql'] = final_sql
|
||||
interaction['interaction'] = []
|
||||
|
||||
for turn in interaction_data['interaction']:
|
||||
turn_sql = []
|
||||
skip = False
|
||||
for query_tok in turn['query_toks_no_value']:
|
||||
if query_tok != '.' and '.' in query_tok:
|
||||
# invalid sql; didn't use table alias in join
|
||||
turn_sql += query_tok.replace('.',' . ').split()
|
||||
skip = True
|
||||
else:
|
||||
turn_sql.append(query_tok)
|
||||
turn_sql = ' '.join(turn_sql)
|
||||
|
||||
# Correct some human sql annotation error
|
||||
turn_sql = turn_sql.replace('select f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id', 'select t1 . f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id')
|
||||
turn_sql = turn_sql.replace('select name from climber mountain', 'select name from climber')
|
||||
turn_sql = turn_sql.replace('select sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid', 'select t1 . sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid')
|
||||
turn_sql = turn_sql.replace('select avg ( price ) from goods )', 'select avg ( price ) from goods')
|
||||
turn_sql = turn_sql.replace('select min ( annual_fuel_cost ) , from vehicles', 'select min ( annual_fuel_cost ) from vehicles')
|
||||
turn_sql = turn_sql.replace('select * from goods where price < ( select avg ( price ) from goods', 'select * from goods where price < ( select avg ( price ) from goods )')
|
||||
turn_sql = turn_sql.replace('select distinct id , price from goods where price < ( select avg ( price ) from goods', 'select distinct id , price from goods where price < ( select avg ( price ) from goods )')
|
||||
turn_sql = turn_sql.replace('select id from goods where price > ( select avg ( price ) from goods', 'select id from goods where price > ( select avg ( price ) from goods )')
|
||||
|
||||
if skip and 'train' in split_json:
|
||||
continue
|
||||
|
||||
if remove_from:
|
||||
try:
|
||||
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
|
||||
except:
|
||||
print('continue')
|
||||
continue
|
||||
else:
|
||||
turn_sql_parse = turn_sql
|
||||
|
||||
if 'utterance_toks' in turn:
|
||||
turn_utterance = ' '.join(turn['utterance_toks']) # not lower()
|
||||
else:
|
||||
turn_utterance = turn['utterance']
|
||||
|
||||
interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
|
||||
if len(interaction['interaction']) > 0:
|
||||
interaction_list[db_id].append(interaction)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(spider_dir, 'train.json')
|
||||
interaction_list = read_spider_split(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(spider_dir, 'dev.json')
|
||||
interaction_list = read_spider_split(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(sparc_dir, 'train_no_value.json')
|
||||
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(sparc_dir, 'dev_no_value.json')
|
||||
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(cosql_dir, 'train.json')
|
||||
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(cosql_dir, 'dev.json')
|
||||
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_db_split(data_dir):
|
||||
train_database = []
|
||||
with open(os.path.join(data_dir,'train_db_ids.txt')) as f:
|
||||
for line in f:
|
||||
train_database.append(line.strip())
|
||||
|
||||
dev_database = []
|
||||
with open(os.path.join(data_dir,'dev_db_ids.txt')) as f:
|
||||
for line in f:
|
||||
dev_database.append(line.strip())
|
||||
|
||||
return train_database, dev_database
|
||||
|
||||
|
||||
def preprocess(dataset, remove_from=False):
|
||||
# Validate output_vocab
|
||||
output_vocab = ['_UNK', '_EOS', '.', 't1', 't2', '=', 'select', 'from', 'as', 'value', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'count', 'group', 'order', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
|
||||
if remove_from:
|
||||
output_vocab = ['_UNK', '_EOS', '=', 'select', 'value', ')', '(', 'where', ',', 'count', 'group_by', 'order_by', 'distinct', 'and', 'limit_value', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!=', 'union', 'between', '-', '+', '/']
|
||||
print('size of output_vocab', len(output_vocab))
|
||||
print('output_vocab', output_vocab)
|
||||
print()
|
||||
|
||||
if dataset == 'spider':
|
||||
spider_dir = 'data/spider/'
|
||||
database_schema_filename = 'data/spider/tables.json'
|
||||
output_dir = 'data/spider_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/spider_data_removefrom'
|
||||
train_database, dev_database = read_db_split(spider_dir)
|
||||
elif dataset == 'sparc':
|
||||
sparc_dir = 'data/sparc/'
|
||||
database_schema_filename = 'data/sparc/tables.json'
|
||||
output_dir = 'data/sparc_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/sparc_data_removefrom'
|
||||
train_database, dev_database = read_db_split(sparc_dir)
|
||||
elif dataset == 'cosql':
|
||||
cosql_dir = 'data/cosql/'
|
||||
database_schema_filename = 'data/cosql/tables.json'
|
||||
output_dir = 'data/cosql_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/cosql_data_removefrom'
|
||||
train_database, dev_database = read_db_split(cosql_dir)
|
||||
|
||||
if os.path.isdir(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
os.mkdir(output_dir)
|
||||
|
||||
schema_tokens = {}
|
||||
column_names = {}
|
||||
database_schemas = {}
|
||||
|
||||
print('Reading spider database schema file')
|
||||
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas)
|
||||
num_database = len(schema_tokens)
|
||||
print('num_database', num_database, len(train_database), len(dev_database))
|
||||
print('total number of schema_tokens / databases:', len(schema_tokens))
|
||||
|
||||
output_database_schema_filename = os.path.join(output_dir, 'tables.json')
|
||||
with open(output_database_schema_filename, 'w') as outfile:
|
||||
json.dump([v for k,v in database_schemas.items()], outfile, indent=4)
|
||||
|
||||
if dataset == 'spider':
|
||||
interaction_list = read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
elif dataset == 'sparc':
|
||||
interaction_list = read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
elif dataset == 'cosql':
|
||||
interaction_list = read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
print('interaction_list length', len(interaction_list))
|
||||
|
||||
train_interaction = []
|
||||
for database_id in interaction_list:
|
||||
if database_id not in dev_database:
|
||||
train_interaction += interaction_list[database_id]
|
||||
|
||||
dev_interaction = []
|
||||
for database_id in dev_database:
|
||||
dev_interaction += interaction_list[database_id]
|
||||
|
||||
print('train interaction: ', len(train_interaction))
|
||||
print('dev interaction: ', len(dev_interaction))
|
||||
|
||||
write_interaction(train_interaction, 'train', output_dir)
|
||||
write_interaction(dev_interaction, 'dev', output_dir)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", choices=('spider', 'sparc', 'cosql'), default='sparc')
|
||||
parser.add_argument('--remove_from', action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
preprocess(args.dataset, args.remove_from)
|
|
@ -0,0 +1,77 @@
|
|||
import torch
|
||||
from torch.utils import data
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
from transformers import RobertaTokenizer
|
||||
|
||||
from .utils import preprocess
|
||||
|
||||
class NL2SQL_Dataset(data.Dataset):
|
||||
def __init__(self, args, data_type="train", shuffle=True):
|
||||
self.data_type = data_type
|
||||
assert self.data_type in ["train", "valid", "test"]
|
||||
# tokenizer
|
||||
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', max_len=128)
|
||||
|
||||
if self.data_type in ['train', 'valid']:
|
||||
self.file_path = os.path.join(args.data_path, 'rerank_train_data.json')
|
||||
else:
|
||||
self.file_path = os.path.join(args.data_path, 'rerank_dev_data.json')
|
||||
|
||||
with open(self.file_path, 'r') as f:
|
||||
self.data_lines = f.readlines()
|
||||
|
||||
print('load data from :', self.file_path)
|
||||
|
||||
# split to pos and neg data
|
||||
self.pos_data_all = []
|
||||
self.neg_data_all = []
|
||||
for i in range(len(self.data_lines)):
|
||||
content = json.loads(self.data_lines[i])
|
||||
if content['match']:
|
||||
self.pos_data_all.append(content)
|
||||
else:
|
||||
self.neg_data_all.append(content)
|
||||
|
||||
if self.data_type in ['train', 'valid']:
|
||||
# np.random.shuffle(self.pos_data_all)
|
||||
# np.random.shuffle(self.neg_data_all)
|
||||
# valid_range = int(len(self.pos_data_all + self.neg_data_all) * 0.9 )
|
||||
valid_range_pos = int(len(self.pos_data_all) * 0.9)
|
||||
valid_range_neg = int(len(self.neg_data_all) * 0.9)
|
||||
|
||||
self.train_pos_data = self.pos_data_all[:valid_range_pos]
|
||||
self.val_pos_data = self.pos_data_all[valid_range_pos:]
|
||||
|
||||
self.train_neg_data = self.neg_data_all[:valid_range_neg]
|
||||
self.val_neg_data = self.neg_data_all[valid_range_neg:]
|
||||
|
||||
def get_random_data(self, shuffle=True):
|
||||
if self.data_type == "train":
|
||||
max_train_num = len(self.train_pos_data) if len(self.train_pos_data) < len(self.train_neg_data) else len(self.train_neg_data)
|
||||
np.random.shuffle(self.train_pos_data)
|
||||
np.random.shuffle(self.train_neg_data)
|
||||
data = self.train_pos_data[:max_train_num] + self.train_neg_data[:max_train_num]
|
||||
np.random.shuffle(data)
|
||||
elif self.data_type == "valid":
|
||||
max_val_num = len(self.val_pos_data) if len(self.val_pos_data) < len(self.val_neg_data) else len(self.val_neg_data)
|
||||
data = self.val_pos_data[:max_val_num] + self.val_neg_data[:max_val_num]
|
||||
else:
|
||||
data = self.pos_data_all + self.neg_data_all
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data_sample = self.get_random_data()
|
||||
data = data_sample[idx]
|
||||
turn_id = data['turn_id']
|
||||
utterances = data['utterances'][:turn_id + 1]
|
||||
lable = data['match']
|
||||
sql = data['pred']
|
||||
tokens_tensor, attention_mask_tensor = preprocess(utterances, sql, self.tokenizer)
|
||||
label = torch.tensor(lable)
|
||||
return tokens_tensor, attention_mask_tensor, lable
|
||||
|
||||
def __len__(self):
|
||||
return len(self.get_random_data())
|
|
@ -0,0 +1,187 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from .dataset import NL2SQL_Dataset
|
||||
from .model import ReRanker
|
||||
|
||||
from sklearn.metrics import confusion_matrix, accuracy_score
|
||||
|
||||
def parser_arg():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=666)
|
||||
parser.add_argument("--gpu", type=str, default='0')
|
||||
parser.add_argument("--data_path", type=str, default="./reranker/data")
|
||||
parser.add_argument("--train", action="store_true")
|
||||
parser.add_argument("--test", action="store_true")
|
||||
parser.add_argument("--epoches", type=int, default=20)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--cls_lr", default=1e-3)
|
||||
parser.add_argument("--bert_lr", default=5e-6)
|
||||
parser.add_argument("--threshold", default=0.5)
|
||||
parser.add_argument("--save_dir", default="./reranker/checkpoints")
|
||||
parser.add_argument("--base_model", default="roberta")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def init(args):
|
||||
"""
|
||||
guaranteed reproducible
|
||||
"""
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.mkdir(args.save_dir)
|
||||
print("make dir : ", args.save_dir)
|
||||
|
||||
def print_evalute(evalute_lst):
|
||||
all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc = evalute_lst
|
||||
print("\n All Acc {}/{} ({:.2f}%) \t Pos Acc: {}/{} ({:.2f}%) \t Neg Acc: {}/{} ({:.2f}%) \n".format(all_acc, all_cnt, \
|
||||
100.* all_acc / all_cnt, \
|
||||
pos_acc, pos_cnt, \
|
||||
100. * pos_acc / pos_cnt, \
|
||||
neg_acc, neg_cnt, \
|
||||
100. * neg_acc / neg_cnt))
|
||||
return 100. * all_acc / all_cnt
|
||||
|
||||
def evaluate2(args, output_flat, target_flat):
|
||||
print("start evalute ...")
|
||||
pred = torch.gt(output_flat, args.threshold).float()
|
||||
pred = pred.cpu().numpy()
|
||||
target = target_flat.cpu().numpy()
|
||||
ret = confusion_matrix(target, pred)
|
||||
neg_recall, pos_recall = ret.diagonal() / ret.sum(axis=1)
|
||||
neg_acc, pos_acc = ret.diagonal() / ret.sum(axis=0)
|
||||
acc = accuracy_score(target, pred)
|
||||
print(" All Acc:\t {:.3f}%".format(100.0 * acc))
|
||||
print(" Neg Recall: \t {:.3f}% \t Pos Recall: \t {:.3f}% \n Neg Acc: \t {:.3f}% \t Pos Acc: \t {:.3f}% \n".format(100.0 * neg_recall, 100.0 * pos_recall, 100.0 * neg_acc, 100.0 * pos_acc))
|
||||
|
||||
|
||||
def evaluate(args, output, target, evalute_lst):
|
||||
all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc = evalute_lst
|
||||
output = torch.gt(output, args.threshold).float()
|
||||
for idx in range(target.shape[0]):
|
||||
gt = target[idx]
|
||||
pred = output[idx]
|
||||
if gt == 1:
|
||||
pos_cnt += 1
|
||||
if gt == pred:
|
||||
pos_acc += 1
|
||||
elif gt == 0:
|
||||
neg_cnt += 1
|
||||
if gt == pred:
|
||||
neg_acc += 1
|
||||
all_acc = pos_acc + neg_acc
|
||||
all_cnt = pos_cnt + neg_cnt
|
||||
return [all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc]
|
||||
|
||||
|
||||
def train(args, model, train_loader, epoch):
|
||||
model.train()
|
||||
criterion = nn.BCELoss()
|
||||
idx = 0
|
||||
evalute_lst = [0] * 6
|
||||
for batch_idx, (tokens, attention_mask, target) in enumerate(train_loader):
|
||||
tokens, attention_mask, target = tokens.cuda(), attention_mask.cuda(), target.cuda().float()
|
||||
model.zero_grad()
|
||||
output = model(input_ids=tokens, attention_mask=attention_mask)
|
||||
output = output.squeeze(dim=-1)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
model.cls_trainer.step()
|
||||
model.bert_trainer.step()
|
||||
idx += len(target)
|
||||
evalute_lst = evaluate(args, output, target, evalute_lst)
|
||||
if batch_idx % 50 == 0:
|
||||
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch, idx, len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
return print_evalute(evalute_lst)
|
||||
|
||||
def test(args, model, data_loader):
|
||||
model.eval()
|
||||
criterion = nn.BCELoss()
|
||||
test_loss = 0
|
||||
evalute_lst = [0] * 6
|
||||
output_lst = []
|
||||
target_lst = []
|
||||
with torch.no_grad():
|
||||
for tokens, attention_mask, target in data_loader:
|
||||
tokens, attention_mask, target = tokens.cuda(), attention_mask.cuda(), target.cuda().float()
|
||||
output = model(input_ids=tokens, attention_mask=attention_mask)
|
||||
output = output.squeeze(dim=-1)
|
||||
output_lst.append(output)
|
||||
target_lst.append(target)
|
||||
test_loss = criterion(output, target)
|
||||
evalute_lst = evaluate(args, output, target, evalute_lst)
|
||||
output_flat = torch.cat(output_lst)
|
||||
target_flat = torch.cat(target_lst)
|
||||
print('\n{} set: loss: {:.4f}\n'.format(data_loader.dataset.data_type, test_loss))
|
||||
acc = print_evalute(evalute_lst)
|
||||
return acc
|
||||
|
||||
def main():
|
||||
# init
|
||||
args = parser_arg()
|
||||
init(args)
|
||||
|
||||
# model define
|
||||
model = ReRanker(args, args.base_model)
|
||||
model = model.cuda()
|
||||
model.build_optim()
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = StepLR(model.cls_trainer, step_size=1)
|
||||
|
||||
|
||||
test_loader = DataLoader(
|
||||
NL2SQL_Dataset(args, data_type="test"),
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
if args.train:
|
||||
train_loader = DataLoader(
|
||||
NL2SQL_Dataset(args, data_type="train"),
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
NL2SQL_Dataset(args, data_type="valid"),
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
print("train set num: ", len(train_loader.dataset))
|
||||
print("valid set num: ", len(valid_loader.dataset))
|
||||
print("test set num: ", len(test_loader.dataset))
|
||||
|
||||
for epoch in range(args.epoches):
|
||||
train_acc = train(args, model, train_loader, epoch)
|
||||
valid_acc = test(args, model, valid_loader)
|
||||
save_path = os.path.join(args.save_dir, "roberta_%s_train+%s_val+%s.pt" % (str(epoch), str(train_acc), str(valid_acc)))
|
||||
torch.save(model.state_dict(), save_path)
|
||||
print('save model %s...' % save_path)
|
||||
scheduler.step()
|
||||
test(args, model, test_loader)
|
||||
|
||||
if args.test:
|
||||
model_path = 'xxx.pt'
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
print("load model from ", model_path)
|
||||
test(args, model, test_loader)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BertModel, RobertaModel
|
||||
|
||||
class ReRanker(nn.Module):
|
||||
def __init__(self, args=None, base_model="roberta"):
|
||||
super(ReRanker, self).__init__()
|
||||
self.args = args
|
||||
self.bert_model = RobertaModel.from_pretrained('./local_param/')
|
||||
self.cls_model = nn.Sequential(
|
||||
nn.Linear(768, 128),
|
||||
nn.Tanh(),
|
||||
nn.Linear(128, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def build_optim(self):
|
||||
params_cls_trainer = []
|
||||
params_bert_trainer = []
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad:
|
||||
if 'bert_model' in name:
|
||||
params_bert_trainer.append(param)
|
||||
else:
|
||||
params_cls_trainer.append(param)
|
||||
self.bert_trainer = torch.optim.Adam(params_bert_trainer, lr=self.args.bert_lr)
|
||||
self.cls_trainer = torch.optim.Adam(params_cls_trainer, lr=self.args.cls_lr)
|
||||
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# bert model
|
||||
x = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
# get first [CLS] representation
|
||||
x = x[:, 0, :]
|
||||
# classification layer
|
||||
x = self.cls_model(x)
|
||||
return x
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
|
||||
from transformers import BertTokenizer, RobertaTokenizer
|
||||
|
||||
from .model import ReRanker
|
||||
from .utils import preprocess
|
||||
|
||||
class Ranker:
|
||||
def __init__(self, model_path, base_model='roberta'):
|
||||
self.model = ReRanker(base_model=base_model)
|
||||
self.model.load_state_dict(torch.load(model_path))
|
||||
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', max_len=512)
|
||||
print("load reranker model from ", model_path)
|
||||
|
||||
def get_score(self, utterances, sql):
|
||||
tokens_tensor, attention_mask_tensor = preprocess(utterances, sql, self.tokenizer)
|
||||
tokens_tensor = tokens_tensor.unsqueeze(0)
|
||||
attention_mask_tensor = attention_mask_tensor.unsqueeze(0)
|
||||
output = self.model(input_ids=tokens_tensor, attention_mask=attention_mask_tensor)
|
||||
output = output.squeeze(dim=-1)
|
||||
return output.item()
|
||||
|
||||
def get_score_batch(self, utterances, sqls):
|
||||
assert isinstance(sqls, list), "only support sql list input"
|
||||
tokens_tensor_batch, attention_mask_tensor_batch = preprocess(utterances, sqls[0], self.tokenizer)
|
||||
tokens_tensor_batch = tokens_tensor_batch.unsqueeze(0)
|
||||
attention_mask_tensor_batch = attention_mask_tensor_batch.unsqueeze(0)
|
||||
if len(sqls) > 1:
|
||||
for s in sqls[1:]:
|
||||
new_token, new_mask = preprocess(utterances, s, self.tokenizer)
|
||||
tokens_tensor_batch = torch.cat([tokens_tensor_batch, new_token.unsqueeze(0)], dim=0)
|
||||
attention_mask_tensor_batch = torch.cat([attention_mask_tensor_batch, new_mask.unsqueeze(0)])
|
||||
output = self.model(input_ids=tokens_tensor_batch, attention_mask=attention_mask_tensor_batch)
|
||||
output = output.view(-1)
|
||||
ret = []
|
||||
for i in output:
|
||||
ret.append(i.item())
|
||||
return ret
|
||||
|
||||
if __name__ == '__main__':
|
||||
utterances_1 = ["What are all the airlines ?", "Of these , which is Jetblue Airways ?", "What is the country corresponding it ?"]
|
||||
sql_1 = "select Country from airlines where Airline = 1 select from airlines select * from airlines where Airline = 1"
|
||||
utterances_2 = ["What are all the airlines ?", "Of these , which is Jetblue Airways ?", "What is the country corresponding it ?"]
|
||||
sql_2 = "select T2.Countryabbrev from airlines as T1 join airports as T2 where T1.Airline = 1"
|
||||
sql_lst = ["select *", "select T2.Countryabbrev from airlines as T1 join airports as T2 where T1.Airline = 1"]
|
||||
utterances_test = ['Show the document name for all documents .', 'Also show their description .']
|
||||
sql_test = ['select Document_Name , Document_Description from Documents', 'select Document_Name , Document_Description from Documents select Document_Name from Documents', 'select Document_Name , Document_Name , Document_Description from Documents', 'select Document_ID , Document_Description from Documents', 'select Document_Name from Documents']
|
||||
|
||||
ranker = Ranker()
|
||||
#print("utterance : ", utterances_1)
|
||||
#print("sql : ", sql_1)
|
||||
#print(ranker.get_score(utterances_1, sql_1))
|
||||
#print("utterance : ", utterances_2)
|
||||
#print("sql : ", sql_2)
|
||||
#print(ranker.get_score(utterances_2, sql_2))
|
||||
print(ranker.get_score_batch(utterances_test, sql_test))
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
|
||||
def preprocess(utterances, sql, tokenizer):
|
||||
text = ""
|
||||
for u in utterances:
|
||||
for t in u.split(' '):
|
||||
text = text + ' ' + t.strip()
|
||||
sql = sql.strip()
|
||||
sql = sql.replace(".", " ")
|
||||
sql = sql.replace("_", " ")
|
||||
l = []
|
||||
for char in sql:
|
||||
if char.isupper():
|
||||
l.append(' ')
|
||||
l.append(char)
|
||||
sql = ''.join(l)
|
||||
sql = ' '.join(sql.split())
|
||||
# input:
|
||||
# [CLS] utterance [SEP] sql [SEP]
|
||||
token_encoding = tokenizer.encode_plus(text, sql, max_length=128, pad_to_max_length=True)
|
||||
tokenized_token = token_encoding['input_ids']
|
||||
attention_mask = token_encoding['attention_mask']
|
||||
tokens_tensor = torch.tensor(tokenized_token)
|
||||
attention_mask_tensor = torch.tensor(attention_mask)
|
||||
return tokens_tensor, attention_mask_tensor
|
|
@ -0,0 +1,311 @@
|
|||
"""Contains a main function for training and/or evaluating a model."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from parse_args import interpret_args
|
||||
|
||||
import data_util
|
||||
from data_util import atis_data
|
||||
from model.schema_interaction_model import SchemaInteractionATISModel
|
||||
from logger import Logger
|
||||
from model.model import ATISModel
|
||||
from model_util import Metrics, evaluate_utterance_sample, evaluate_interaction_sample, \
|
||||
train_epoch_with_utterances, train_epoch_with_interactions, evaluate_using_predicted_queries
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
||||
VALID_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY]
|
||||
TRAIN_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY]
|
||||
FINAL_EVAL_METRICS = [Metrics.STRING_ACCURACY, Metrics.TOKEN_ACCURACY]
|
||||
|
||||
def train(model, data, params):
|
||||
""" Trains a model.
|
||||
|
||||
Inputs:
|
||||
model (ATISModel): The model to train.
|
||||
data (ATISData): The data that is used to train.
|
||||
params (namespace): Training parameters.
|
||||
"""
|
||||
# Get the training batches.
|
||||
log = Logger(os.path.join(params.logdir, params.logfile), "a+")
|
||||
num_train_original = atis_data.num_utterances(data.train_data)
|
||||
eval_fn = evaluate_utterance_sample
|
||||
trainbatch_fn = data.get_utterance_batches
|
||||
trainsample_fn = data.get_random_utterances
|
||||
validsample_fn = data.get_all_utterances
|
||||
batch_size = params.batch_size
|
||||
if params.interaction_level:
|
||||
batch_size = 1
|
||||
eval_fn = evaluate_interaction_sample
|
||||
trainbatch_fn = data.get_interaction_batches
|
||||
trainsample_fn = data.get_random_interactions
|
||||
validsample_fn = data.get_all_interactions
|
||||
|
||||
maximum_output_length = params.train_maximum_sql_length
|
||||
train_batches = trainbatch_fn(batch_size,
|
||||
max_output_length=maximum_output_length,
|
||||
randomize=not params.deterministic)
|
||||
|
||||
if params.num_train >= 0:
|
||||
train_batches = train_batches[:params.num_train]
|
||||
|
||||
training_sample = trainsample_fn(params.train_evaluation_size,
|
||||
max_output_length=maximum_output_length)
|
||||
valid_examples = validsample_fn(data.valid_data,
|
||||
max_output_length=maximum_output_length)
|
||||
|
||||
num_train_examples = sum([len(batch) for batch in train_batches])
|
||||
num_steps_per_epoch = len(train_batches)
|
||||
|
||||
# Keeping track of things during training.
|
||||
epochs = 0
|
||||
patience = params.initial_patience
|
||||
learning_rate_coefficient = 1.
|
||||
previous_epoch_loss = float('inf')
|
||||
maximum_validation_accuracy = 0.
|
||||
maximum_string_accuracy = 0.
|
||||
|
||||
countdown = int(patience)
|
||||
|
||||
if params.scheduler:
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.trainer, mode='min', )
|
||||
|
||||
keep_training = True
|
||||
while keep_training:
|
||||
log.put("Epoch:\t" + str(epochs))
|
||||
model.set_dropout(params.dropout_amount)
|
||||
|
||||
if not params.scheduler:
|
||||
model.set_learning_rate(learning_rate_coefficient * params.initial_learning_rate)
|
||||
|
||||
# Run a training step.
|
||||
if params.interaction_level:
|
||||
epoch_loss = train_epoch_with_interactions(
|
||||
train_batches,
|
||||
params,
|
||||
model,
|
||||
randomize=not params.deterministic)
|
||||
else:
|
||||
epoch_loss = train_epoch_with_utterances(
|
||||
train_batches,
|
||||
model,
|
||||
randomize=not params.deterministic)
|
||||
|
||||
log.put("train epoch loss:\t" + str(epoch_loss))
|
||||
|
||||
model.set_dropout(0.)
|
||||
|
||||
# Run an evaluation step on a sample of the training data.
|
||||
"""
|
||||
train_eval_results = eval_fn(training_sample,
|
||||
model,
|
||||
params.train_maximum_sql_length,
|
||||
name=os.path.join(params.logdir, "train-eval"),
|
||||
write_results=True,
|
||||
gold_forcing=True,
|
||||
metrics=TRAIN_EVAL_METRICS)[0]
|
||||
|
||||
for name, value in train_eval_results.items():
|
||||
log.put(
|
||||
"train final gold-passing " +
|
||||
name.name +
|
||||
":\t" +
|
||||
"%.2f" %
|
||||
value)
|
||||
"""
|
||||
# Run an evaluation step on the validation set.
|
||||
valid_eval_results = eval_fn(valid_examples,
|
||||
model,
|
||||
params.eval_maximum_sql_length,
|
||||
name=os.path.join(params.logdir, "valid-eval"),
|
||||
write_results=True,
|
||||
gold_forcing=True,
|
||||
metrics=VALID_EVAL_METRICS)[0]
|
||||
for name, value in valid_eval_results.items():
|
||||
log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value)
|
||||
|
||||
valid_loss = valid_eval_results[Metrics.LOSS]
|
||||
valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY]
|
||||
string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY]
|
||||
|
||||
assert string_accuracy >= 20
|
||||
|
||||
if params.scheduler:
|
||||
scheduler.step(valid_loss)
|
||||
|
||||
if valid_loss > previous_epoch_loss:
|
||||
learning_rate_coefficient *= params.learning_rate_ratio
|
||||
log.put(
|
||||
"learning rate coefficient:\t" +
|
||||
str(learning_rate_coefficient))
|
||||
|
||||
previous_epoch_loss = valid_loss
|
||||
saved = False
|
||||
|
||||
if not saved and string_accuracy > maximum_string_accuracy:
|
||||
maximum_string_accuracy = string_accuracy
|
||||
patience = patience * params.patience_ratio
|
||||
countdown = int(patience)
|
||||
last_save_file = os.path.join(params.logdir, "save_" + str(epochs))
|
||||
model.save(last_save_file)
|
||||
|
||||
log.put(
|
||||
"maximum string accuracy:\t" +
|
||||
str(maximum_string_accuracy))
|
||||
log.put("patience:\t" + str(patience))
|
||||
log.put("save file:\t" + str(last_save_file))
|
||||
|
||||
if countdown <= 0:
|
||||
keep_training = False
|
||||
|
||||
countdown -= 1
|
||||
log.put("countdown:\t" + str(countdown))
|
||||
log.put("")
|
||||
|
||||
epochs += 1
|
||||
|
||||
log.put("Finished training!")
|
||||
log.close()
|
||||
|
||||
return last_save_file
|
||||
|
||||
|
||||
def evaluate(model, data, params, last_save_file, split):
|
||||
"""Evaluates a pretrained model on a dataset.
|
||||
|
||||
Inputs:
|
||||
model (ATISModel): Model class.
|
||||
data (ATISData): All of the data.
|
||||
params (namespace): Parameters for the model.
|
||||
last_save_file (str): Location where the model save file is.
|
||||
"""
|
||||
if last_save_file:
|
||||
model.load(last_save_file)
|
||||
else:
|
||||
if not params.save_file:
|
||||
raise ValueError(
|
||||
"Must provide a save file name if not training first.")
|
||||
model.load(params.save_file)
|
||||
|
||||
filename = split
|
||||
|
||||
if filename == 'dev':
|
||||
split = data.dev_data
|
||||
elif filename == 'train':
|
||||
split = data.train_data
|
||||
elif filename == 'test':
|
||||
split = data.test_data
|
||||
elif filename == 'valid':
|
||||
split = data.valid_data
|
||||
else:
|
||||
raise ValueError("Split not recognized: " + str(params.evaluate_split))
|
||||
|
||||
if params.use_predicted_queries:
|
||||
filename += "_use_predicted_queries"
|
||||
else:
|
||||
filename += "_use_gold_queries"
|
||||
|
||||
if filename == 'train':
|
||||
full_name = os.path.join(params.logdir, filename) + params.results_note
|
||||
else:
|
||||
full_name = os.path.join("results", params.save_file.split('/')[-1]) + params.results_note
|
||||
|
||||
if params.interaction_level or params.use_predicted_queries:
|
||||
examples = data.get_all_interactions(split)
|
||||
if params.interaction_level:
|
||||
evaluate_interaction_sample(
|
||||
examples,
|
||||
model,
|
||||
name=full_name,
|
||||
metrics=FINAL_EVAL_METRICS,
|
||||
total_num=atis_data.num_utterances(split),
|
||||
database_username=params.database_username,
|
||||
database_password=params.database_password,
|
||||
database_timeout=params.database_timeout,
|
||||
use_predicted_queries=params.use_predicted_queries,
|
||||
max_generation_length=params.eval_maximum_sql_length,
|
||||
write_results=True,
|
||||
use_gpu=True,
|
||||
compute_metrics=params.compute_metrics)
|
||||
else:
|
||||
evaluate_using_predicted_queries(
|
||||
examples,
|
||||
model,
|
||||
name=full_name,
|
||||
metrics=FINAL_EVAL_METRICS,
|
||||
total_num=atis_data.num_utterances(split),
|
||||
database_username=params.database_username,
|
||||
database_password=params.database_password,
|
||||
database_timeout=params.database_timeout)
|
||||
else:
|
||||
examples = data.get_all_utterances(split)
|
||||
evaluate_utterance_sample(
|
||||
examples,
|
||||
model,
|
||||
name=full_name,
|
||||
gold_forcing=False,
|
||||
metrics=FINAL_EVAL_METRICS,
|
||||
total_num=atis_data.num_utterances(split),
|
||||
max_generation_length=params.eval_maximum_sql_length,
|
||||
database_username=params.database_username,
|
||||
database_password=params.database_password,
|
||||
database_timeout=params.database_timeout,
|
||||
write_results=True)
|
||||
|
||||
def init_env(params):
|
||||
if params.seed == 0:
|
||||
params.seed = int(time.time()) % 1000
|
||||
seed = params.seed
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
def main():
|
||||
"""Main function that trains and/or evaluates a model."""
|
||||
params = interpret_args()
|
||||
init_env(params)
|
||||
# Prepare the dataset into the proper form.
|
||||
data = atis_data.ATISDataset(params)
|
||||
|
||||
# Construct the model object.
|
||||
if params.interaction_level:
|
||||
model_type = SchemaInteractionATISModel
|
||||
else:
|
||||
print('not implemented')
|
||||
exit()
|
||||
|
||||
model = model_type(
|
||||
params,
|
||||
data.input_vocabulary,
|
||||
data.output_vocabulary,
|
||||
data.output_vocabulary_schema,
|
||||
data.anonymizer if params.anonymize and params.anonymization_scoring else None)
|
||||
|
||||
model = model.cuda()
|
||||
model.build_optim()
|
||||
sys.stdout.flush()
|
||||
|
||||
last_save_file = ""
|
||||
if params.train:
|
||||
last_save_file = train(model, data, params)
|
||||
if params.evaluate and 'train' in params.evaluate_split:
|
||||
evaluate(model, data, params, last_save_file, split='train')
|
||||
if params.evaluate and 'valid' in params.evaluate_split:
|
||||
evaluate(model, data, params, last_save_file, split='valid')
|
||||
if params.evaluate and 'dev' in params.evaluate_split:
|
||||
evaluate(model, data, params, last_save_file, split='dev')
|
||||
if params.evaluate and 'test' in params.evaluate_split:
|
||||
evaluate(model, data, params, last_save_file, split='test')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,9 @@
|
|||
torch==1.0.0
|
||||
sqlparse
|
||||
pymysql
|
||||
progressbar
|
||||
nltk
|
||||
numpy
|
||||
six
|
||||
spacy
|
||||
transformers==2.10.0
|
|
@ -0,0 +1,779 @@
|
|||
{"sql": ["'2V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northeast", "express", "regional", "airlines"]}
|
||||
{"sql": ["'3J'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "alliance"]}
|
||||
{"sql": ["'7V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["alpha", "air"]}
|
||||
{"sql": ["'9E'"], "type": "AIRLINE_CODE", "cleaned_nl": ["express", "airlines", "i"]}
|
||||
{"sql": ["'9L'"], "type": "AIRLINE_CODE", "cleaned_nl": ["colgan", "air"]}
|
||||
{"sql": ["'9N'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "states", "airlines"]}
|
||||
{"sql": ["'9X'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ontario", "express"]}
|
||||
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "airlines"]}
|
||||
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "airline"]}
|
||||
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american"]}
|
||||
{"sql": ["'AC'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "canada"]}
|
||||
{"sql": ["'AR'"], "type": "AIRLINE_CODE", "cleaned_nl": ["aerolineas", "argentinas"]}
|
||||
{"sql": ["'AS'"], "type": "AIRLINE_CODE", "cleaned_nl": ["alaska", "airlines"]}
|
||||
{"sql": ["'AT'"], "type": "AIRLINE_CODE", "cleaned_nl": ["royal", "air", "maroc"]}
|
||||
{"sql": ["'BA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["british", "airways"]}
|
||||
{"sql": ["'BE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["braniff", "international", "airlines"]}
|
||||
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["continental", "airlines"]}
|
||||
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["continental"]}
|
||||
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian", "airlines", "international"]}
|
||||
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian"]}
|
||||
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian", "airlines"]}
|
||||
{"sql": ["'DH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["atlantic", "coast", "airlines"]}
|
||||
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta", "airlines"]}
|
||||
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta"]}
|
||||
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta", "airline"]}
|
||||
{"sql": ["'EV'"], "type": "AIRLINE_CODE", "cleaned_nl": ["atlantic", "southeast", "airlines"]}
|
||||
{"sql": ["'EA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ea"]}
|
||||
{"sql": ["'EA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["eastern"]}
|
||||
{"sql": ["'FF'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tower", "air"]}
|
||||
{"sql": ["'GX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "ontario"]}
|
||||
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["america", "west", "airlines"]}
|
||||
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["america", "west"]}
|
||||
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "west"]}
|
||||
{"sql": ["'HQ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["business", "express"]}
|
||||
{"sql": ["'KW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["carnival", "air", "lines"]}
|
||||
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa", "german", "airlines"]}
|
||||
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa"]}
|
||||
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa", "airlines"]}
|
||||
{"sql": ["'MG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mgm", "grand", "air"]}
|
||||
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest", "airlines"]}
|
||||
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest", "airline"]}
|
||||
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest"]}
|
||||
{"sql": ["'NX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nationair"]}
|
||||
{"sql": ["'OE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["westair", "airlines"]}
|
||||
{"sql": ["'OH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["comair"]}
|
||||
{"sql": ["'OK'"], "type": "AIRLINE_CODE", "cleaned_nl": ["czechoslovak", "airlines"]}
|
||||
{"sql": ["'OO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sky", "west", "airlines"]}
|
||||
{"sql": ["'QD'"], "type": "AIRLINE_CODE", "cleaned_nl": ["grand", "airways"]}
|
||||
{"sql": ["'RP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["precision", "airlines"]}
|
||||
{"sql": ["'RZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "world", "express"]}
|
||||
{"sql": ["'SN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sabena", "belgian", "world", "airlines"]}
|
||||
{"sql": ["'SX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["christman", "air"]}
|
||||
{"sql": ["'TG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["thai", "airways"]}
|
||||
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "world"]}
|
||||
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["twa"]}
|
||||
{"sql": ["'TZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "trans", "air"]}
|
||||
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["united", "airlines"]}
|
||||
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["united"]}
|
||||
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["usair"]}
|
||||
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["us", "air"]}
|
||||
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest", "airlines"]}
|
||||
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest", "air"]}
|
||||
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest"]}
|
||||
{"sql": ["'XJ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mesaba", "aviation"]}
|
||||
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["midwest", "express", "airlines"]}
|
||||
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["midwest", "express"]}
|
||||
{"sql": ["'ZW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "wisconsin"]}
|
||||
{"sql": ["'2V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["2v"]}
|
||||
{"sql": ["'3J'"], "type": "AIRLINE_CODE", "cleaned_nl": ["3j"]}
|
||||
{"sql": ["'7V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["7v"]}
|
||||
{"sql": ["'9E'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9e"]}
|
||||
{"sql": ["'9L'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9l"]}
|
||||
{"sql": ["'9N'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9n"]}
|
||||
{"sql": ["'9X'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9x"]}
|
||||
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["aa"]}
|
||||
{"sql": ["'AC'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ac"]}
|
||||
{"sql": ["'AR'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ar"]}
|
||||
{"sql": ["'AS'"], "type": "AIRLINE_CODE", "cleaned_nl": ["as"]}
|
||||
{"sql": ["'BA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ba"]}
|
||||
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["co"]}
|
||||
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["cp"]}
|
||||
{"sql": ["'DH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["dh"]}
|
||||
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["dl"]}
|
||||
{"sql": ["'EV'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ev"]}
|
||||
{"sql": ["'FF'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ff"]}
|
||||
{"sql": ["'GX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["gx"]}
|
||||
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["hp"]}
|
||||
{"sql": ["'HQ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["hq"]}
|
||||
{"sql": ["'KW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["kw"]}
|
||||
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lh"]}
|
||||
{"sql": ["'MG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mg"]}
|
||||
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nw"]}
|
||||
{"sql": ["'NX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nx"]}
|
||||
{"sql": ["'OE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oe"]}
|
||||
{"sql": ["'OH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oh"]}
|
||||
{"sql": ["'OK'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ok"]}
|
||||
{"sql": ["'OO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oo"]}
|
||||
{"sql": ["'QD'"], "type": "AIRLINE_CODE", "cleaned_nl": ["qd"]}
|
||||
{"sql": ["'RP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["rp"]}
|
||||
{"sql": ["'RZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["rz"]}
|
||||
{"sql": ["'SN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sn"]}
|
||||
{"sql": ["'SX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sx"]}
|
||||
{"sql": ["'TG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tg"]}
|
||||
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tw"]}
|
||||
{"sql": ["'TZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tz"]}
|
||||
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ua"]}
|
||||
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["us"]}
|
||||
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["wn"]}
|
||||
{"sql": ["'XJ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["xj"]}
|
||||
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["yx"]}
|
||||
{"sql": ["'ZW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["zw"]}
|
||||
{"sql": ["'NASHVILLE'"], "type": "CITY_NAME", "cleaned_nl": ["nashville"]}
|
||||
{"sql": ["'BOSTON'"], "type": "CITY_NAME", "cleaned_nl": ["boston"]}
|
||||
{"sql": ["'BURBANK'"], "type": "CITY_NAME", "cleaned_nl": ["burbank"]}
|
||||
{"sql": ["'BALTIMORE'"], "type": "CITY_NAME", "cleaned_nl": ["baltimore"]}
|
||||
{"sql": ["'CHICAGO'"], "type": "CITY_NAME", "cleaned_nl": ["chicago"]}
|
||||
{"sql": ["'CLEVELAND'"], "type": "CITY_NAME", "cleaned_nl": ["cleveland"]}
|
||||
{"sql": ["'CHARLOTTE'"], "type": "CITY_NAME", "cleaned_nl": ["charlotte"]}
|
||||
{"sql": ["'COLUMBUS'"], "type": "CITY_NAME", "cleaned_nl": ["columbus"]}
|
||||
{"sql": ["'CINCINNATI'"], "type": "CITY_NAME", "cleaned_nl": ["cincinnati"]}
|
||||
{"sql": ["'DENVER'"], "type": "CITY_NAME", "cleaned_nl": ["denver"]}
|
||||
{"sql": ["'DALLAS'"], "type": "CITY_NAME", "cleaned_nl": ["dallas"]}
|
||||
{"sql": ["'DETROIT'"], "type": "CITY_NAME", "cleaned_nl": ["detroit"]}
|
||||
{"sql": ["'FORT WORTH'"], "type": "CITY_NAME", "cleaned_nl": ["fort", "worth"]}
|
||||
{"sql": ["'HOUSTON'"], "type": "CITY_NAME", "cleaned_nl": ["houston"]}
|
||||
{"sql": ["'WESTCHESTER COUNTY'"], "type": "CITY_NAME", "cleaned_nl": ["westchester", "county"]}
|
||||
{"sql": ["'WESTCHESTER COUNTY'"], "type": "CITY_NAME", "cleaned_nl": ["westchester"]}
|
||||
{"sql": ["'INDIANAPOLIS'"], "type": "CITY_NAME", "cleaned_nl": ["indianapolis"]}
|
||||
{"sql": ["'NEWARK'"], "type": "CITY_NAME", "cleaned_nl": ["newark"]}
|
||||
{"sql": ["'LAS VEGAS'"], "type": "CITY_NAME", "cleaned_nl": ["las", "vegas"]}
|
||||
{"sql": ["'LOS ANGELES'"], "type": "CITY_NAME", "cleaned_nl": ["los", "angeles"]}
|
||||
{"sql": ["'LOS ANGELES'"], "type": "CITY_NAME", "cleaned_nl": ["la"]}
|
||||
{"sql": ["'LONG BEACH'"], "type": "CITY_NAME", "cleaned_nl": ["long", "beach"]}
|
||||
{"sql": ["'ATLANTA'"], "type": "CITY_NAME", "cleaned_nl": ["atlanta"]}
|
||||
{"sql": ["'MEMPHIS'"], "type": "CITY_NAME", "cleaned_nl": ["memphis"]}
|
||||
{"sql": ["'MIAMI'"], "type": "CITY_NAME", "cleaned_nl": ["miami"]}
|
||||
{"sql": ["'KANSAS CITY'"], "type": "CITY_NAME", "cleaned_nl": ["kansas", "city"]}
|
||||
{"sql": ["'MILWAUKEE'"], "type": "CITY_NAME", "cleaned_nl": ["milwaukee"]}
|
||||
{"sql": ["'MINNEAPOLIS'"], "type": "CITY_NAME", "cleaned_nl": ["minneapolis"]}
|
||||
{"sql": ["'NEW YORK'"], "type": "CITY_NAME", "cleaned_nl": ["new", "york"]}
|
||||
{"sql": ["'NEW YORK'"], "type": "CITY_NAME", "cleaned_nl": ["new", "york", "city"]}
|
||||
{"sql": ["'OAKLAND'"], "type": "CITY_NAME", "cleaned_nl": ["oakland"]}
|
||||
{"sql": ["'ONTARIO'"], "type": "CITY_NAME", "cleaned_nl": ["ontario"]}
|
||||
{"sql": ["'ORLANDO'"], "type": "CITY_NAME", "cleaned_nl": ["orlando"]}
|
||||
{"sql": ["'PHILADELPHIA'"], "type": "CITY_NAME", "cleaned_nl": ["philadelphia"]}
|
||||
{"sql": ["'PHOENIX'"], "type": "CITY_NAME", "cleaned_nl": ["phoenix"]}
|
||||
{"sql": ["'PITTSBURGH'"], "type": "CITY_NAME", "cleaned_nl": ["pittsburgh"]}
|
||||
{"sql": ["'ST. PAUL'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "paul"]}
|
||||
{"sql": ["'ST. PAUL'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "paul"]}
|
||||
{"sql": ["'SAN DIEGO'"], "type": "CITY_NAME", "cleaned_nl": ["san", "diego"]}
|
||||
{"sql": ["'SEATTLE'"], "type": "CITY_NAME", "cleaned_nl": ["seattle"]}
|
||||
{"sql": ["'SAN FRANCISCO'"], "type": "CITY_NAME", "cleaned_nl": ["san", "francisco"]}
|
||||
{"sql": ["'SAN JOSE'"], "type": "CITY_NAME", "cleaned_nl": ["san", "jose"]}
|
||||
{"sql": ["'SALT LAKE CITY'"], "type": "CITY_NAME", "cleaned_nl": ["salt", "lake", "city"]}
|
||||
{"sql": ["'ST. LOUIS'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "louis"]}
|
||||
{"sql": ["'ST. LOUIS'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "louis"]}
|
||||
{"sql": ["'ST. PETERSBURG'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "petersburg"]}
|
||||
{"sql": ["'ST. PETERSBURG'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "petersburg"]}
|
||||
{"sql": ["'TACOMA'"], "type": "CITY_NAME", "cleaned_nl": ["tacoma"]}
|
||||
{"sql": ["'TAMPA'"], "type": "CITY_NAME", "cleaned_nl": ["tampa"]}
|
||||
{"sql": ["'WASHINGTON'"], "type": "CITY_NAME", "cleaned_nl": ["washington"]}
|
||||
{"sql": ["'MONTREAL'"], "type": "CITY_NAME", "cleaned_nl": ["montreal"]}
|
||||
{"sql": ["'TORONTO'"], "type": "CITY_NAME", "cleaned_nl": ["toronto"]}
|
||||
{"sql": ["1"], "type": "MONTH_NUMBER", "cleaned_nl": ["january"]}
|
||||
{"sql": ["2"], "type": "MONTH_NUMBER", "cleaned_nl": ["february"]}
|
||||
{"sql": ["3"], "type": "MONTH_NUMBER", "cleaned_nl": ["march"]}
|
||||
{"sql": ["4"], "type": "MONTH_NUMBER", "cleaned_nl": ["april"]}
|
||||
{"sql": ["5"], "type": "MONTH_NUMBER", "cleaned_nl": ["may"]}
|
||||
{"sql": ["6"], "type": "MONTH_NUMBER", "cleaned_nl": ["june"]}
|
||||
{"sql": ["7"], "type": "MONTH_NUMBER", "cleaned_nl": ["july"]}
|
||||
{"sql": ["8"], "type": "MONTH_NUMBER", "cleaned_nl": ["august"]}
|
||||
{"sql": ["9"], "type": "MONTH_NUMBER", "cleaned_nl": ["september"]}
|
||||
{"sql": ["10"], "type": "MONTH_NUMBER", "cleaned_nl": ["october"]}
|
||||
{"sql": ["11"], "type": "MONTH_NUMBER", "cleaned_nl": ["november"]}
|
||||
{"sql": ["12"], "type": "MONTH_NUMBER", "cleaned_nl": ["december"]}
|
||||
{"sql": ["'MONDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["monday"]}
|
||||
{"sql": ["'TUESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["tuesday"]}
|
||||
{"sql": ["'WEDNESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["wednesday"]}
|
||||
{"sql": ["'THURSDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["thursday"]}
|
||||
{"sql": ["'FRIDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["friday"]}
|
||||
{"sql": ["'SATURDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["saturday"]}
|
||||
{"sql": ["'SUNDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["sunday"]}
|
||||
{"sql": ["'MONDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["mondays"]}
|
||||
{"sql": ["'TUESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["tuesdays"]}
|
||||
{"sql": ["'WEDNESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["wednesdays"]}
|
||||
{"sql": ["'THURSDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["thursdays"]}
|
||||
{"sql": ["'FRIDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["fridays"]}
|
||||
{"sql": ["'SATURDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["saturdays"]}
|
||||
{"sql": ["'SUNDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["sundays"]}
|
||||
{"sql": ["1"], "type": "DAY_NUMBER", "cleaned_nl": ["first"]}
|
||||
{"sql": ["2"], "type": "DAY_NUMBER", "cleaned_nl": ["second"]}
|
||||
{"sql": ["3"], "type": "DAY_NUMBER", "cleaned_nl": ["third"]}
|
||||
{"sql": ["4"], "type": "DAY_NUMBER", "cleaned_nl": ["fourth"]}
|
||||
{"sql": ["5"], "type": "DAY_NUMBER", "cleaned_nl": ["fifth"]}
|
||||
{"sql": ["6"], "type": "DAY_NUMBER", "cleaned_nl": ["sixth"]}
|
||||
{"sql": ["7"], "type": "DAY_NUMBER", "cleaned_nl": ["seventh"]}
|
||||
{"sql": ["8"], "type": "DAY_NUMBER", "cleaned_nl": ["eighth"]}
|
||||
{"sql": ["9"], "type": "DAY_NUMBER", "cleaned_nl": ["ninth"]}
|
||||
{"sql": ["10"], "type": "DAY_NUMBER", "cleaned_nl": ["tenth"]}
|
||||
{"sql": ["11"], "type": "DAY_NUMBER", "cleaned_nl": ["eleventh"]}
|
||||
{"sql": ["12"], "type": "DAY_NUMBER", "cleaned_nl": ["twelfth"]}
|
||||
{"sql": ["13"], "type": "DAY_NUMBER", "cleaned_nl": ["thirteenth"]}
|
||||
{"sql": ["14"], "type": "DAY_NUMBER", "cleaned_nl": ["fourteenth"]}
|
||||
{"sql": ["15"], "type": "DAY_NUMBER", "cleaned_nl": ["fifteenth"]}
|
||||
{"sql": ["16"], "type": "DAY_NUMBER", "cleaned_nl": ["sixteenth"]}
|
||||
{"sql": ["17"], "type": "DAY_NUMBER", "cleaned_nl": ["seventeenth"]}
|
||||
{"sql": ["18"], "type": "DAY_NUMBER", "cleaned_nl": ["eighteenth"]}
|
||||
{"sql": ["19"], "type": "DAY_NUMBER", "cleaned_nl": ["nineteenth"]}
|
||||
{"sql": ["20"], "type": "DAY_NUMBER", "cleaned_nl": ["twentieth"]}
|
||||
{"sql": ["21"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "first"]}
|
||||
{"sql": ["22"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "second"]}
|
||||
{"sql": ["23"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "third"]}
|
||||
{"sql": ["24"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "fourth"]}
|
||||
{"sql": ["25"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "fifth"]}
|
||||
{"sql": ["26"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "sixth"]}
|
||||
{"sql": ["27"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "seventh"]}
|
||||
{"sql": ["28"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "eighth"]}
|
||||
{"sql": ["29"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "ninth"]}
|
||||
{"sql": ["30"], "type": "DAY_NUMBER", "cleaned_nl": ["thirtieth"]}
|
||||
{"sql": ["31"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty", "first"]}
|
||||
{"sql": ["2"], "type": "DAY_NUMBER", "cleaned_nl": ["two"]}
|
||||
{"sql": ["3"], "type": "DAY_NUMBER", "cleaned_nl": ["three"]}
|
||||
{"sql": ["4"], "type": "DAY_NUMBER", "cleaned_nl": ["four"]}
|
||||
{"sql": ["5"], "type": "DAY_NUMBER", "cleaned_nl": ["five"]}
|
||||
{"sql": ["6"], "type": "DAY_NUMBER", "cleaned_nl": ["six"]}
|
||||
{"sql": ["7"], "type": "DAY_NUMBER", "cleaned_nl": ["seven"]}
|
||||
{"sql": ["8"], "type": "DAY_NUMBER", "cleaned_nl": ["eight"]}
|
||||
{"sql": ["9"], "type": "DAY_NUMBER", "cleaned_nl": ["nine"]}
|
||||
{"sql": ["10"], "type": "DAY_NUMBER", "cleaned_nl": ["ten"]}
|
||||
{"sql": ["11"], "type": "DAY_NUMBER", "cleaned_nl": ["eleven"]}
|
||||
{"sql": ["12"], "type": "DAY_NUMBER", "cleaned_nl": ["twelve"]}
|
||||
{"sql": ["13"], "type": "DAY_NUMBER", "cleaned_nl": ["thirteen"]}
|
||||
{"sql": ["14"], "type": "DAY_NUMBER", "cleaned_nl": ["fourteen"]}
|
||||
{"sql": ["15"], "type": "DAY_NUMBER", "cleaned_nl": ["fifteen"]}
|
||||
{"sql": ["16"], "type": "DAY_NUMBER", "cleaned_nl": ["sixteen"]}
|
||||
{"sql": ["17"], "type": "DAY_NUMBER", "cleaned_nl": ["seventeen"]}
|
||||
{"sql": ["18"], "type": "DAY_NUMBER", "cleaned_nl": ["eighteen"]}
|
||||
{"sql": ["19"], "type": "DAY_NUMBER", "cleaned_nl": ["nineteen"]}
|
||||
{"sql": ["20"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty"]}
|
||||
{"sql": ["21"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "one"]}
|
||||
{"sql": ["22"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "two"]}
|
||||
{"sql": ["23"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "three"]}
|
||||
{"sql": ["24"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "four"]}
|
||||
{"sql": ["25"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "five"]}
|
||||
{"sql": ["26"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "six"]}
|
||||
{"sql": ["27"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "seven"]}
|
||||
{"sql": ["28"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "eight"]}
|
||||
{"sql": ["29"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "nine"]}
|
||||
{"sql": ["30"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty"]}
|
||||
{"sql": ["31"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty", "one"]}
|
||||
{"sql": ["'COACH'"], "type": "FARE_TYPE", "cleaned_nl": ["coach"]}
|
||||
{"sql": ["'BUSINESS'"], "type": "FARE_TYPE", "cleaned_nl": ["business"]}
|
||||
{"sql": ["'FIRST'"], "type": "FARE_TYPE", "cleaned_nl": ["first", "class"]}
|
||||
{"sql": ["'THRIFT'"], "type": "FARE_TYPE", "cleaned_nl": ["thrift"]}
|
||||
{"sql": ["'STANDARD'"], "type": "FARE_TYPE", "cleaned_nl": ["standard"]}
|
||||
{"sql": ["'SHUTTLE'"], "type": "FARE_TYPE", "cleaned_nl": ["shuttle"]}
|
||||
{"sql": ["'B'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["b"]}
|
||||
{"sql": ["'BH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bh"]}
|
||||
{"sql": ["'BHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bhw"]}
|
||||
{"sql": ["'BHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bhx"]}
|
||||
{"sql": ["'BL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bl"]}
|
||||
{"sql": ["'BLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["blw"]}
|
||||
{"sql": ["'BLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["blx"]}
|
||||
{"sql": ["'BN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bn"]}
|
||||
{"sql": ["'BOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bow"]}
|
||||
{"sql": ["'BOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["box"]}
|
||||
{"sql": ["'BW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bw"]}
|
||||
{"sql": ["'BX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bx"]}
|
||||
{"sql": ["'C'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["c"]}
|
||||
{"sql": ["'CN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["cn"]}
|
||||
{"sql": ["'F'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["f"]}
|
||||
{"sql": ["'FN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["fn"]}
|
||||
{"sql": ["'H'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["h"]}
|
||||
{"sql": ["'HH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hh"]}
|
||||
{"sql": ["'HHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hhw"]}
|
||||
{"sql": ["'HHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hhx"]}
|
||||
{"sql": ["'HL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hl"]}
|
||||
{"sql": ["'HLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hlw"]}
|
||||
{"sql": ["'HLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hlx"]}
|
||||
{"sql": ["'HOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hox"]}
|
||||
{"sql": ["'J'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["j"]}
|
||||
{"sql": ["'K'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["k"]}
|
||||
{"sql": ["'KH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kh"]}
|
||||
{"sql": ["'KL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kl"]}
|
||||
{"sql": ["'KN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kn"]}
|
||||
{"sql": ["'LX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["lx"]}
|
||||
{"sql": ["'M'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["m"]}
|
||||
{"sql": ["'MH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["mh"]}
|
||||
{"sql": ["'ML'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["ml"]}
|
||||
{"sql": ["'MOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["mow"]}
|
||||
{"sql": ["'P'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["p"]}
|
||||
{"sql": ["'Q'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["q"]}
|
||||
{"sql": ["'QH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qh"]}
|
||||
{"sql": ["'QHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qhw"]}
|
||||
{"sql": ["'QHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qhx"]}
|
||||
{"sql": ["'QLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qlw"]}
|
||||
{"sql": ["'QLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qlx"]}
|
||||
{"sql": ["'QO'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qo"]}
|
||||
{"sql": ["'QOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qow"]}
|
||||
{"sql": ["'QOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qox"]}
|
||||
{"sql": ["'QW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qw"]}
|
||||
{"sql": ["'QX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qx"]}
|
||||
{"sql": ["'S'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["s"]}
|
||||
{"sql": ["'U'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["u"]}
|
||||
{"sql": ["'V'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["v"]}
|
||||
{"sql": ["'VHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vhw"]}
|
||||
{"sql": ["'VHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vhx"]}
|
||||
{"sql": ["'VW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vw"]}
|
||||
{"sql": ["'VX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vx"]}
|
||||
{"sql": ["'Y'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["y"]}
|
||||
{"sql": ["'YH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yh"]}
|
||||
{"sql": ["'YL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yl"]}
|
||||
{"sql": ["'YN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yn"]}
|
||||
{"sql": ["'YW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yw"]}
|
||||
{"sql": ["'YX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yx"]}
|
||||
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["rental", "car"]}
|
||||
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["rental", "cars"]}
|
||||
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car", "rental"]}
|
||||
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car", "rentals"]}
|
||||
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car"]}
|
||||
{"sql": ["'AIR TAXI OPERATION'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["air", "taxi", "operation"]}
|
||||
{"sql": ["'LIMOUSINE'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["limousine"]}
|
||||
{"sql": ["'TAXI'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["taxi"]}
|
||||
{"sql": ["'ARIZONA'"], "type": "STATE_NAME", "cleaned_nl": ["arizona"]}
|
||||
{"sql": ["'CALIFORNIA'"], "type": "STATE_NAME", "cleaned_nl": ["california"]}
|
||||
{"sql": ["'COLORADO'"], "type": "STATE_NAME", "cleaned_nl": ["colorado"]}
|
||||
{"sql": ["'DISTRICT OF COLUMBIA'"], "type": "STATE_NAME", "cleaned_nl": ["district", "of", "columbia"]}
|
||||
{"sql": ["'DISTRICT OF COLUMBIA'"], "type": "STATE_NAME", "cleaned_nl": ["dc"]}
|
||||
{"sql": ["'DC'"], "type": "STATE_NAME", "cleaned_nl": ["dc"]}
|
||||
{"sql": ["'FLORIDA'"], "type": "STATE_NAME", "cleaned_nl": ["florida"]}
|
||||
{"sql": ["'GEORGIA'"], "type": "STATE_NAME", "cleaned_nl": ["georgia"]}
|
||||
{"sql": ["'ILLINOIS'"], "type": "STATE_NAME", "cleaned_nl": ["illinois"]}
|
||||
{"sql": ["'INDIANA'"], "type": "STATE_NAME", "cleaned_nl": ["indiana"]}
|
||||
{"sql": ["'MASSACHUSETTS'"], "type": "STATE_NAME", "cleaned_nl": ["massachusetts"]}
|
||||
{"sql": ["'MARYLAND'"], "type": "STATE_NAME", "cleaned_nl": ["maryland"]}
|
||||
{"sql": ["'MICHIGAN'"], "type": "STATE_NAME", "cleaned_nl": ["michigan"]}
|
||||
{"sql": ["'MINNESOTA'"], "type": "STATE_NAME", "cleaned_nl": ["minnesota"]}
|
||||
{"sql": ["'MISSOURI'"], "type": "STATE_NAME", "cleaned_nl": ["missouri"]}
|
||||
{"sql": ["'NORTH CAROLINA'"], "type": "STATE_NAME", "cleaned_nl": ["north", "carolina"]}
|
||||
{"sql": ["'NEW JERSEY'"], "type": "STATE_NAME", "cleaned_nl": ["new", "jersey"]}
|
||||
{"sql": ["'NEVADA'"], "type": "STATE_NAME", "cleaned_nl": ["nevada"]}
|
||||
{"sql": ["'NEW YORK'"], "type": "STATE_NAME", "cleaned_nl": ["new", "york"]}
|
||||
{"sql": ["'OHIO'"], "type": "STATE_NAME", "cleaned_nl": ["ohio"]}
|
||||
{"sql": ["'ONTARIO'"], "type": "STATE_NAME", "cleaned_nl": ["ontario"]}
|
||||
{"sql": ["'PENNSYLVANIA'"], "type": "STATE_NAME", "cleaned_nl": ["pennsylvania"]}
|
||||
{"sql": ["'QUEBEC'"], "type": "STATE_NAME", "cleaned_nl": ["quebec"]}
|
||||
{"sql": ["'TENNESSEE'"], "type": "STATE_NAME", "cleaned_nl": ["tennessee"]}
|
||||
{"sql": ["'TEXAS'"], "type": "STATE_NAME", "cleaned_nl": ["texas"]}
|
||||
{"sql": ["'UTAH'"], "type": "STATE_NAME", "cleaned_nl": ["utah"]}
|
||||
{"sql": ["'WASHINGTON'"], "type": "STATE_NAME", "cleaned_nl": ["washington"]}
|
||||
{"sql": ["'WISCONSIN'"], "type": "STATE_NAME", "cleaned_nl": ["wisconsin"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1pm"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1am"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1", "o'clock", "am"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1", "o'clock", "pm"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1", "o'clock"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1", "o'clock"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["100"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["100"]}
|
||||
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["100am"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["100pm"]}
|
||||
{"sql": ["115"], "type": "TIME", "cleaned_nl": ["115"]}
|
||||
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["115"]}
|
||||
{"sql": ["115"], "type": "TIME", "cleaned_nl": ["115am"]}
|
||||
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["115pm"]}
|
||||
{"sql": ["130"], "type": "TIME", "cleaned_nl": ["130"]}
|
||||
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["130"]}
|
||||
{"sql": ["130"], "type": "TIME", "cleaned_nl": ["130am"]}
|
||||
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["130pm"]}
|
||||
{"sql": ["145"], "type": "TIME", "cleaned_nl": ["145"]}
|
||||
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["145"]}
|
||||
{"sql": ["145"], "type": "TIME", "cleaned_nl": ["145am"]}
|
||||
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["145pm"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2pm"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2am"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2", "o'clock", "am"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2", "o'clock", "pm"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2", "o'clock"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2", "o'clock"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["200"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["200"]}
|
||||
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["200am"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["200pm"]}
|
||||
{"sql": ["215"], "type": "TIME", "cleaned_nl": ["215"]}
|
||||
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["215"]}
|
||||
{"sql": ["215"], "type": "TIME", "cleaned_nl": ["215am"]}
|
||||
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["215pm"]}
|
||||
{"sql": ["230"], "type": "TIME", "cleaned_nl": ["230"]}
|
||||
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["230"]}
|
||||
{"sql": ["230"], "type": "TIME", "cleaned_nl": ["230am"]}
|
||||
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["230pm"]}
|
||||
{"sql": ["245"], "type": "TIME", "cleaned_nl": ["245"]}
|
||||
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["245"]}
|
||||
{"sql": ["245"], "type": "TIME", "cleaned_nl": ["245am"]}
|
||||
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["245pm"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3pm"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3am"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3", "o'clock", "am"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3", "o'clock", "pm"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3", "o'clock"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3", "o'clock"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["300"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["300"]}
|
||||
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["300am"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["300pm"]}
|
||||
{"sql": ["315"], "type": "TIME", "cleaned_nl": ["315"]}
|
||||
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["315"]}
|
||||
{"sql": ["315"], "type": "TIME", "cleaned_nl": ["315am"]}
|
||||
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["315pm"]}
|
||||
{"sql": ["330"], "type": "TIME", "cleaned_nl": ["330"]}
|
||||
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["330"]}
|
||||
{"sql": ["330"], "type": "TIME", "cleaned_nl": ["330am"]}
|
||||
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["330pm"]}
|
||||
{"sql": ["345"], "type": "TIME", "cleaned_nl": ["345"]}
|
||||
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["345"]}
|
||||
{"sql": ["345"], "type": "TIME", "cleaned_nl": ["345am"]}
|
||||
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["345pm"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4pm"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4am"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4", "o'clock", "am"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4", "o'clock", "pm"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4", "o'clock"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4", "o'clock"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["400"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["400"]}
|
||||
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["400am"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["400pm"]}
|
||||
{"sql": ["415"], "type": "TIME", "cleaned_nl": ["415"]}
|
||||
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["415"]}
|
||||
{"sql": ["415"], "type": "TIME", "cleaned_nl": ["415am"]}
|
||||
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["415pm"]}
|
||||
{"sql": ["430"], "type": "TIME", "cleaned_nl": ["430"]}
|
||||
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["430"]}
|
||||
{"sql": ["430"], "type": "TIME", "cleaned_nl": ["430am"]}
|
||||
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["430pm"]}
|
||||
{"sql": ["445"], "type": "TIME", "cleaned_nl": ["445"]}
|
||||
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["445"]}
|
||||
{"sql": ["445"], "type": "TIME", "cleaned_nl": ["445am"]}
|
||||
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["445pm"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5pm"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5am"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5", "o'clock", "am"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5", "o'clock", "pm"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5", "o'clock"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5", "o'clock"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["500"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["500"]}
|
||||
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["500am"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["500pm"]}
|
||||
{"sql": ["515"], "type": "TIME", "cleaned_nl": ["515"]}
|
||||
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["515"]}
|
||||
{"sql": ["515"], "type": "TIME", "cleaned_nl": ["515am"]}
|
||||
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["515pm"]}
|
||||
{"sql": ["530"], "type": "TIME", "cleaned_nl": ["530"]}
|
||||
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["530"]}
|
||||
{"sql": ["530"], "type": "TIME", "cleaned_nl": ["530am"]}
|
||||
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["530pm"]}
|
||||
{"sql": ["545"], "type": "TIME", "cleaned_nl": ["545"]}
|
||||
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["545"]}
|
||||
{"sql": ["545"], "type": "TIME", "cleaned_nl": ["545am"]}
|
||||
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["545pm"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6pm"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6am"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6", "o'clock", "am"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6", "o'clock", "pm"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6", "o'clock"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6", "o'clock"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["600"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["600"]}
|
||||
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["600am"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["600pm"]}
|
||||
{"sql": ["615"], "type": "TIME", "cleaned_nl": ["615"]}
|
||||
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["615"]}
|
||||
{"sql": ["615"], "type": "TIME", "cleaned_nl": ["615am"]}
|
||||
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["615pm"]}
|
||||
{"sql": ["630"], "type": "TIME", "cleaned_nl": ["630"]}
|
||||
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["630"]}
|
||||
{"sql": ["630"], "type": "TIME", "cleaned_nl": ["630am"]}
|
||||
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["630pm"]}
|
||||
{"sql": ["645"], "type": "TIME", "cleaned_nl": ["645"]}
|
||||
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["645"]}
|
||||
{"sql": ["645"], "type": "TIME", "cleaned_nl": ["645am"]}
|
||||
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["645pm"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7pm"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7am"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7", "o'clock", "am"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7", "o'clock", "pm"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7", "o'clock"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7", "o'clock"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["700"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["700"]}
|
||||
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["700am"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["700pm"]}
|
||||
{"sql": ["715"], "type": "TIME", "cleaned_nl": ["715"]}
|
||||
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["715"]}
|
||||
{"sql": ["715"], "type": "TIME", "cleaned_nl": ["715am"]}
|
||||
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["715pm"]}
|
||||
{"sql": ["730"], "type": "TIME", "cleaned_nl": ["730"]}
|
||||
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["730"]}
|
||||
{"sql": ["730"], "type": "TIME", "cleaned_nl": ["730am"]}
|
||||
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["730pm"]}
|
||||
{"sql": ["745"], "type": "TIME", "cleaned_nl": ["745"]}
|
||||
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["745"]}
|
||||
{"sql": ["745"], "type": "TIME", "cleaned_nl": ["745am"]}
|
||||
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["745pm"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8pm"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8am"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8", "o'clock", "am"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8", "o'clock", "pm"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8", "o'clock"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8", "o'clock"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["800"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["800"]}
|
||||
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["800am"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["800pm"]}
|
||||
{"sql": ["815"], "type": "TIME", "cleaned_nl": ["815"]}
|
||||
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["815"]}
|
||||
{"sql": ["815"], "type": "TIME", "cleaned_nl": ["815am"]}
|
||||
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["815pm"]}
|
||||
{"sql": ["830"], "type": "TIME", "cleaned_nl": ["830"]}
|
||||
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["830"]}
|
||||
{"sql": ["830"], "type": "TIME", "cleaned_nl": ["830am"]}
|
||||
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["830pm"]}
|
||||
{"sql": ["845"], "type": "TIME", "cleaned_nl": ["845"]}
|
||||
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["845"]}
|
||||
{"sql": ["845"], "type": "TIME", "cleaned_nl": ["845am"]}
|
||||
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["845pm"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9pm"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9am"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9", "o'clock", "am"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9", "o'clock", "pm"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9", "o'clock"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9", "o'clock"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["900"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["900"]}
|
||||
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["900am"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["900pm"]}
|
||||
{"sql": ["915"], "type": "TIME", "cleaned_nl": ["915"]}
|
||||
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["915"]}
|
||||
{"sql": ["915"], "type": "TIME", "cleaned_nl": ["915am"]}
|
||||
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["915pm"]}
|
||||
{"sql": ["930"], "type": "TIME", "cleaned_nl": ["930"]}
|
||||
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["930"]}
|
||||
{"sql": ["930"], "type": "TIME", "cleaned_nl": ["930am"]}
|
||||
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["930pm"]}
|
||||
{"sql": ["945"], "type": "TIME", "cleaned_nl": ["945"]}
|
||||
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["945"]}
|
||||
{"sql": ["945"], "type": "TIME", "cleaned_nl": ["945am"]}
|
||||
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["945pm"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10pm"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10am"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10", "o'clock", "am"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10", "o'clock", "pm"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10", "o'clock"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10", "o'clock"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["1000"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["1000"]}
|
||||
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["1000am"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["1000pm"]}
|
||||
{"sql": ["1015"], "type": "TIME", "cleaned_nl": ["1015"]}
|
||||
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["1015"]}
|
||||
{"sql": ["1015"], "type": "TIME", "cleaned_nl": ["1015am"]}
|
||||
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["1015pm"]}
|
||||
{"sql": ["1030"], "type": "TIME", "cleaned_nl": ["1030"]}
|
||||
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["1030"]}
|
||||
{"sql": ["1030"], "type": "TIME", "cleaned_nl": ["1030am"]}
|
||||
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["1030pm"]}
|
||||
{"sql": ["1045"], "type": "TIME", "cleaned_nl": ["1045"]}
|
||||
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["1045"]}
|
||||
{"sql": ["1045"], "type": "TIME", "cleaned_nl": ["1045am"]}
|
||||
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["1045pm"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11pm"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11am"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11", "o'clock", "am"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11", "o'clock", "pm"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11", "o'clock"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11", "o'clock"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["1100"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["1100"]}
|
||||
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["1100am"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["1100pm"]}
|
||||
{"sql": ["1115"], "type": "TIME", "cleaned_nl": ["1115"]}
|
||||
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["1115"]}
|
||||
{"sql": ["1115"], "type": "TIME", "cleaned_nl": ["1115am"]}
|
||||
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["1115pm"]}
|
||||
{"sql": ["1130"], "type": "TIME", "cleaned_nl": ["1130"]}
|
||||
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["1130"]}
|
||||
{"sql": ["1130"], "type": "TIME", "cleaned_nl": ["1130am"]}
|
||||
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["1130pm"]}
|
||||
{"sql": ["1145"], "type": "TIME", "cleaned_nl": ["1145"]}
|
||||
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["1145"]}
|
||||
{"sql": ["1145"], "type": "TIME", "cleaned_nl": ["1145am"]}
|
||||
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["1145pm"]}
|
||||
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["12"]}
|
||||
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["1200"]}
|
||||
{"sql": ["2400"], "type": "TIME", "cleaned_nl": ["1200"]}
|
||||
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["1200am"]}
|
||||
{"sql": ["2400"], "type": "TIME", "cleaned_nl": ["1200pm"]}
|
||||
{"sql": ["1215"], "type": "TIME", "cleaned_nl": ["1215"]}
|
||||
{"sql": ["2415"], "type": "TIME", "cleaned_nl": ["1215"]}
|
||||
{"sql": ["1215"], "type": "TIME", "cleaned_nl": ["1215am"]}
|
||||
{"sql": ["2415"], "type": "TIME", "cleaned_nl": ["1215pm"]}
|
||||
{"sql": ["1230"], "type": "TIME", "cleaned_nl": ["1230"]}
|
||||
{"sql": ["2430"], "type": "TIME", "cleaned_nl": ["1230"]}
|
||||
{"sql": ["1230"], "type": "TIME", "cleaned_nl": ["1230am"]}
|
||||
{"sql": ["2430"], "type": "TIME", "cleaned_nl": ["1230pm"]}
|
||||
{"sql": ["1245"], "type": "TIME", "cleaned_nl": ["1245"]}
|
||||
{"sql": ["2445"], "type": "TIME", "cleaned_nl": ["1245"]}
|
||||
{"sql": ["1245"], "type": "TIME", "cleaned_nl": ["1245am"]}
|
||||
{"sql": ["2445"], "type": "TIME", "cleaned_nl": ["1245pm"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["13"]}
|
||||
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1300"]}
|
||||
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["1315"]}
|
||||
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["1330"]}
|
||||
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["1345"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["14"]}
|
||||
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["1400"]}
|
||||
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["1415"]}
|
||||
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["1430"]}
|
||||
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["1445"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["15"]}
|
||||
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["1500"]}
|
||||
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["1515"]}
|
||||
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["1530"]}
|
||||
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["1545"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["16"]}
|
||||
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["1600"]}
|
||||
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["1615"]}
|
||||
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["1630"]}
|
||||
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["1645"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["17"]}
|
||||
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["1700"]}
|
||||
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["1715"]}
|
||||
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["1730"]}
|
||||
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["1745"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["18"]}
|
||||
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["1800"]}
|
||||
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["1815"]}
|
||||
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["1830"]}
|
||||
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["1845"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["19"]}
|
||||
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["1900"]}
|
||||
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["1915"]}
|
||||
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["1930"]}
|
||||
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["1945"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["20"]}
|
||||
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["2000"]}
|
||||
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["2015"]}
|
||||
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["2030"]}
|
||||
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["2045"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["21"]}
|
||||
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["2100"]}
|
||||
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["2115"]}
|
||||
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["2130"]}
|
||||
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["2145"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["22"]}
|
||||
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["2200"]}
|
||||
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["2215"]}
|
||||
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["2230"]}
|
||||
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["2245"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["23"]}
|
||||
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["2300"]}
|
||||
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["2315"]}
|
||||
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["2330"]}
|
||||
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["2345"]}
|
||||
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["noon"]}
|
||||
{"sql": ["'BREAKFAST'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["breakfast"]}
|
||||
{"sql": ["'LUNCH'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["lunch"]}
|
||||
{"sql": ["'DINNER'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["dinner"]}
|
||||
{"sql": ["'SNACK'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["snack"]}
|
||||
{"sql": ["'ATL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["atl"]}
|
||||
{"sql": ["'ATL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["atlanta", "airport"]}
|
||||
{"sql": ["'BNA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bna"]}
|
||||
{"sql": ["'BOS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bos"]}
|
||||
{"sql": ["'BOS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["boston", "airport"]}
|
||||
{"sql": ["'BUR'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bur"]}
|
||||
{"sql": ["'BWI'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bwi"]}
|
||||
{"sql": ["'CLE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cle"]}
|
||||
{"sql": ["'CLT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["clt"]}
|
||||
{"sql": ["'CMH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cmh"]}
|
||||
{"sql": ["'CVG'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cvg"]}
|
||||
{"sql": ["'DAL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dal"]}
|
||||
{"sql": ["'DAL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["love", "field"]}
|
||||
{"sql": ["'DCA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dca"]}
|
||||
{"sql": ["'DEN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["den"]}
|
||||
{"sql": ["'DEN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["denver", "airport"]}
|
||||
{"sql": ["'DET'"], "type": "AIRPORT_CODE", "cleaned_nl": ["det"]}
|
||||
{"sql": ["'DFW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dfw"]}
|
||||
{"sql": ["'DFW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dallas", "airport"]}
|
||||
{"sql": ["'DTW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dtw"]}
|
||||
{"sql": ["'EWR'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ewr"]}
|
||||
{"sql": ["'HOU'"], "type": "AIRPORT_CODE", "cleaned_nl": ["hou"]}
|
||||
{"sql": ["'HPN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["hpn"]}
|
||||
{"sql": ["'IAD'"], "type": "AIRPORT_CODE", "cleaned_nl": ["iad"]}
|
||||
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["iah"]}
|
||||
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "international"]}
|
||||
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "international", "airport"]}
|
||||
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "intercontinental", "airport"]}
|
||||
{"sql": ["'IND'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ind"]}
|
||||
{"sql": ["'JFK'"], "type": "AIRPORT_CODE", "cleaned_nl": ["jfk"]}
|
||||
{"sql": ["'LAS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["las"]}
|
||||
{"sql": ["'LAX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lax"]}
|
||||
{"sql": ["'LAX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["los", "angeles", "international", "airport"]}
|
||||
{"sql": ["'LGA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lga"]}
|
||||
{"sql": ["'LGA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["la", "guardia"]}
|
||||
{"sql": ["'LGB'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lgb"]}
|
||||
{"sql": ["'MCI'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mci"]}
|
||||
{"sql": ["'MCO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mco"]}
|
||||
{"sql": ["'MDW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mdw"]}
|
||||
{"sql": ["'MEM'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mem"]}
|
||||
{"sql": ["'MIA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mia"]}
|
||||
{"sql": ["'MKE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mke"]}
|
||||
{"sql": ["'MKE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["general", "mitchell", "airport"]}
|
||||
{"sql": ["'MSP'"], "type": "AIRPORT_CODE", "cleaned_nl": ["msp"]}
|
||||
{"sql": ["'OAK'"], "type": "AIRPORT_CODE", "cleaned_nl": ["oak"]}
|
||||
{"sql": ["'ONT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ont"]}
|
||||
{"sql": ["'ORD'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ord"]}
|
||||
{"sql": ["'PHL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["phl"]}
|
||||
{"sql": ["'PHX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["phx"]}
|
||||
{"sql": ["'PIE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pie"]}
|
||||
{"sql": ["'PIT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pit"]}
|
||||
{"sql": ["'PIT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pittsburgh", "airport"]}
|
||||
{"sql": ["'SAN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san"]}
|
||||
{"sql": ["'SEA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sea"]}
|
||||
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sfo"]}
|
||||
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san", "francisco", "international"]}
|
||||
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san", "francisco", "international", "airport"]}
|
||||
{"sql": ["'SJC'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sjc"]}
|
||||
{"sql": ["'SLC'"], "type": "AIRPORT_CODE", "cleaned_nl": ["slc"]}
|
||||
{"sql": ["'STL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["stl"]}
|
||||
{"sql": ["'TPA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["tpa"]}
|
||||
{"sql": ["'YKZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ykz"]}
|
||||
{"sql": ["'YMX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ymx"]}
|
||||
{"sql": ["'YTZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ytz"]}
|
||||
{"sql": ["'YUL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["yul"]}
|
||||
{"sql": ["'YYZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["yyz"]}
|
||||
{"sql": ["'72S'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["72s"]}
|
||||
{"sql": ["'734'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["734"]}
|
||||
{"sql": ["'M80'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["m80"]}
|
||||
{"sql": ["'D9S'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["d9s"]}
|
||||
{"cleaned_nl": ["ap/2"], "type": "RESTRICTION_CODE", "sql": ["'AP/2'"]}
|
||||
{"cleaned_nl": ["ap/6"], "type": "RESTRICTION_CODE", "sql": ["'AP/6'"]}
|
||||
{"cleaned_nl": ["ap/12"], "type": "RESTRICTION_CODE", "sql": ["'AP/12'"]}
|
||||
{"cleaned_nl": ["ap/20"], "type": "RESTRICTION_CODE", "sql": ["'AP/20'"]}
|
||||
{"cleaned_nl": ["ap/21"], "type": "RESTRICTION_CODE", "sql": ["'AP/21'"]}
|
||||
{"cleaned_nl": ["ap/57"], "type": "RESTRICTION_CODE", "sql": ["'AP/57'"]}
|
||||
{"cleaned_nl": ["ap", "57"], "type": "RESTRICTION_CODE", "sql": ["'AP/57'"]}
|
||||
{"cleaned_nl": ["ap/58"], "type": "RESTRICTION_CODE", "sql": ["'AP/58'"]}
|
||||
{"cleaned_nl": ["ap/60"], "type": "RESTRICTION_CODE", "sql": ["'AP/60'"]}
|
||||
{"cleaned_nl": ["ap/80"], "type": "RESTRICTION_CODE", "sql": ["'AP/80'"]}
|
||||
{"cleaned_nl": ["ap", "80"], "type": "RESTRICTION_CODE", "sql": ["'AP/80'"]}
|
||||
{"cleaned_nl": ["ap/75"], "type": "RESTRICTION_CODE", "sql": ["'AP/75'"]}
|
||||
{"cleaned_nl": ["ex/9"], "type": "RESTRICTION_CODE", "sql": ["'EX/9'"]}
|
||||
{"cleaned_nl": ["ex/13"], "type": "RESTRICTION_CODE", "sql": ["'EX/13'"]}
|
||||
{"cleaned_nl": ["ex/14"], "type": "RESTRICTION_CODE", "sql": ["'EX/14'"]}
|
||||
{"cleaned_nl": ["ex/17"], "type": "RESTRICTION_CODE", "sql": ["'EX/17'"]}
|
||||
{"cleaned_nl": ["ex/19"], "type": "RESTRICTION_CODE", "sql": ["'EX/19'"]}
|
|
@ -0,0 +1,3 @@
|
|||
raw/atis2/12-1.1/ATIS2/TEXT/TRAIN/MIT/4L0/1
|
||||
raw/atis2/12-1.1/ATIS2/TEXT/TEST/NOV92/770/5
|
||||
raw/atis3_test/17_4.2/atis3/sp_tst/dev94/prelim/g0x/4
|
|
@ -0,0 +1,902 @@
|
|||
{"input": "northeast express regional airlines", "output": ["'2V'"]}
|
||||
{"input": "air alliance", "output": ["'3J'"]}
|
||||
{"input": "alpha air", "output": ["'7V'"]}
|
||||
{"input": "express airlines i", "output": ["'9E'"]}
|
||||
{"input": "colgan air", "output": ["'9L'"]}
|
||||
{"input": "trans states airlines", "output": ["'9N'"]}
|
||||
{"input": "ontario express", "output": ["'9X'"]}
|
||||
{"input": "american airlines", "output": ["'AA'"]}
|
||||
{"input": "american airline", "output": ["'AA'"]}
|
||||
{"input": "american", "output": ["'AA'"]}
|
||||
{"input": "air canada", "output": ["'AC'"]}
|
||||
{"input": "aerolineas argentinas", "output": ["'AR'"]}
|
||||
{"input": "alaska airlines", "output": ["'AS'"]}
|
||||
{"input": "royal air maroc", "output": ["'AT'"]}
|
||||
{"input": "british airways", "output": ["'BA'"]}
|
||||
{"input": "braniff international airlines", "output": ["'BE'"]}
|
||||
{"input": "continental airlines", "output": ["'CO'"]}
|
||||
{"input": "continental", "output": ["'CO'"]}
|
||||
{"input": "canadian airlines international", "output": ["'CP'"]}
|
||||
{"input": "canadian", "output": ["'CP'"]}
|
||||
{"input": "canadian airlines", "output": ["'CP'"]}
|
||||
{"input": "atlantic coast airlines", "output": ["'DH'"]}
|
||||
{"input": "delta airlines", "output": ["'DL'"]}
|
||||
{"input": "delta", "output": ["'DL'"]}
|
||||
{"input": "delta airline", "output": ["'DL'"]}
|
||||
{"input": "atlantic southeast airlines", "output": ["'EV'"]}
|
||||
{"input": "tower air", "output": ["'FF'"]}
|
||||
{"input": "air ontario", "output": ["'GX'"]}
|
||||
{"input": "america west airlines", "output": ["'HP'"]}
|
||||
{"input": "america west", "output": ["'HP'"]}
|
||||
{"input": "business express", "output": ["'HQ'"]}
|
||||
{"input": "carnival air lines", "output": ["'KW'"]}
|
||||
{"input": "lufthansa german airlines", "output": ["'LH'"]}
|
||||
{"input": "lufthansa", "output": ["'LH'"]}
|
||||
{"input": "lufthansa airlines", "output": ["'LH'"]}
|
||||
{"input": "mgm grand air", "output": ["'MG'"]}
|
||||
{"input": "northwest airlines", "output": ["'NW'"]}
|
||||
{"input": "northwest airline", "output": ["'NW'"]}
|
||||
{"input": "northwest", "output": ["'NW'"]}
|
||||
{"input": "nationair", "output": ["'NX'"]}
|
||||
{"input": "westair airlines", "output": ["'OE'"]}
|
||||
{"input": "comair", "output": ["'OH'"]}
|
||||
{"input": "czechoslovak airlines", "output": ["'OK'"]}
|
||||
{"input": "sky west airlines", "output": ["'OO'"]}
|
||||
{"input": "grand airways", "output": ["'QD'"]}
|
||||
{"input": "precision airlines", "output": ["'RP'"]}
|
||||
{"input": "trans world express", "output": ["'RZ'"]}
|
||||
{"input": "sabena belgian world airlines", "output": ["'SN'"]}
|
||||
{"input": "christman air system", "output": ["'SX'"]}
|
||||
{"input": "thai airways international,ltd.", "output": ["'TG'"]}
|
||||
{"input": "trans world airlines", "output": ["'TW'"]}
|
||||
{"input": "twa", "output": ["'TW'"]}
|
||||
{"input": "american trans air", "output": ["'TZ'"]}
|
||||
{"input": "united airlines", "output": ["'UA'"]}
|
||||
{"input": "united", "output": ["'UA'"]}
|
||||
{"input": "usair", "output": ["'US'"]}
|
||||
{"input": "us air", "output": ["'US'"]}
|
||||
{"input": "southwest airlines", "output": ["'WN'"]}
|
||||
{"input": "southwest air", "output": ["'WN'"]}
|
||||
{"input": "southwest", "output": ["'WN'"]}
|
||||
{"input": "mesaba aviation", "output": ["'XJ'"]}
|
||||
{"input": "midwest express airlines", "output": ["'YX'"]}
|
||||
{"input": "midwest express", "output": ["'YX'"]}
|
||||
{"input": "air wisconsin", "output": ["'ZW'"]}
|
||||
{"input": "2v", "output": ["'2V'"]}
|
||||
{"input": "3j", "output": ["'3J'"]}
|
||||
{"input": "7v", "output": ["'7V'"]}
|
||||
{"input": "9e", "output": ["'9E'"]}
|
||||
{"input": "9l", "output": ["'9L'"]}
|
||||
{"input": "9n", "output": ["'9N'"]}
|
||||
{"input": "9x", "output": ["'9X'"]}
|
||||
{"input": "aa", "output": ["'AA'"]}
|
||||
{"input": "ac", "output": ["'AC'"]}
|
||||
{"input": "ar", "output": ["'AR'"]}
|
||||
{"input": "as", "output": ["'AS'"]}
|
||||
{"input": "at", "output": ["'AT'"]}
|
||||
{"input": "ba", "output": ["'BA'"]}
|
||||
{"input": "be", "output": ["'BE'"]}
|
||||
{"input": "co", "output": ["'CO'"]}
|
||||
{"input": "cp", "output": ["'CP'"]}
|
||||
{"input": "dh", "output": ["'DH'"]}
|
||||
{"input": "dl", "output": ["'DL'"]}
|
||||
{"input": "ev", "output": ["'EV'"]}
|
||||
{"input": "ff", "output": ["'FF'"]}
|
||||
{"input": "gx", "output": ["'GX'"]}
|
||||
{"input": "hp", "output": ["'HP'"]}
|
||||
{"input": "hq", "output": ["'HQ'"]}
|
||||
{"input": "kw", "output": ["'KW'"]}
|
||||
{"input": "lh", "output": ["'LH'"]}
|
||||
{"input": "mg", "output": ["'MG'"]}
|
||||
{"input": "nw", "output": ["'NW'"]}
|
||||
{"input": "nx", "output": ["'NX'"]}
|
||||
{"input": "oe", "output": ["'OE'"]}
|
||||
{"input": "oh", "output": ["'OH'"]}
|
||||
{"input": "ok", "output": ["'OK'"]}
|
||||
{"input": "oo", "output": ["'OO'"]}
|
||||
{"input": "qd", "output": ["'QD'"]}
|
||||
{"input": "rp", "output": ["'RP'"]}
|
||||
{"input": "rz", "output": ["'RZ'"]}
|
||||
{"input": "sn", "output": ["'SN'"]}
|
||||
{"input": "sx", "output": ["'SX'"]}
|
||||
{"input": "tg", "output": ["'TG'"]}
|
||||
{"input": "tw", "output": ["'TW'"]}
|
||||
{"input": "tz", "output": ["'TZ'"]}
|
||||
{"input": "ua", "output": ["'UA'"]}
|
||||
{"input": "us", "output": ["'US'"]}
|
||||
{"input": "wn", "output": ["'WN'"]}
|
||||
{"input": "xj", "output": ["'XJ'"]}
|
||||
{"input": "yx", "output": ["'YX'"]}
|
||||
{"input": "zw", "output": ["'ZW'"]}
|
||||
{"input": "nashville", "output": ["'NASHVILLE'"]}
|
||||
{"input": "boston", "output": ["'BOSTON'"]}
|
||||
{"input": "burbank", "output": ["'BURBANK'"]}
|
||||
{"input": "baltimore", "output": ["'BALTIMORE'"]}
|
||||
{"input": "chicago", "output": ["'CHICAGO'"]}
|
||||
{"input": "cleveland", "output": ["'CLEVELAND'"]}
|
||||
{"input": "charlotte", "output": ["'CHARLOTTE'"]}
|
||||
{"input": "columbus", "output": ["'COLUMBUS'"]}
|
||||
{"input": "cincinnati", "output": ["'CINCINNATI'"]}
|
||||
{"input": "denver", "output": ["'DENVER'"]}
|
||||
{"input": "dallas", "output": ["'DALLAS'"]}
|
||||
{"input": "detroit", "output": ["'DETROIT'"]}
|
||||
{"input": "fort worth", "output": ["'FORT", "WORTH'"]}
|
||||
{"input": "houston", "output": ["'HOUSTON'"]}
|
||||
{"input": "westchester county", "output": ["'WESTCHESTER COUNTY'"]}
|
||||
{"input": "westchester", "output": ["'WESTCHESTER COUNTY'"]}
|
||||
{"input": "indianapolis", "output": ["'INDIANAPOLIS'"]}
|
||||
{"input": "newark", "output": ["'NEWARK'"]}
|
||||
{"input": "las vegas", "output": ["'LAS VEGAS'"]}
|
||||
{"input": "los angeles", "output": ["'LOS ANGELES'"]}
|
||||
{"input": "long beach", "output": ["'LONG BEACH'"]}
|
||||
{"input": "atlanta", "output": ["'ATLANTA'"]}
|
||||
{"input": "memphis", "output": ["'MEMPHIS'"]}
|
||||
{"input": "miami", "output": ["'MIAMI'"]}
|
||||
{"input": "kansas city", "output": ["'KANSAS CITY'"]}
|
||||
{"input": "milwaukee", "output": ["'MILWAUKEE'"]}
|
||||
{"input": "minneapolis", "output": ["'MINNEAPOLIS'"]}
|
||||
{"input": "new york", "output": ["'NEW YORK'"]}
|
||||
{"input": "new york city", "output": ["'NEW YORK'"]}
|
||||
{"input": "oakland", "output": ["'OAKLAND'"]}
|
||||
{"input": "ontario", "output": ["'ONTARIO'"]}
|
||||
{"input": "orlando", "output": ["'ORLANDO'"]}
|
||||
{"input": "philadelphia", "output": ["'PHILADELPHIA'"]}
|
||||
{"input": "phoenix", "output": ["'PHOENIX'"]}
|
||||
{"input": "pittsburgh", "output": ["'PITTSBURGH'"]}
|
||||
{"input": "st. paul", "output": ["'ST. PAUL'"]}
|
||||
{"input": "saint paul", "output": ["'ST. PAUL'"]}
|
||||
{"input": "san diego", "output": ["'SAN DIEGO'"]}
|
||||
{"input": "seattle", "output": ["'SEATTLE'"]}
|
||||
{"input": "san francisco", "output": ["'SAN FRANCISCO'"]}
|
||||
{"input": "san jose", "output": ["'SAN JOSE'"]}
|
||||
{"input": "salt lake city", "output": ["'SALT LAKE CITY'"]}
|
||||
{"input": "st. louis", "output": ["'ST. LOUIS'"]}
|
||||
{"input": "saint louis", "output": ["'ST. LOUIS'"]}
|
||||
{"input": "saint petersburg", "output": ["'ST. PETERSBURG'"]}
|
||||
{"input": "st. petersburg", "output": ["'ST. PETERSBURG'"]}
|
||||
{"input": "tacoma", "output": ["'TACOMA'"]}
|
||||
{"input": "tampa", "output": ["'TAMPA'"]}
|
||||
{"input": "washington", "output": ["'WASHINGTON'"]}
|
||||
{"input": "montreal", "output": ["'MONTREAL'"]}
|
||||
{"input": "toronto", "output": ["'TORONTO'"]}
|
||||
{"input": "january", "output": ["1"]}
|
||||
{"input": "february", "output": ["2"]}
|
||||
{"input": "march", "output": ["3"]}
|
||||
{"input": "april", "output": ["4"]}
|
||||
{"input": "may", "output": ["5"]}
|
||||
{"input": "june", "output": ["6"]}
|
||||
{"input": "july", "output": ["7"]}
|
||||
{"input": "august", "output": ["8"]}
|
||||
{"input": "september", "output": ["9"]}
|
||||
{"input": "october", "output": ["10"]}
|
||||
{"input": "november", "output": ["11"]}
|
||||
{"input": "december", "output": ["12"]}
|
||||
{"input": "monday", "output": ["'MONDAY'"]}
|
||||
{"input": "tuesday", "output": ["'TUESDAY'"]}
|
||||
{"input": "wednesday", "output": ["'WEDNESDAY'"]}
|
||||
{"input": "thursday", "output": ["'THURSDAY'"]}
|
||||
{"input": "friday", "output": ["'FRIDAY'"]}
|
||||
{"input": "saturday", "output": ["'SATURDAY'"]}
|
||||
{"input": "sunday", "output": ["'SUNDAY'"]}
|
||||
{"input": "first", "output": ["1"]}
|
||||
{"input": "second", "output": ["2"]}
|
||||
{"input": "third", "output": ["3"]}
|
||||
{"input": "fourth", "output": ["4"]}
|
||||
{"input": "fifth", "output": ["5"]}
|
||||
{"input": "sixth", "output": ["6"]}
|
||||
{"input": "seventh", "output": ["7"]}
|
||||
{"input": "eighth", "output": ["8"]}
|
||||
{"input": "ninth", "output": ["9"]}
|
||||
{"input": "tenth", "output": ["10"]}
|
||||
{"input": "eleventh", "output": ["11"]}
|
||||
{"input": "twelfth", "output": ["12"]}
|
||||
{"input": "thirteenth", "output": ["13"]}
|
||||
{"input": "fourteenth", "output": ["14"]}
|
||||
{"input": "fifteenth", "output": ["15"]}
|
||||
{"input": "sixteenth", "output": ["16"]}
|
||||
{"input": "seventeenth", "output": ["17"]}
|
||||
{"input": "eighteenth", "output": ["18"]}
|
||||
{"input": "nineteenth", "output": ["19"]}
|
||||
{"input": "twentieth", "output": ["20"]}
|
||||
{"input": "twenty first", "output": ["21"]}
|
||||
{"input": "twenty second", "output": ["22"]}
|
||||
{"input": "twenty third", "output": ["23"]}
|
||||
{"input": "twenty fourth", "output": ["24"]}
|
||||
{"input": "twenty fifth", "output": ["25"]}
|
||||
{"input": "twenty sixth", "output": ["26"]}
|
||||
{"input": "twenty seventh", "output": ["27"]}
|
||||
{"input": "twenty eighth", "output": ["28"]}
|
||||
{"input": "twenty ninth", "output": ["29"]}
|
||||
{"input": "thirtieth", "output": ["30"]}
|
||||
{"input": "thirty first", "output": ["31"]}
|
||||
{"input": "coach", "output": ["'COACH'"]}
|
||||
{"input": "business", "output": ["'BUSINESS'"]}
|
||||
{"input": "first", "output": ["'FIRST'"]}
|
||||
{"input": "first class", "output": ["'FIRST'"]}
|
||||
{"input": "thrift", "output": ["'THRIFT'"]}
|
||||
{"input": "standard", "output": ["'STANDARD'"]}
|
||||
{"input": "shuttle", "output": ["'SHUTTLE'"]}
|
||||
{"input": "b", "output": ["'B'"]}
|
||||
{"input": "bh", "output": ["'BH'"]}
|
||||
{"input": "bhw", "output": ["'BHW'"]}
|
||||
{"input": "bhx", "output": ["'BHX'"]}
|
||||
{"input": "bl", "output": ["'BL'"]}
|
||||
{"input": "blw", "output": ["'BLW'"]}
|
||||
{"input": "blx", "output": ["'BLX'"]}
|
||||
{"input": "bn", "output": ["'BN'"]}
|
||||
{"input": "bow", "output": ["'BOW'"]}
|
||||
{"input": "box", "output": ["'BOX'"]}
|
||||
{"input": "bw", "output": ["'BW'"]}
|
||||
{"input": "bx", "output": ["'BX'"]}
|
||||
{"input": "c", "output": ["'C'"]}
|
||||
{"input": "cn", "output": ["'CN'"]}
|
||||
{"input": "f", "output": ["'F'"]}
|
||||
{"input": "fn", "output": ["'FN'"]}
|
||||
{"input": "h", "output": ["'H'"]}
|
||||
{"input": "hh", "output": ["'HH'"]}
|
||||
{"input": "hhw", "output": ["'HHW'"]}
|
||||
{"input": "hhx", "output": ["'HHX'"]}
|
||||
{"input": "hl", "output": ["'HL'"]}
|
||||
{"input": "hlw", "output": ["'HLW'"]}
|
||||
{"input": "hlx", "output": ["'HLX'"]}
|
||||
{"input": "how", "output": ["'HOW'"]}
|
||||
{"input": "hox", "output": ["'HOX'"]}
|
||||
{"input": "j", "output": ["'J'"]}
|
||||
{"input": "k", "output": ["'K'"]}
|
||||
{"input": "kh", "output": ["'KH'"]}
|
||||
{"input": "kl", "output": ["'KL'"]}
|
||||
{"input": "kn", "output": ["'KN'"]}
|
||||
{"input": "lx", "output": ["'LX'"]}
|
||||
{"input": "m", "output": ["'M'"]}
|
||||
{"input": "mh", "output": ["'MH'"]}
|
||||
{"input": "ml", "output": ["'ML'"]}
|
||||
{"input": "mow", "output": ["'MOW'"]}
|
||||
{"input": "p", "output": ["'P'"]}
|
||||
{"input": "q", "output": ["'Q'"]}
|
||||
{"input": "qh", "output": ["'QH'"]}
|
||||
{"input": "qhw", "output": ["'QHW'"]}
|
||||
{"input": "qhx", "output": ["'QHX'"]}
|
||||
{"input": "qlw", "output": ["'QLW'"]}
|
||||
{"input": "qlx", "output": ["'QLX'"]}
|
||||
{"input": "qo", "output": ["'QO'"]}
|
||||
{"input": "qow", "output": ["'QOW'"]}
|
||||
{"input": "qox", "output": ["'QOX'"]}
|
||||
{"input": "qw", "output": ["'QW'"]}
|
||||
{"input": "qx", "output": ["'QX'"]}
|
||||
{"input": "s", "output": ["'S'"]}
|
||||
{"input": "u", "output": ["'U'"]}
|
||||
{"input": "v", "output": ["'V'"]}
|
||||
{"input": "vhw", "output": ["'VHW'"]}
|
||||
{"input": "vhx", "output": ["'VHX'"]}
|
||||
{"input": "vw", "output": ["'VW'"]}
|
||||
{"input": "vx", "output": ["'VX'"]}
|
||||
{"input": "y", "output": ["'Y'"]}
|
||||
{"input": "yh", "output": ["'YH'"]}
|
||||
{"input": "yl", "output": ["'YL'"]}
|
||||
{"input": "yn", "output": ["'YN'"]}
|
||||
{"input": "yw", "output": ["'YW'"]}
|
||||
{"input": "yx", "output": ["'YX'"]}
|
||||
{"input": "rental car", "output": ["'RENTAL CAR'"]}
|
||||
{"input": "air taxi operation", "output": ["'AIR TAXI OPERATION'"]}
|
||||
{"input": "limousine", "output": ["'LIMOUSINE'"]}
|
||||
{"input": "taxi", "output": ["'TAXI'"]}
|
||||
{"input": "arizona", "output": ["'ARIZONA'"]}
|
||||
{"input": "california", "output": ["'CALIFORNIA'"]}
|
||||
{"input": "colorado", "output": ["'COLORADO'"]}
|
||||
{"input": "district of columbia", "output": ["'DISTRICT OF COLUMBIA'"]}
|
||||
{"input": "dc", "output": ["'DISTRICT OF COLUMBIA'"]}
|
||||
{"input": "dc", "output": ["'DC'"]}
|
||||
{"input": "florida", "output": ["'FLORIDA'"]}
|
||||
{"input": "georgia", "output": ["'GEORGIA'"]}
|
||||
{"input": "illinois", "output": ["'ILLINOIS'"]}
|
||||
{"input": "indiana", "output": ["'INDIANA'"]}
|
||||
{"input": "massachusetts", "output": ["'MASSACHUSETTS'"]}
|
||||
{"input": "maryland", "output": ["'MARYLAND'"]}
|
||||
{"input": "michigan", "output": ["'MICHIGAN'"]}
|
||||
{"input": "minnesota", "output": ["'MINNESOTA'"]}
|
||||
{"input": "missouri", "output": ["'MISSOURI'"]}
|
||||
{"input": "north carolina", "output": ["'NORTH CAROLINA'"]}
|
||||
{"input": "new jersey", "output": ["'NEW JERSEY'"]}
|
||||
{"input": "nevada", "output": ["'NEVADA'"]}
|
||||
{"input": "new york", "output": ["'NEW YORK'"]}
|
||||
{"input": "ohio", "output": ["'OHIO'"]}
|
||||
{"input": "ontario", "output": ["'ONTARIO'"]}
|
||||
{"input": "pennsylvania", "output": ["'PENNSYLVANIA'"]}
|
||||
{"input": "quebec", "output": ["'QUEBEC'"]}
|
||||
{"input": "tennessee", "output": ["'TENNESSEE'"]}
|
||||
{"input": "texas", "output": ["'TEXAS'"]}
|
||||
{"input": "utah", "output": ["'UTAH'"]}
|
||||
{"input": "washington", "output": ["'WASHINGTON'"]}
|
||||
{"input": "wisconsin", "output": ["'WISCONSIN'"]}
|
||||
{"input": "1", "output": ["100"]}
|
||||
{"input": "1", "output": ["1300"]}
|
||||
{"input": "1pm", "output": ["1300"]}
|
||||
{"input": "1am", "output": ["100"]}
|
||||
{"input": "1 o'clock am", "output": ["100"]}
|
||||
{"input": "1 o'clock pm", "output": ["1300"]}
|
||||
{"input": "1 o'clock", "output": ["100"]}
|
||||
{"input": "1 o'clock", "output": ["1300"]}
|
||||
{"input": "100", "output": ["100"]}
|
||||
{"input": "100", "output": ["1300"]}
|
||||
{"input": "100am", "output": ["100"]}
|
||||
{"input": "100pm", "output": ["1300"]}
|
||||
{"input": "115", "output": ["115"]}
|
||||
{"input": "115", "output": ["1315"]}
|
||||
{"input": "115am", "output": ["115"]}
|
||||
{"input": "115pm", "output": ["1315"]}
|
||||
{"input": "130", "output": ["130"]}
|
||||
{"input": "130", "output": ["1330"]}
|
||||
{"input": "130am", "output": ["130"]}
|
||||
{"input": "130pm", "output": ["1330"]}
|
||||
{"input": "145", "output": ["145"]}
|
||||
{"input": "145", "output": ["1345"]}
|
||||
{"input": "145am", "output": ["145"]}
|
||||
{"input": "145pm", "output": ["1345"]}
|
||||
{"input": "2", "output": ["200"]}
|
||||
{"input": "2", "output": ["1400"]}
|
||||
{"input": "2pm", "output": ["1400"]}
|
||||
{"input": "2am", "output": ["200"]}
|
||||
{"input": "2 o'clock am", "output": ["200"]}
|
||||
{"input": "2 o'clock pm", "output": ["1400"]}
|
||||
{"input": "2 o'clock", "output": ["200"]}
|
||||
{"input": "2 o'clock", "output": ["1400"]}
|
||||
{"input": "200", "output": ["200"]}
|
||||
{"input": "200", "output": ["1400"]}
|
||||
{"input": "200am", "output": ["200"]}
|
||||
{"input": "200pm", "output": ["1400"]}
|
||||
{"input": "215", "output": ["215"]}
|
||||
{"input": "215", "output": ["1415"]}
|
||||
{"input": "215am", "output": ["215"]}
|
||||
{"input": "215pm", "output": ["1415"]}
|
||||
{"input": "230", "output": ["230"]}
|
||||
{"input": "230", "output": ["1430"]}
|
||||
{"input": "230am", "output": ["230"]}
|
||||
{"input": "230pm", "output": ["1430"]}
|
||||
{"input": "245", "output": ["245"]}
|
||||
{"input": "245", "output": ["1445"]}
|
||||
{"input": "245am", "output": ["245"]}
|
||||
{"input": "245pm", "output": ["1445"]}
|
||||
{"input": "3", "output": ["300"]}
|
||||
{"input": "3", "output": ["1500"]}
|
||||
{"input": "3pm", "output": ["1500"]}
|
||||
{"input": "3am", "output": ["300"]}
|
||||
{"input": "3 o'clock am", "output": ["300"]}
|
||||
{"input": "3 o'clock pm", "output": ["1500"]}
|
||||
{"input": "3 o'clock", "output": ["300"]}
|
||||
{"input": "3 o'clock", "output": ["1500"]}
|
||||
{"input": "300", "output": ["300"]}
|
||||
{"input": "300", "output": ["1500"]}
|
||||
{"input": "300am", "output": ["300"]}
|
||||
{"input": "300pm", "output": ["1500"]}
|
||||
{"input": "315", "output": ["315"]}
|
||||
{"input": "315", "output": ["1515"]}
|
||||
{"input": "315am", "output": ["315"]}
|
||||
{"input": "315pm", "output": ["1515"]}
|
||||
{"input": "330", "output": ["330"]}
|
||||
{"input": "330", "output": ["1530"]}
|
||||
{"input": "330am", "output": ["330"]}
|
||||
{"input": "330pm", "output": ["1530"]}
|
||||
{"input": "345", "output": ["345"]}
|
||||
{"input": "345", "output": ["1545"]}
|
||||
{"input": "345am", "output": ["345"]}
|
||||
{"input": "345pm", "output": ["1545"]}
|
||||
{"input": "4", "output": ["400"]}
|
||||
{"input": "4", "output": ["1600"]}
|
||||
{"input": "4pm", "output": ["1600"]}
|
||||
{"input": "4am", "output": ["400"]}
|
||||
{"input": "4 o'clock am", "output": ["400"]}
|
||||
{"input": "4 o'clock pm", "output": ["1600"]}
|
||||
{"input": "4 o'clock", "output": ["400"]}
|
||||
{"input": "4 o'clock", "output": ["1600"]}
|
||||
{"input": "400", "output": ["400"]}
|
||||
{"input": "400", "output": ["1600"]}
|
||||
{"input": "400am", "output": ["400"]}
|
||||
{"input": "400pm", "output": ["1600"]}
|
||||
{"input": "415", "output": ["415"]}
|
||||
{"input": "415", "output": ["1615"]}
|
||||
{"input": "415am", "output": ["415"]}
|
||||
{"input": "415pm", "output": ["1615"]}
|
||||
{"input": "430", "output": ["430"]}
|
||||
{"input": "430", "output": ["1630"]}
|
||||
{"input": "430am", "output": ["430"]}
|
||||
{"input": "430pm", "output": ["1630"]}
|
||||
{"input": "445", "output": ["445"]}
|
||||
{"input": "445", "output": ["1645"]}
|
||||
{"input": "445am", "output": ["445"]}
|
||||
{"input": "445pm", "output": ["1645"]}
|
||||
{"input": "5", "output": ["500"]}
|
||||
{"input": "5", "output": ["1700"]}
|
||||
{"input": "5pm", "output": ["1700"]}
|
||||
{"input": "5am", "output": ["500"]}
|
||||
{"input": "5 o'clock am", "output": ["500"]}
|
||||
{"input": "5 o'clock pm", "output": ["1700"]}
|
||||
{"input": "5 o'clock", "output": ["500"]}
|
||||
{"input": "5 o'clock", "output": ["1700"]}
|
||||
{"input": "500", "output": ["500"]}
|
||||
{"input": "500", "output": ["1700"]}
|
||||
{"input": "500am", "output": ["500"]}
|
||||
{"input": "500pm", "output": ["1700"]}
|
||||
{"input": "515", "output": ["515"]}
|
||||
{"input": "515", "output": ["1715"]}
|
||||
{"input": "515am", "output": ["515"]}
|
||||
{"input": "515pm", "output": ["1715"]}
|
||||
{"input": "530", "output": ["530"]}
|
||||
{"input": "530", "output": ["1730"]}
|
||||
{"input": "530am", "output": ["530"]}
|
||||
{"input": "530pm", "output": ["1730"]}
|
||||
{"input": "545", "output": ["545"]}
|
||||
{"input": "545", "output": ["1745"]}
|
||||
{"input": "545am", "output": ["545"]}
|
||||
{"input": "545pm", "output": ["1745"]}
|
||||
{"input": "6", "output": ["600"]}
|
||||
{"input": "6", "output": ["1800"]}
|
||||
{"input": "6pm", "output": ["1800"]}
|
||||
{"input": "6am", "output": ["600"]}
|
||||
{"input": "6 o'clock am", "output": ["600"]}
|
||||
{"input": "6 o'clock pm", "output": ["1800"]}
|
||||
{"input": "6 o'clock", "output": ["600"]}
|
||||
{"input": "6 o'clock", "output": ["1800"]}
|
||||
{"input": "600", "output": ["600"]}
|
||||
{"input": "600", "output": ["1800"]}
|
||||
{"input": "600am", "output": ["600"]}
|
||||
{"input": "600pm", "output": ["1800"]}
|
||||
{"input": "615", "output": ["615"]}
|
||||
{"input": "615", "output": ["1815"]}
|
||||
{"input": "615am", "output": ["615"]}
|
||||
{"input": "615pm", "output": ["1815"]}
|
||||
{"input": "630", "output": ["630"]}
|
||||
{"input": "630", "output": ["1830"]}
|
||||
{"input": "630am", "output": ["630"]}
|
||||
{"input": "630pm", "output": ["1830"]}
|
||||
{"input": "645", "output": ["645"]}
|
||||
{"input": "645", "output": ["1845"]}
|
||||
{"input": "645am", "output": ["645"]}
|
||||
{"input": "645pm", "output": ["1845"]}
|
||||
{"input": "7", "output": ["700"]}
|
||||
{"input": "7", "output": ["1900"]}
|
||||
{"input": "7pm", "output": ["1900"]}
|
||||
{"input": "7am", "output": ["700"]}
|
||||
{"input": "7 o'clock am", "output": ["700"]}
|
||||
{"input": "7 o'clock pm", "output": ["1900"]}
|
||||
{"input": "7 o'clock", "output": ["700"]}
|
||||
{"input": "7 o'clock", "output": ["1900"]}
|
||||
{"input": "700", "output": ["700"]}
|
||||
{"input": "700", "output": ["1900"]}
|
||||
{"input": "700am", "output": ["700"]}
|
||||
{"input": "700pm", "output": ["1900"]}
|
||||
{"input": "715", "output": ["715"]}
|
||||
{"input": "715", "output": ["1915"]}
|
||||
{"input": "715am", "output": ["715"]}
|
||||
{"input": "715pm", "output": ["1915"]}
|
||||
{"input": "730", "output": ["730"]}
|
||||
{"input": "730", "output": ["1930"]}
|
||||
{"input": "730am", "output": ["730"]}
|
||||
{"input": "730pm", "output": ["1930"]}
|
||||
{"input": "745", "output": ["745"]}
|
||||
{"input": "745", "output": ["1945"]}
|
||||
{"input": "745am", "output": ["745"]}
|
||||
{"input": "745pm", "output": ["1945"]}
|
||||
{"input": "8", "output": ["800"]}
|
||||
{"input": "8", "output": ["2000"]}
|
||||
{"input": "8pm", "output": ["2000"]}
|
||||
{"input": "8am", "output": ["800"]}
|
||||
{"input": "8 o'clock am", "output": ["800"]}
|
||||
{"input": "8 o'clock pm", "output": ["2000"]}
|
||||
{"input": "8 o'clock", "output": ["800"]}
|
||||
{"input": "8 o'clock", "output": ["2000"]}
|
||||
{"input": "800", "output": ["800"]}
|
||||
{"input": "800", "output": ["2000"]}
|
||||
{"input": "800am", "output": ["800"]}
|
||||
{"input": "800pm", "output": ["2000"]}
|
||||
{"input": "815", "output": ["815"]}
|
||||
{"input": "815", "output": ["2015"]}
|
||||
{"input": "815am", "output": ["815"]}
|
||||
{"input": "815pm", "output": ["2015"]}
|
||||
{"input": "830", "output": ["830"]}
|
||||
{"input": "830", "output": ["2030"]}
|
||||
{"input": "830am", "output": ["830"]}
|
||||
{"input": "830pm", "output": ["2030"]}
|
||||
{"input": "845", "output": ["845"]}
|
||||
{"input": "845", "output": ["2045"]}
|
||||
{"input": "845am", "output": ["845"]}
|
||||
{"input": "845pm", "output": ["2045"]}
|
||||
{"input": "9", "output": ["900"]}
|
||||
{"input": "9", "output": ["2100"]}
|
||||
{"input": "9pm", "output": ["2100"]}
|
||||
{"input": "9am", "output": ["900"]}
|
||||
{"input": "9 o'clock am", "output": ["900"]}
|
||||
{"input": "9 o'clock pm", "output": ["2100"]}
|
||||
{"input": "9 o'clock", "output": ["900"]}
|
||||
{"input": "9 o'clock", "output": ["2100"]}
|
||||
{"input": "900", "output": ["900"]}
|
||||
{"input": "900", "output": ["2100"]}
|
||||
{"input": "900am", "output": ["900"]}
|
||||
{"input": "900pm", "output": ["2100"]}
|
||||
{"input": "915", "output": ["915"]}
|
||||
{"input": "915", "output": ["2115"]}
|
||||
{"input": "915am", "output": ["915"]}
|
||||
{"input": "915pm", "output": ["2115"]}
|
||||
{"input": "930", "output": ["930"]}
|
||||
{"input": "930", "output": ["2130"]}
|
||||
{"input": "930am", "output": ["930"]}
|
||||
{"input": "930pm", "output": ["2130"]}
|
||||
{"input": "945", "output": ["945"]}
|
||||
{"input": "945", "output": ["2145"]}
|
||||
{"input": "945am", "output": ["945"]}
|
||||
{"input": "945pm", "output": ["2145"]}
|
||||
{"input": "10", "output": ["1000"]}
|
||||
{"input": "10", "output": ["2200"]}
|
||||
{"input": "10pm", "output": ["2200"]}
|
||||
{"input": "10am", "output": ["1000"]}
|
||||
{"input": "10 o'clock am", "output": ["1000"]}
|
||||
{"input": "10 o'clock pm", "output": ["2200"]}
|
||||
{"input": "10 o'clock", "output": ["1000"]}
|
||||
{"input": "10 o'clock", "output": ["2200"]}
|
||||
{"input": "1000", "output": ["1000"]}
|
||||
{"input": "1000", "output": ["2200"]}
|
||||
{"input": "1000am", "output": ["1000"]}
|
||||
{"input": "1000pm", "output": ["2200"]}
|
||||
{"input": "1015", "output": ["1015"]}
|
||||
{"input": "1015", "output": ["2215"]}
|
||||
{"input": "1015am", "output": ["1015"]}
|
||||
{"input": "1015pm", "output": ["2215"]}
|
||||
{"input": "1030", "output": ["1030"]}
|
||||
{"input": "1030", "output": ["2230"]}
|
||||
{"input": "1030am", "output": ["1030"]}
|
||||
{"input": "1030pm", "output": ["2230"]}
|
||||
{"input": "1045", "output": ["1045"]}
|
||||
{"input": "1045", "output": ["2245"]}
|
||||
{"input": "1045am", "output": ["1045"]}
|
||||
{"input": "1045pm", "output": ["2245"]}
|
||||
{"input": "11", "output": ["1100"]}
|
||||
{"input": "11", "output": ["2300"]}
|
||||
{"input": "11pm", "output": ["2300"]}
|
||||
{"input": "11am", "output": ["1100"]}
|
||||
{"input": "11 o'clock am", "output": ["1100"]}
|
||||
{"input": "11 o'clock pm", "output": ["2300"]}
|
||||
{"input": "11 o'clock", "output": ["1100"]}
|
||||
{"input": "11 o'clock", "output": ["2300"]}
|
||||
{"input": "1100", "output": ["1100"]}
|
||||
{"input": "1100", "output": ["2300"]}
|
||||
{"input": "1100am", "output": ["1100"]}
|
||||
{"input": "1100pm", "output": ["2300"]}
|
||||
{"input": "1115", "output": ["1115"]}
|
||||
{"input": "1115", "output": ["2315"]}
|
||||
{"input": "1115am", "output": ["1115"]}
|
||||
{"input": "1115pm", "output": ["2315"]}
|
||||
{"input": "1130", "output": ["1130"]}
|
||||
{"input": "1130", "output": ["2330"]}
|
||||
{"input": "1130am", "output": ["1130"]}
|
||||
{"input": "1130pm", "output": ["2330"]}
|
||||
{"input": "1145", "output": ["1145"]}
|
||||
{"input": "1145", "output": ["2345"]}
|
||||
{"input": "1145am", "output": ["1145"]}
|
||||
{"input": "1145pm", "output": ["2345"]}
|
||||
{"input": "12", "output": ["1200"]}
|
||||
{"input": "1200", "output": ["1200"]}
|
||||
{"input": "1200", "output": ["2400"]}
|
||||
{"input": "1200am", "output": ["1200"]}
|
||||
{"input": "1200pm", "output": ["2400"]}
|
||||
{"input": "1215", "output": ["1215"]}
|
||||
{"input": "1215", "output": ["2415"]}
|
||||
{"input": "1215am", "output": ["1215"]}
|
||||
{"input": "1215pm", "output": ["2415"]}
|
||||
{"input": "1230", "output": ["1230"]}
|
||||
{"input": "1230", "output": ["2430"]}
|
||||
{"input": "1230am", "output": ["1230"]}
|
||||
{"input": "1230pm", "output": ["2430"]}
|
||||
{"input": "1245", "output": ["1245"]}
|
||||
{"input": "1245", "output": ["2445"]}
|
||||
{"input": "1245am", "output": ["1245"]}
|
||||
{"input": "1245pm", "output": ["2445"]}
|
||||
{"input": "13", "output": ["1300"]}
|
||||
{"input": "1300", "output": ["1300"]}
|
||||
{"input": "1315", "output": ["1315"]}
|
||||
{"input": "1330", "output": ["1330"]}
|
||||
{"input": "1345", "output": ["1345"]}
|
||||
{"input": "14", "output": ["1400"]}
|
||||
{"input": "1400", "output": ["1400"]}
|
||||
{"input": "1415", "output": ["1415"]}
|
||||
{"input": "1430", "output": ["1430"]}
|
||||
{"input": "1445", "output": ["1445"]}
|
||||
{"input": "15", "output": ["1500"]}
|
||||
{"input": "1500", "output": ["1500"]}
|
||||
{"input": "1515", "output": ["1515"]}
|
||||
{"input": "1530", "output": ["1530"]}
|
||||
{"input": "1545", "output": ["1545"]}
|
||||
{"input": "16", "output": ["1600"]}
|
||||
{"input": "1600", "output": ["1600"]}
|
||||
{"input": "1615", "output": ["1615"]}
|
||||
{"input": "1630", "output": ["1630"]}
|
||||
{"input": "1645", "output": ["1645"]}
|
||||
{"input": "17", "output": ["1700"]}
|
||||
{"input": "1700", "output": ["1700"]}
|
||||
{"input": "1715", "output": ["1715"]}
|
||||
{"input": "1730", "output": ["1730"]}
|
||||
{"input": "1745", "output": ["1745"]}
|
||||
{"input": "18", "output": ["1800"]}
|
||||
{"input": "1800", "output": ["1800"]}
|
||||
{"input": "1815", "output": ["1815"]}
|
||||
{"input": "1830", "output": ["1830"]}
|
||||
{"input": "1845", "output": ["1845"]}
|
||||
{"input": "19", "output": ["1900"]}
|
||||
{"input": "1900", "output": ["1900"]}
|
||||
{"input": "1915", "output": ["1915"]}
|
||||
{"input": "1930", "output": ["1930"]}
|
||||
{"input": "1945", "output": ["1945"]}
|
||||
{"input": "20", "output": ["2000"]}
|
||||
{"input": "2000", "output": ["2000"]}
|
||||
{"input": "2015", "output": ["2015"]}
|
||||
{"input": "2030", "output": ["2030"]}
|
||||
{"input": "2045", "output": ["2045"]}
|
||||
{"input": "21", "output": ["2100"]}
|
||||
{"input": "2100", "output": ["2100"]}
|
||||
{"input": "2115", "output": ["2115"]}
|
||||
{"input": "2130", "output": ["2130"]}
|
||||
{"input": "2145", "output": ["2145"]}
|
||||
{"input": "22", "output": ["2200"]}
|
||||
{"input": "2200", "output": ["2200"]}
|
||||
{"input": "2215", "output": ["2215"]}
|
||||
{"input": "2230", "output": ["2230"]}
|
||||
{"input": "2245", "output": ["2245"]}
|
||||
{"input": "23", "output": ["2300"]}
|
||||
{"input": "2300", "output": ["2300"]}
|
||||
{"input": "2315", "output": ["2315"]}
|
||||
{"input": "2330", "output": ["2330"]}
|
||||
{"input": "2345", "output": ["2345"]}
|
||||
{"input": "noon", "output": ["1200"]}
|
||||
{"input": "breakfast", "output": ["'BREAKFAST'"]}
|
||||
{"input": "lunch", "output": ["'LUNCH'"]}
|
||||
{"input": "dinner", "output": ["'DINNER'"]}
|
||||
{"input": "snack", "output": ["'SNACK'"]}
|
||||
{"input": "atl", "output": ["'ATL'"]}
|
||||
{"input": "bna", "output": ["'BNA'"]}
|
||||
{"input": "bos", "output": ["'BOS'"]}
|
||||
{"input": "bur", "output": ["'BUR'"]}
|
||||
{"input": "bwi", "output": ["'BWI'"]}
|
||||
{"input": "cle", "output": ["'CLE'"]}
|
||||
{"input": "clt", "output": ["'CLT'"]}
|
||||
{"input": "cmh", "output": ["'CMH'"]}
|
||||
{"input": "cvg", "output": ["'CVG'"]}
|
||||
{"input": "dal", "output": ["'DAL'"]}
|
||||
{"input": "love field", "output": ["'DAL'"]}
|
||||
{"input": "dca", "output": ["'DCA'"]}
|
||||
{"input": "den", "output": ["'DEN'"]}
|
||||
{"input": "det", "output": ["'DET'"]}
|
||||
{"input": "dfw", "output": ["'DFW'"]}
|
||||
{"input": "dtw", "output": ["'DTW'"]}
|
||||
{"input": "ewr", "output": ["'EWR'"]}
|
||||
{"input": "hou", "output": ["'HOU'"]}
|
||||
{"input": "hpn", "output": ["'HPN'"]}
|
||||
{"input": "iad", "output": ["'IAD'"]}
|
||||
{"input": "iah", "output": ["'IAH'"]}
|
||||
{"input": "houston international", "output": ["'IAH'"]}
|
||||
{"input": "houston international airport", "output": ["'IAH'"]}
|
||||
{"input": "houston intercontinental airport", "output": ["'IAH'"]}
|
||||
{"input": "ind", "output": ["'IND'"]}
|
||||
{"input": "jfk", "output": ["'JFK'"]}
|
||||
{"input": "las", "output": ["'LAS'"]}
|
||||
{"input": "lax", "output": ["'LAX'"]}
|
||||
{"input": "lga", "output": ["'LGA'"]}
|
||||
{"input": "la guardia", "output": ["'LGA'"]}
|
||||
{"input": "lgb", "output": ["'LGB'"]}
|
||||
{"input": "mci", "output": ["'MCI'"]}
|
||||
{"input": "mco", "output": ["'MCO'"]}
|
||||
{"input": "mdw", "output": ["'MDW'"]}
|
||||
{"input": "mem", "output": ["'MEM'"]}
|
||||
{"input": "mia", "output": ["'MIA'"]}
|
||||
{"input": "mke", "output": ["'MKE'"]}
|
||||
{"input": "msp", "output": ["'MSP'"]}
|
||||
{"input": "oak", "output": ["'OAK'"]}
|
||||
{"input": "ont", "output": ["'ONT'"]}
|
||||
{"input": "ord", "output": ["'ORD'"]}
|
||||
{"input": "phl", "output": ["'PHL'"]}
|
||||
{"input": "phx", "output": ["'PHX'"]}
|
||||
{"input": "pie", "output": ["'PIE'"]}
|
||||
{"input": "pit", "output": ["'PIT'"]}
|
||||
{"input": "san", "output": ["'SAN'"]}
|
||||
{"input": "sea", "output": ["'SEA'"]}
|
||||
{"input": "sfo", "output": ["'SFO'"]}
|
||||
{"input": "sjc", "output": ["'SJC'"]}
|
||||
{"input": "slc", "output": ["'SLC'"]}
|
||||
{"input": "stl", "output": ["'STL'"]}
|
||||
{"input": "tpa", "output": ["'TPA'"]}
|
||||
{"input": "ykz", "output": ["'YKZ'"]}
|
||||
{"input": "ymx", "output": ["'YMX'"]}
|
||||
{"input": "ytz", "output": ["'YTZ'"]}
|
||||
{"input": "yul", "output": ["'YUL'"]}
|
||||
{"input": "yyz", "output": ["'YYZ'"]}
|
||||
{"input": "ap/2", "output": ["'AP/2'"]}
|
||||
{"input": "ap/6", "output": ["'AP/6'"]}
|
||||
{"input": "ap/12", "output": ["'AP/12'"]}
|
||||
{"input": "ap/20", "output": ["'AP/20'"]}
|
||||
{"input": "ap/21", "output": ["'AP/21'"]}
|
||||
{"input": "ap/57", "output": ["'AP/57'"]}
|
||||
{"input": "ap/58", "output": ["'AP/58'"]}
|
||||
{"input": "ap/60", "output": ["'AP/60'"]}
|
||||
{"input": "ap/75", "output": ["'AP/75'"]}
|
||||
{"input": "ex/9", "output": ["'EX/9'"]}
|
||||
{"input": "ex/13", "output": ["'EX/13'"]}
|
||||
{"input": "ex/14", "output": ["'EX/14'"]}
|
||||
{"input": "ex/17", "output": ["'EX/17'"]}
|
||||
{"input": "ex/19", "output": ["'EX/19'"]}
|
||||
{"input": "flight", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "flights", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "fly", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "trip", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "cost", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "price", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "cheapest", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "expensive", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "fares", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "fare", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "rate", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "cost", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "prices", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "price", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "expensive", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "fares", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "fare", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "rate", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least expensive one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "most expensive one way", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "highest fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "highest one way fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "round trip", "output": ["fare.round_trip_cost", "IS", "NOT", "NULL"]}
|
||||
{"input": "least", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least expensive round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "airplane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "plane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "ground transportation", "output": ["DISTINCT", "ground_service.city_code", ",", "ground_service.airport_code", ",", "ground_service.transport_type", "FROM", "ground_service"]}
|
||||
{"input": "airline", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
|
||||
{"input": "airlines", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
|
||||
{"input": "meals", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "meal", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "breakfast", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "dinner", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "lunch", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "schedule", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "where", "output": ["DISTINCT", "airport.airport_code"]}
|
||||
{"input": "seating capacity", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
|
||||
{"input": "capacities", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.basic_type", "FROM", "aircraft"]}
|
||||
{"input": "stop", "output": ["flight.stops"]}
|
||||
{"input": "schedules", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "time", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "times", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "flights and fares", "output": ["DISTINCT", "flight.flight_id", ",", "fare.fare_id", "FROM", "flight", ",", "flight_fare", ",", "fare"]}
|
||||
{"input": "fare code", "output": ["DISTINCT", "fare_basis.fare_basis_code", ",", "fare_basis.booking_class", ",", "fare_basis.class_type", ",", "fare_basis.premium", ",", "fare_basis.economy", ",", "fare_basis.discounted", ",", "fare_basis.night", ",", "fare_basis.season", ",", "fare_basis.basis_days", "FROM", "fare_basis"]}
|
||||
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
|
||||
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
|
||||
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "latest", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "last", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "shortest", "output": ["MIN", "(", "flight.time_elapsed", ")", "FROM", "flight"]}
|
||||
{"input": "how many", "output": ["count", "(", "*", ")"]}
|
||||
{"input": "number", "output": ["count", "(", "*", ")"]}
|
||||
{"input": "one way", "output": ["fare.round_trip_required", "=", "'NO'"]}
|
||||
{"input": "1 way", "output": ["fare.round_trip_required", "=", "'NO'"]}
|
||||
{"input": "cities", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
|
||||
{"input": "city", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
|
||||
{"input": "airport", "output": ["DISTINCT", "airport.airport_location"]}
|
||||
{"input": "economy", "output": ["'ECONOMY'"]}
|
||||
{"input": "monday", "output": ["date_day.month_number"]}
|
||||
{"input": "monday", "output": ["date_day.day_number"]}
|
||||
{"input": "tuesday", "output": ["date_day.month_number"]}
|
||||
{"input": "tuesday", "output": ["date_day.day_number"]}
|
||||
{"input": "wednesday", "output": ["date_day.month_number"]}
|
||||
{"input": "wednesday", "output": ["date_day.day_number"]}
|
||||
{"input": "thursday", "output": ["date_day.month_number"]}
|
||||
{"input": "thursday", "output": ["date_day.day_number"]}
|
||||
{"input": "friday", "output": ["date_day.month_number"]}
|
||||
{"input": "friday", "output": ["date_day.day_number"]}
|
||||
{"input": "saturday", "output": ["date_day.month_number"]}
|
||||
{"input": "saturday", "output": ["date_day.day_number"]}
|
||||
{"input": "sunday", "output": ["date_day.month_number"]}
|
||||
{"input": "sunday", "output": ["date_day.day_number"]}
|
||||
{"input": "class", "output": ["class_of_service.booking_class"]}
|
||||
{"input": "morning", "output": ["BETWEEN", "0", "AND", "1200"]}
|
||||
{"input": "what does", "output": ["DISTINCT", "airport.airport_name", "FROM", "airport"]}
|
||||
{"input": "what does", "output": ["DISTINCT", "class_of_service.class_description", "FROM", "class_of_service"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "meals", "output": ["food_service.meal_code"]}
|
||||
{"input": "flight", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "flights", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "fly", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "trip", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
|
||||
{"input": "cost", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "price", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "cheapest", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "expensive", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "fares", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "fare", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "rate", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
|
||||
{"input": "cost", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "prices", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "price", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "expensive", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "fares", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "fare", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "rate", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least expensive one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "most expensive one way", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "highest fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "highest one way fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "round trip", "output": ["fare.round_trip_cost", "IS", "NOT", "NULL"]}
|
||||
{"input": "least", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "cheapest round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "least expensive round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "airplane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "plane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "ground transportation", "output": ["DISTINCT", "ground_service.city_code", ",", "ground_service.airport_code", ",", "ground_service.transport_type", "FROM", "ground_service"]}
|
||||
{"input": "airline", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
|
||||
{"input": "airlines", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
|
||||
{"input": "meals", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "meal", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "breakfast", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "dinner", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "lunch", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
|
||||
{"input": "schedule", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "where", "output": ["DISTINCT", "airport.airport_code"]}
|
||||
{"input": "seating capacity", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
|
||||
{"input": "capacities", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.basic_type", "FROM", "aircraft"]}
|
||||
{"input": "stop", "output": ["flight.stops"]}
|
||||
{"input": "schedules", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "time", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "times", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
|
||||
{"input": "flights and fares", "output": ["DISTINCT", "flight.flight_id", ",", "fare.fare_id", "FROM", "flight", ",", "flight_fare", ",", "fare"]}
|
||||
{"input": "fare code", "output": ["DISTINCT", "fare_basis.fare_basis_code", ",", "fare_basis.booking_class", ",", "fare_basis.class_type", ",", "fare_basis.premium", ",", "fare_basis.economy", ",", "fare_basis.discounted", ",", "fare_basis.night", ",", "fare_basis.season", ",", "fare_basis.basis_days", "FROM", "fare_basis"]}
|
||||
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
|
||||
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
|
||||
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight", ".", "flight_days", ",", "flight", ".", "from_airport", ",", "flight", ".", "to_airport", ",", "flight", ".", "departure_time", ",", "flight", ".", "arrival_time", ",", "flight", ".", "airline_flight", ",", "flight", ".", "airline_code", ",", "flight", ".", "flight_number", ",", "flight", ".", "aircraft_code_sequence", ",", "flight", ".", "meal_code", ",", "flight", ".", "stops", ",", "flight", ".", "connections", ",", "flight", ".", "dual_carrier", ",", "flight", ".", "time_elapsed", "FROM", "flight"]}
|
||||
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
|
||||
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "latest", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "last", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
|
||||
{"input": "shortest", "output": ["MIN", "(", "flight.time_elapsed", ")", "FROM", "flight"]}
|
||||
{"input": "how many", "output": ["count", "(", "*", ")"]}
|
||||
{"input": "number", "output": ["count", "(", "*", ")"]}
|
||||
{"input": "one way", "output": ["fare.round_trip_required", "=", "'NO'"]}
|
||||
{"input": "1 way", "output": ["fare.round_trip_required", "=", "'NO'"]}
|
||||
{"input": "cities", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
|
||||
{"input": "city", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
|
||||
{"input": "airport", "output": ["DISTINCT", "airport.airport_location"]}
|
||||
{"input": "economy", "output": ["'ECONOMY'"]}
|
||||
{"input": "monday", "output": ["date_day.month_number"]}
|
||||
{"input": "monday", "output": ["date_day.day_number"]}
|
||||
{"input": "tuesday", "output": ["date_day.month_number"]}
|
||||
{"input": "tuesday", "output": ["date_day.day_number"]}
|
||||
{"input": "wednesday", "output": ["date_day.month_number"]}
|
||||
{"input": "wednesday", "output": ["date_day.day_number"]}
|
||||
{"input": "thursday", "output": ["date_day.month_number"]}
|
||||
{"input": "thursday", "output": ["date_day.day_number"]}
|
||||
{"input": "friday", "output": ["date_day.month_number"]}
|
||||
{"input": "friday", "output": ["date_day.day_number"]}
|
||||
{"input": "saturday", "output": ["date_day.month_number"]}
|
||||
{"input": "saturday", "output": ["date_day.day_number"]}
|
||||
{"input": "sunday", "output": ["date_day.month_number"]}
|
||||
{"input": "sunday", "output": ["date_day.day_number"]}
|
||||
{"input": "class", "output": ["class_of_service.booking_class"]}
|
||||
{"input": "morning", "output": ["BETWEEN", "0", "AND", "1200"]}
|
||||
{"input": "what does", "output": ["DISTINCT", "airport.airport_name", "FROM", "airport"]}
|
||||
{"input": "what does", "output": ["DISTINCT", "class_of_service.class_description", "FROM", "class_of_service"]}
|
||||
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
|
||||
{"input": "meals", "output": ["food_service.meal_code"]}
|
|
@ -0,0 +1,262 @@
|
|||
"""Code for identifying and anonymizing entities in NL and SQL."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from . import util
|
||||
|
||||
ENTITY_NAME = "ENTITY"
|
||||
CONSTANT_NAME = "CONSTANT"
|
||||
TIME_NAME = "TIME"
|
||||
SEPARATOR = "#"
|
||||
|
||||
|
||||
def timeval(string):
|
||||
"""Returns the numeric version of a time.
|
||||
|
||||
Inputs:
|
||||
string (str): String representing a time.
|
||||
|
||||
Returns:
|
||||
String representing the absolute time.
|
||||
"""
|
||||
if string.endswith("am") or string.endswith(
|
||||
"pm") and string[:-2].isdigit():
|
||||
numval = int(string[:-2])
|
||||
if len(string) == 3 or len(string) == 4:
|
||||
numval *= 100
|
||||
if string.endswith("pm"):
|
||||
numval += 1200
|
||||
return str(numval)
|
||||
return ""
|
||||
|
||||
|
||||
def is_time(string):
|
||||
"""Returns whether a string represents a time.
|
||||
|
||||
Inputs:
|
||||
string (str): String to check.
|
||||
|
||||
Returns:
|
||||
Whether the string represents a time.
|
||||
"""
|
||||
if string.endswith("am") or string.endswith("pm"):
|
||||
if string[:-2].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def deanonymize(sequence, ent_dict, key):
|
||||
"""Deanonymizes a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): List of tokens to deanonymize.
|
||||
ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary.
|
||||
key (str): The key to use, in this case either natural language or SQL.
|
||||
|
||||
Returns:
|
||||
Deanonymized sequence of tokens.
|
||||
"""
|
||||
new_sequence = []
|
||||
for token in sequence:
|
||||
if token in ent_dict:
|
||||
new_sequence.extend(ent_dict[token][key])
|
||||
else:
|
||||
new_sequence.append(token)
|
||||
|
||||
return new_sequence
|
||||
|
||||
|
||||
class Anonymizer:
|
||||
"""Anonymization class for keeping track of entities in this domain and
|
||||
scripts for anonymizing/deanonymizing.
|
||||
|
||||
Members:
|
||||
anonymization_map (list of dict (str->str)): Containing entities from
|
||||
the anonymization file.
|
||||
entity_types (list of str): All entities in the anonymization file.
|
||||
keys (set of str): Possible keys (types of text handled); in this case it should be
|
||||
one for natural language and another for SQL.
|
||||
entity_set (set of str): entity_types as a set.
|
||||
"""
|
||||
def __init__(self, filename):
|
||||
self.anonymization_map = []
|
||||
self.entity_types = []
|
||||
self.keys = set()
|
||||
|
||||
pairs = [json.loads(line) for line in open(filename).readlines()]
|
||||
for pair in pairs:
|
||||
for key in pair:
|
||||
if key != "type":
|
||||
self.keys.add(key)
|
||||
self.anonymization_map.append(pair)
|
||||
if pair["type"] not in self.entity_types:
|
||||
self.entity_types.append(pair["type"])
|
||||
|
||||
self.entity_types.append(ENTITY_NAME)
|
||||
self.entity_types.append(CONSTANT_NAME)
|
||||
self.entity_types.append(TIME_NAME)
|
||||
|
||||
self.entity_set = set(self.entity_types)
|
||||
|
||||
def get_entity_type_from_token(self, token):
|
||||
"""Gets the type of an entity given an anonymized token.
|
||||
|
||||
Inputs:
|
||||
token (str): The entity token.
|
||||
|
||||
Returns:
|
||||
str, representing the type of the entity.
|
||||
"""
|
||||
# these are in the pattern NAME:#, so just strip the thing after the
|
||||
# colon
|
||||
colon_loc = token.index(SEPARATOR)
|
||||
entity_type = token[:colon_loc]
|
||||
assert entity_type in self.entity_set
|
||||
|
||||
return entity_type
|
||||
|
||||
def is_anon_tok(self, token):
|
||||
"""Returns whether a token is an anonymized token or not.
|
||||
|
||||
Input:
|
||||
token (str): The token to check.
|
||||
|
||||
Returns:
|
||||
bool, whether the token is an anonymized token.
|
||||
"""
|
||||
return token.split(SEPARATOR)[0] in self.entity_set
|
||||
|
||||
def get_anon_id(self, token):
|
||||
"""Gets the entity index (unique ID) for a token.
|
||||
|
||||
Input:
|
||||
token (str): The token to get the index from.
|
||||
|
||||
Returns:
|
||||
int, the token ID if it is an anonymized token; otherwise -1.
|
||||
"""
|
||||
if self.is_anon_tok(token):
|
||||
return self.entity_types.index(token.split(SEPARATOR)[0])
|
||||
else:
|
||||
return -1
|
||||
|
||||
def anonymize(self,
|
||||
sequence,
|
||||
tok_to_entity_dict,
|
||||
key,
|
||||
add_new_anon_toks=False):
|
||||
"""Anonymizes a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): Sequence to anonymize.
|
||||
tok_to_entity_dict (dict): Existing dictionary mapping from anonymized
|
||||
tokens to entities.
|
||||
key (str): Which kind of text this is (natural language or SQL)
|
||||
add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict.
|
||||
|
||||
Returns:
|
||||
list of str, the anonymized sequence.
|
||||
"""
|
||||
# Sort the token-tok-entity dict by the length of the modality.
|
||||
sorted_dict = sorted(tok_to_entity_dict.items(),
|
||||
key=lambda k: len(k[1][key]))[::-1]
|
||||
|
||||
anonymized_sequence = copy.deepcopy(sequence)
|
||||
|
||||
if add_new_anon_toks:
|
||||
type_counts = {}
|
||||
for entity_type in self.entity_types:
|
||||
type_counts[entity_type] = 0
|
||||
for token in tok_to_entity_dict:
|
||||
entity_type = self.get_entity_type_from_token(token)
|
||||
type_counts[entity_type] += 1
|
||||
|
||||
# First find occurrences of things in the anonymization dictionary.
|
||||
for token, modalities in sorted_dict:
|
||||
our_modality = modalities[key]
|
||||
|
||||
# Check if this key's version of the anonymized thing is in our
|
||||
# sequence.
|
||||
while util.subsequence(our_modality, anonymized_sequence):
|
||||
found = False
|
||||
for startidx in range(
|
||||
len(anonymized_sequence) - len(our_modality) + 1):
|
||||
if anonymized_sequence[startidx:startidx +
|
||||
len(our_modality)] == our_modality:
|
||||
anonymized_sequence = anonymized_sequence[:startidx] + [
|
||||
token] + anonymized_sequence[startidx + len(our_modality):]
|
||||
found = True
|
||||
break
|
||||
assert found, "Thought " \
|
||||
+ str(our_modality) + " was in [" \
|
||||
+ str(anonymized_sequence) + "] but could not find it"
|
||||
|
||||
# Now add new keys if they are present.
|
||||
if add_new_anon_toks:
|
||||
|
||||
# For every span in the sequence, check whether it is in the anon map
|
||||
# for this modality
|
||||
sorted_anon_map = sorted(self.anonymization_map,
|
||||
key=lambda k: len(k[key]))[::-1]
|
||||
|
||||
for pair in sorted_anon_map:
|
||||
our_modality = pair[key]
|
||||
|
||||
token_type = pair["type"]
|
||||
new_token = token_type + SEPARATOR + \
|
||||
str(type_counts[token_type])
|
||||
|
||||
while util.subsequence(our_modality, anonymized_sequence):
|
||||
found = False
|
||||
for startidx in range(
|
||||
len(anonymized_sequence) - len(our_modality) + 1):
|
||||
if anonymized_sequence[startidx:startidx + \
|
||||
len(our_modality)] == our_modality:
|
||||
if new_token not in tok_to_entity_dict:
|
||||
type_counts[token_type] += 1
|
||||
tok_to_entity_dict[new_token] = pair
|
||||
|
||||
anonymized_sequence = anonymized_sequence[:startidx] + [
|
||||
new_token] + anonymized_sequence[startidx + len(our_modality):]
|
||||
found = True
|
||||
break
|
||||
assert found, "Thought " \
|
||||
+ str(our_modality) + " was in [" \
|
||||
+ str(anonymized_sequence) + "] but could not find it"
|
||||
|
||||
# Also replace integers with constants
|
||||
for index, token in enumerate(anonymized_sequence):
|
||||
if token.isdigit() or is_time(token):
|
||||
if token.isdigit():
|
||||
entity_type = CONSTANT_NAME
|
||||
value = new_token
|
||||
if is_time(token):
|
||||
entity_type = TIME_NAME
|
||||
value = timeval(token)
|
||||
|
||||
# First try to find the constant in the entity dictionary already,
|
||||
# and get the name if it's found.
|
||||
new_token = ""
|
||||
new_dict = {}
|
||||
found = False
|
||||
for entity, value in tok_to_entity_dict.items():
|
||||
if value[key][0] == token:
|
||||
new_token = entity
|
||||
new_dict = value
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
new_token = entity_type + SEPARATOR + \
|
||||
str(type_counts[entity_type])
|
||||
new_dict = {}
|
||||
for tempkey in self.keys:
|
||||
new_dict[tempkey] = [token]
|
||||
|
||||
tok_to_entity_dict[new_token] = new_dict
|
||||
type_counts[entity_type] += 1
|
||||
|
||||
anonymized_sequence[index] = new_token
|
||||
|
||||
return anonymized_sequence
|
|
@ -0,0 +1,350 @@
|
|||
# TODO: review this entire file and make it much simpler.
|
||||
|
||||
import copy
|
||||
from . import snippets as snip
|
||||
from . import sql_util
|
||||
from . import vocabulary as vocab
|
||||
|
||||
|
||||
class UtteranceItem():
|
||||
def __init__(self, interaction, index):
|
||||
self.interaction = interaction
|
||||
self.utterance_index = index
|
||||
|
||||
def __str__(self):
|
||||
return str(self.interaction.utterances[self.utterance_index])
|
||||
|
||||
def histories(self, maximum):
|
||||
if maximum > 0:
|
||||
history_seqs = []
|
||||
for utterance in self.interaction.utterances[:self.utterance_index]:
|
||||
history_seqs.append(utterance.input_seq_to_use)
|
||||
|
||||
if len(history_seqs) > maximum:
|
||||
history_seqs = history_seqs[-maximum:]
|
||||
|
||||
return history_seqs
|
||||
return []
|
||||
|
||||
def input_sequence(self):
|
||||
return self.interaction.utterances[self.utterance_index].input_seq_to_use
|
||||
|
||||
def previous_query(self):
|
||||
if self.utterance_index == 0:
|
||||
return []
|
||||
return self.interaction.utterances[self.utterance_index -
|
||||
1].anonymized_gold_query
|
||||
|
||||
def anonymized_gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].anonymized_gold_query
|
||||
|
||||
def snippets(self):
|
||||
return self.interaction.utterances[self.utterance_index].available_snippets
|
||||
|
||||
def original_gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].original_gold_query
|
||||
|
||||
def contained_entities(self):
|
||||
return self.interaction.utterances[self.utterance_index].contained_entities
|
||||
|
||||
def original_gold_queries(self):
|
||||
return [
|
||||
q[0] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
|
||||
|
||||
def gold_tables(self):
|
||||
return [
|
||||
q[1] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
|
||||
|
||||
def gold_query(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_query_to_use + [
|
||||
vocab.EOS_TOK]
|
||||
|
||||
def gold_edit_sequence(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_edit_sequence
|
||||
|
||||
def gold_table(self):
|
||||
return self.interaction.utterances[self.utterance_index].gold_sql_results
|
||||
|
||||
def all_snippets(self):
|
||||
return self.interaction.snippets
|
||||
|
||||
def within_limits(self,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
return self.interaction.utterances[self.utterance_index].length_valid(
|
||||
max_input_length, max_output_length)
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
# Remove the EOS
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
# First remove the snippets
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
return no_snippets_sequence
|
||||
|
||||
def flatten_sequence(self, sequence):
|
||||
# Remove the EOS
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
# First remove the snippets
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
|
||||
# Deanonymize
|
||||
deanon_sequence = self.interaction.deanonymize(
|
||||
no_snippets_sequence, "sql")
|
||||
return deanon_sequence
|
||||
|
||||
|
||||
class UtteranceBatch():
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def start(self):
|
||||
self.index = 0
|
||||
|
||||
def next(self):
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
def done(self):
|
||||
return self.index >= len(self.items)
|
||||
|
||||
class PredUtteranceItem():
|
||||
def __init__(self,
|
||||
input_sequence,
|
||||
interaction_item,
|
||||
previous_query,
|
||||
index,
|
||||
available_snippets):
|
||||
self.input_seq_to_use = input_sequence
|
||||
self.interaction_item = interaction_item
|
||||
self.index = index
|
||||
self.available_snippets = available_snippets
|
||||
self.prev_pred_query = previous_query
|
||||
|
||||
def input_sequence(self):
|
||||
return self.input_seq_to_use
|
||||
|
||||
def histories(self, maximum):
|
||||
if maximum == 0:
|
||||
return histories
|
||||
histories = []
|
||||
for utterance in self.interaction_item.processed_utterances[:self.index]:
|
||||
histories.append(utterance.input_sequence())
|
||||
if len(histories) > maximum:
|
||||
histories = histories[-maximum:]
|
||||
return histories
|
||||
|
||||
def snippets(self):
|
||||
return self.available_snippets
|
||||
|
||||
def previous_query(self):
|
||||
return self.prev_pred_query
|
||||
|
||||
def flatten_sequence(self, sequence):
|
||||
return self.interaction_item.flatten_sequence(sequence)
|
||||
|
||||
def remove_snippets(self, sequence):
|
||||
return sql_util.fix_parentheses(
|
||||
self.interaction_item.expand_snippets(sequence))
|
||||
|
||||
def set_predicted_query(self, query):
|
||||
self.anonymized_pred_query = query
|
||||
|
||||
# Mocks an Interaction item, but allows for the parameters to be updated during
|
||||
# the process
|
||||
|
||||
|
||||
class InteractionItem():
|
||||
def __init__(self,
|
||||
interaction,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
nl_to_sql_dict={},
|
||||
maximum_length=float('inf')):
|
||||
if maximum_length != float('inf'):
|
||||
self.interaction = copy.deepcopy(interaction)
|
||||
self.interaction.utterances = self.interaction.utterances[:maximum_length]
|
||||
else:
|
||||
self.interaction = interaction
|
||||
self.processed_utterances = []
|
||||
self.snippet_bank = []
|
||||
self.identifier = self.interaction.identifier
|
||||
|
||||
self.max_input_length = max_input_length
|
||||
self.max_output_length = max_output_length
|
||||
|
||||
self.nl_to_sql_dict = nl_to_sql_dict
|
||||
|
||||
self.index = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.interaction)
|
||||
|
||||
def __str__(self):
|
||||
s = "Utterances, gold queries, and predictions:\n"
|
||||
for i, utterance in enumerate(self.interaction.utterances):
|
||||
s += " ".join(utterance.input_seq_to_use) + "\n"
|
||||
pred_utterance = self.processed_utterances[i]
|
||||
s += " ".join(pred_utterance.gold_query()) + "\n"
|
||||
s += " ".join(pred_utterance.anonymized_query()) + "\n"
|
||||
s += "\n"
|
||||
s += "Snippets:\n"
|
||||
for snippet in self.snippet_bank:
|
||||
s += str(snippet) + "\n"
|
||||
|
||||
return s
|
||||
|
||||
def start_interaction(self):
|
||||
assert len(self.snippet_bank) == 0
|
||||
assert len(self.processed_utterances) == 0
|
||||
assert self.index == 0
|
||||
|
||||
def next_utterance(self):
|
||||
utterance = self.interaction.utterances[self.index]
|
||||
self.index += 1
|
||||
|
||||
available_snippets = self.available_snippets(snippet_keep_age=1)
|
||||
|
||||
return PredUtteranceItem(utterance.input_seq_to_use,
|
||||
self,
|
||||
self.processed_utterances[-1].anonymized_pred_query if len(self.processed_utterances) > 0 else [],
|
||||
self.index - 1,
|
||||
available_snippets)
|
||||
|
||||
def done(self):
|
||||
return len(self.processed_utterances) == len(self.interaction)
|
||||
|
||||
def finish(self):
|
||||
self.snippet_bank = []
|
||||
self.processed_utterances = []
|
||||
self.index = 0
|
||||
|
||||
def utterance_within_limits(self, utterance_item):
|
||||
return utterance_item.within_limits(self.max_input_length,
|
||||
self.max_output_length)
|
||||
|
||||
def available_snippets(self, snippet_keep_age):
|
||||
return [
|
||||
snippet for snippet in self.snippet_bank if snippet.index <= snippet_keep_age]
|
||||
|
||||
def gold_utterances(self):
|
||||
utterances = []
|
||||
for i, utterance in enumerate(self.interaction.utterances):
|
||||
utterances.append(UtteranceItem(self.interaction, i))
|
||||
return utterances
|
||||
|
||||
def get_schema(self):
|
||||
return self.interaction.schema
|
||||
|
||||
def add_utterance(
|
||||
self,
|
||||
utterance,
|
||||
predicted_sequence,
|
||||
snippets=None,
|
||||
previous_snippets=[],
|
||||
simple=False):
|
||||
if not snippets:
|
||||
self.add_snippets(
|
||||
predicted_sequence,
|
||||
previous_snippets=previous_snippets, simple=simple)
|
||||
else:
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(self.snippet_bank))
|
||||
self.snippet_bank.append(snippet)
|
||||
|
||||
for snippet in self.snippet_bank:
|
||||
snippet.increase_age()
|
||||
self.processed_utterances.append(utterance)
|
||||
|
||||
def add_snippets(self, sequence, previous_snippets=[], simple=False):
|
||||
if sequence:
|
||||
if simple:
|
||||
snippets = sql_util.get_subtrees_simple(
|
||||
sequence, oldsnippets=previous_snippets)
|
||||
else:
|
||||
snippets = sql_util.get_subtrees(
|
||||
sequence, oldsnippets=previous_snippets)
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(self.snippet_bank))
|
||||
self.snippet_bank.append(snippet)
|
||||
|
||||
for snippet in self.snippet_bank:
|
||||
snippet.increase_age()
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
return sql_util.fix_parentheses(
|
||||
snip.expand_snippets(
|
||||
sequence, self.snippet_bank))
|
||||
|
||||
def remove_snippets(self, sequence):
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
no_snippets_sequence = self.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
return no_snippets_sequence
|
||||
|
||||
def flatten_sequence(self, sequence, gold_snippets=False):
|
||||
if sequence[-1] == vocab.EOS_TOK:
|
||||
sequence = sequence[:-1]
|
||||
|
||||
if gold_snippets:
|
||||
no_snippets_sequence = self.interaction.expand_snippets(sequence)
|
||||
else:
|
||||
no_snippets_sequence = self.expand_snippets(sequence)
|
||||
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
|
||||
|
||||
deanon_sequence = self.interaction.deanonymize(
|
||||
no_snippets_sequence, "sql")
|
||||
return deanon_sequence
|
||||
|
||||
def gold_query(self, index):
|
||||
return self.interaction.utterances[index].gold_query_to_use + [
|
||||
vocab.EOS_TOK]
|
||||
|
||||
def original_gold_query(self, index):
|
||||
return self.interaction.utterances[index].original_gold_query
|
||||
|
||||
def gold_table(self, index):
|
||||
return self.interaction.utterances[index].gold_sql_results
|
||||
|
||||
|
||||
class InteractionBatch():
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def start(self):
|
||||
self.timestep = 0
|
||||
self.current_interactions = []
|
||||
|
||||
def get_next_utterance_batch(self, snippet_keep_age, use_gold=False):
|
||||
items = []
|
||||
self.current_interactions = []
|
||||
for interaction in self.items:
|
||||
if self.timestep < len(interaction):
|
||||
utterance_item = interaction.original_utterances(
|
||||
snippet_keep_age, use_gold)[self.timestep]
|
||||
self.current_interactions.append(interaction)
|
||||
items.append(utterance_item)
|
||||
|
||||
self.timestep += 1
|
||||
return UtteranceBatch(items)
|
||||
|
||||
def done(self):
|
||||
finished = True
|
||||
for interaction in self.items:
|
||||
if self.timestep < len(interaction):
|
||||
finished = False
|
||||
return finished
|
||||
return finished
|
|
@ -0,0 +1,376 @@
|
|||
""" Utility functions for loading and processing ATIS data."""
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
|
||||
from . import anonymization as anon
|
||||
from . import atis_batch
|
||||
from . import dataset_split as ds
|
||||
|
||||
from .interaction import load_function
|
||||
from .entities import NLtoSQLDict
|
||||
from .atis_vocab import ATISVocabulary
|
||||
|
||||
ENTITIES_FILENAME = './entities.txt'
|
||||
ANONYMIZATION_FILENAME = 'data/anonymization.txt'
|
||||
|
||||
class ATISDataset():
|
||||
""" Contains the ATIS data. """
|
||||
def __init__(self, params):
|
||||
self.anonymizer = None
|
||||
if params.anonymize:
|
||||
self.anonymizer = anon.Anonymizer(ANONYMIZATION_FILENAME)
|
||||
|
||||
if not os.path.exists(params.data_directory):
|
||||
os.mkdir(params.data_directory)
|
||||
|
||||
self.entities_dictionary = NLtoSQLDict(ENTITIES_FILENAME)
|
||||
|
||||
database_schema = None
|
||||
if params.database_schema_filename:
|
||||
if 'removefrom' not in params.data_directory:
|
||||
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema_simple(params.database_schema_filename)
|
||||
else:
|
||||
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema(params.database_schema_filename)
|
||||
|
||||
int_load_function = load_function(params,
|
||||
self.entities_dictionary,
|
||||
self.anonymizer,
|
||||
database_schema=database_schema)
|
||||
|
||||
def collapse_list(the_list):
|
||||
""" Collapses a list of list into a single list."""
|
||||
return [s for i in the_list for s in i]
|
||||
|
||||
if 'atis' not in params.data_directory:
|
||||
self.train_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_train_filename),
|
||||
params.raw_train_filename,
|
||||
int_load_function)
|
||||
self.valid_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_validation_filename),
|
||||
params.raw_validation_filename,
|
||||
int_load_function)
|
||||
|
||||
train_input_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.input_seqs()))
|
||||
valid_input_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.input_seqs()))
|
||||
all_input_seqs = train_input_seqs + valid_input_seqs
|
||||
|
||||
self.input_vocabulary = ATISVocabulary(
|
||||
all_input_seqs,
|
||||
os.path.join(params.data_directory, params.input_vocabulary_filename),
|
||||
params,
|
||||
is_input='input',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
self.output_vocabulary_schema = ATISVocabulary(
|
||||
column_names_embedder_input,
|
||||
os.path.join(params.data_directory, 'schema_'+params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='schema',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
#train_output_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.output_seqs()))
|
||||
#valid_output_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.output_seqs()))
|
||||
#all_output_seqs = train_output_seqs + valid_output_seqs
|
||||
|
||||
sql_keywords = ['.', 't1', 't2', '=', 'select', 'as', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'group', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
|
||||
sql_keywords += ['count', 'from', 'value', 'order']
|
||||
sql_keywords += ['group_by', 'order_by', 'limit_value', '!=']
|
||||
|
||||
# skip column_names_surface_form but keep sql_keywords
|
||||
skip_tokens = list(set(column_names_surface_form) - set(sql_keywords))
|
||||
#print(len(all_output_seqs))
|
||||
all_output_seqs = []
|
||||
out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'and', 'avg', 'having', '<', 'in', 'sum', 'max', 'asc', 'not', 'or', 'like', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
|
||||
#if params.data_directory == 'processed_data_sparc_removefrom_test':
|
||||
#out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'avg', 'and', 'having', '<', 'in', 'max', 'sum', 'asc', 'like', 'not', 'or', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
|
||||
for i in range(len(out_vocab_ordered)):
|
||||
all_output_seqs.append(out_vocab_ordered[:i+1])
|
||||
|
||||
self.output_vocabulary = ATISVocabulary(
|
||||
all_output_seqs,
|
||||
os.path.join(params.data_directory, params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='output',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None,
|
||||
skip=skip_tokens)
|
||||
else:
|
||||
self.train_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_train_filename),
|
||||
params.raw_train_filename,
|
||||
int_load_function)
|
||||
if params.train:
|
||||
self.valid_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_validation_filename),
|
||||
params.raw_validation_filename,
|
||||
int_load_function)
|
||||
if params.evaluate or params.attention:
|
||||
self.dev_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_dev_filename),
|
||||
params.raw_dev_filename,
|
||||
int_load_function)
|
||||
if params.enable_testing:
|
||||
self.test_data = ds.DatasetSplit(
|
||||
os.path.join(params.data_directory, params.processed_test_filename),
|
||||
params.raw_test_filename,
|
||||
int_load_function)
|
||||
|
||||
train_input_seqs = []
|
||||
train_input_seqs = collapse_list(
|
||||
self.train_data.get_ex_properties(
|
||||
lambda i: i.input_seqs()))
|
||||
|
||||
self.input_vocabulary = ATISVocabulary(
|
||||
train_input_seqs,
|
||||
os.path.join(params.data_directory, params.input_vocabulary_filename),
|
||||
params,
|
||||
is_input='input',
|
||||
min_occur=2,
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
train_output_seqs = collapse_list(
|
||||
self.train_data.get_ex_properties(
|
||||
lambda i: i.output_seqs()))
|
||||
|
||||
self.output_vocabulary = ATISVocabulary(
|
||||
train_output_seqs,
|
||||
os.path.join(params.data_directory, params.output_vocabulary_filename),
|
||||
params,
|
||||
is_input='output',
|
||||
anonymizer=self.anonymizer if params.anonymization_scoring else None)
|
||||
|
||||
self.output_vocabulary_schema = None
|
||||
|
||||
def read_database_schema_simple(self, database_schema_filename):
|
||||
with open(database_schema_filename, "r") as f:
|
||||
database_schema = json.load(f)
|
||||
|
||||
database_schema_dict = {}
|
||||
column_names_surface_form = []
|
||||
column_names_embedder_input = []
|
||||
for table_schema in database_schema:
|
||||
db_id = table_schema['db_id']
|
||||
database_schema_dict[db_id] = table_schema
|
||||
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append(table_name.lower())
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
column_name_embedder_input = column_name
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
for table_name in table_names:
|
||||
column_names_embedder_input.append(table_name.split())
|
||||
|
||||
database_schema = database_schema_dict
|
||||
|
||||
return database_schema, column_names_surface_form, column_names_embedder_input
|
||||
|
||||
def read_database_schema(self, database_schema_filename):
|
||||
with open(database_schema_filename, "r") as f:
|
||||
database_schema = json.load(f)
|
||||
|
||||
database_schema_dict = {}
|
||||
column_names_surface_form = []
|
||||
column_names_embedder_input = []
|
||||
for table_schema in database_schema:
|
||||
db_id = table_schema['db_id']
|
||||
database_schema_dict[db_id] = table_schema
|
||||
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
if table_id >= 0:
|
||||
table_name = table_names[table_id]
|
||||
column_name_embedder_input = table_name + ' . ' + column_name
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
for table_name in table_names:
|
||||
column_name_embedder_input = table_name + ' . *'
|
||||
column_names_embedder_input.append(column_name_embedder_input.split())
|
||||
|
||||
database_schema = database_schema_dict
|
||||
|
||||
return database_schema, column_names_surface_form, column_names_embedder_input
|
||||
|
||||
def get_all_utterances(self,
|
||||
dataset,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
""" Returns all utterances in a dataset."""
|
||||
items = []
|
||||
for interaction in dataset.examples:
|
||||
for i, utterance in enumerate(interaction.utterances):
|
||||
if utterance.length_valid(max_input_length, max_output_length):
|
||||
items.append(atis_batch.UtteranceItem(interaction, i))
|
||||
return items
|
||||
|
||||
def get_all_interactions(self,
|
||||
dataset,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
sorted_by_length=False):
|
||||
"""Gets all interactions in a dataset that fit the criteria.
|
||||
|
||||
Inputs:
|
||||
dataset (ATISDatasetSplit): The dataset to use.
|
||||
max_interaction_length (int): Maximum interaction length to keep.
|
||||
max_input_length (int): Maximum input sequence length to keep.
|
||||
max_output_length (int): Maximum output sequence length to keep.
|
||||
sorted_by_length (bool): Whether to sort the examples by interaction length.
|
||||
"""
|
||||
ints = [
|
||||
atis_batch.InteractionItem(
|
||||
interaction,
|
||||
max_input_length,
|
||||
max_output_length,
|
||||
self.entities_dictionary,
|
||||
max_interaction_length) for interaction in dataset.examples]
|
||||
if sorted_by_length:
|
||||
return sorted(ints, key=lambda x: len(x))[::-1]
|
||||
else:
|
||||
return ints
|
||||
|
||||
# This defines a standard way of training: each example is an utterance, and
|
||||
# the batch can contain unrelated utterances.
|
||||
def get_utterance_batches(self,
|
||||
batch_size,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
randomize=True):
|
||||
"""Gets batches of utterances in the data.
|
||||
|
||||
Inputs:
|
||||
batch_size (int): Batch size to use.
|
||||
max_input_length (int): Maximum length of input to keep.
|
||||
max_output_length (int): Maximum length of output to use.
|
||||
randomize (bool): Whether to randomize the ordering.
|
||||
"""
|
||||
# First, get all interactions and the positions of the utterances that are
|
||||
# possible in them.
|
||||
items = self.get_all_utterances(self.train_data,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
if randomize:
|
||||
random.shuffle(items)
|
||||
|
||||
batches = []
|
||||
|
||||
current_batch_items = []
|
||||
for item in items:
|
||||
if len(current_batch_items) >= batch_size:
|
||||
batches.append(atis_batch.UtteranceBatch(current_batch_items))
|
||||
current_batch_items = []
|
||||
current_batch_items.append(item)
|
||||
batches.append(atis_batch.UtteranceBatch(current_batch_items))
|
||||
|
||||
assert sum([len(batch) for batch in batches]) == len(items)
|
||||
|
||||
return batches
|
||||
|
||||
def get_interaction_batches(self,
|
||||
batch_size,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf'),
|
||||
randomize=True):
|
||||
"""Gets batches of interactions in the data.
|
||||
|
||||
Inputs:
|
||||
batch_size (int): Batch size to use.
|
||||
max_interaction_length (int): Maximum length of interaction to keep
|
||||
max_input_length (int): Maximum length of input to keep.
|
||||
max_output_length (int): Maximum length of output to keep.
|
||||
randomize (bool): Whether to randomize the ordering.
|
||||
"""
|
||||
items = self.get_all_interactions(self.train_data,
|
||||
max_interaction_length,
|
||||
max_input_length,
|
||||
max_output_length,
|
||||
sorted_by_length=not randomize)
|
||||
if randomize:
|
||||
random.shuffle(items)
|
||||
|
||||
batches = []
|
||||
current_batch_items = []
|
||||
for item in items:
|
||||
if len(current_batch_items) >= batch_size:
|
||||
batches.append(
|
||||
atis_batch.InteractionBatch(current_batch_items))
|
||||
current_batch_items = []
|
||||
current_batch_items.append(item)
|
||||
batches.append(atis_batch.InteractionBatch(current_batch_items))
|
||||
|
||||
assert sum([len(batch) for batch in batches]) == len(items)
|
||||
|
||||
return batches
|
||||
|
||||
def get_random_utterances(self,
|
||||
num_samples,
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
"""Gets a random selection of utterances in the data.
|
||||
|
||||
Inputs:
|
||||
num_samples (bool): Number of random utterances to get.
|
||||
max_input_length (int): Limit of input length.
|
||||
max_output_length (int): Limit on output length.
|
||||
"""
|
||||
items = self.get_all_utterances(self.train_data,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
random.shuffle(items)
|
||||
return items[:num_samples]
|
||||
|
||||
def get_random_interactions(self,
|
||||
num_samples,
|
||||
max_interaction_length=float('inf'),
|
||||
max_input_length=float('inf'),
|
||||
max_output_length=float('inf')):
|
||||
|
||||
"""Gets a random selection of interactions in the data.
|
||||
|
||||
Inputs:
|
||||
num_samples (bool): Number of random interactions to get.
|
||||
max_input_length (int): Limit of input length.
|
||||
max_output_length (int): Limit on output length.
|
||||
"""
|
||||
items = self.get_all_interactions(self.train_data,
|
||||
max_interaction_length,
|
||||
max_input_length,
|
||||
max_output_length)
|
||||
random.shuffle(items)
|
||||
return items[:num_samples]
|
||||
|
||||
|
||||
def num_utterances(dataset):
|
||||
"""Returns the total number of utterances in the dataset."""
|
||||
return sum([len(interaction) for interaction in dataset.examples])
|
|
@ -0,0 +1,74 @@
|
|||
"""Gets and stores vocabulary for the ATIS data."""
|
||||
|
||||
from . import snippets
|
||||
from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK
|
||||
|
||||
INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK]
|
||||
OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK]
|
||||
|
||||
MIN_INPUT_OCCUR = 1
|
||||
MIN_OUTPUT_OCCUR = 1
|
||||
|
||||
class ATISVocabulary():
|
||||
""" Stores the vocabulary for the ATIS data.
|
||||
|
||||
Attributes:
|
||||
raw_vocab (Vocabulary): Vocabulary object.
|
||||
tokens (set of str): Set of all of the strings in the vocabulary.
|
||||
inorder_tokens (list of str): List of all tokens, with a strict and
|
||||
unchanging order.
|
||||
"""
|
||||
def __init__(self,
|
||||
token_sequences,
|
||||
filename,
|
||||
params,
|
||||
is_input='input',
|
||||
min_occur=1,
|
||||
anonymizer=None,
|
||||
skip=None):
|
||||
|
||||
if is_input=='input':
|
||||
functional_types = INPUT_FN_TYPES
|
||||
elif is_input=='output':
|
||||
functional_types = OUTPUT_FN_TYPES
|
||||
elif is_input=='schema':
|
||||
functional_types = [UNK_TOK]
|
||||
else:
|
||||
functional_types = []
|
||||
|
||||
self.raw_vocab = Vocabulary(
|
||||
token_sequences,
|
||||
filename,
|
||||
functional_types=functional_types,
|
||||
min_occur=min_occur,
|
||||
ignore_fn=lambda x: snippets.is_snippet(x) or (
|
||||
anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) )
|
||||
self.tokens = set(self.raw_vocab.token_to_id.keys())
|
||||
self.inorder_tokens = self.raw_vocab.id_to_token
|
||||
|
||||
assert len(self.inorder_tokens) == len(self.raw_vocab)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_vocab)
|
||||
|
||||
def token_to_id(self, token):
|
||||
""" Maps from a token to a unique ID.
|
||||
|
||||
Inputs:
|
||||
token (str): The token to look up.
|
||||
|
||||
Returns:
|
||||
int, uniquely identifying the token.
|
||||
"""
|
||||
return self.raw_vocab.token_to_id[token]
|
||||
|
||||
def id_to_token(self, identifier):
|
||||
""" Maps from a unique integer to an identifier.
|
||||
|
||||
Inputs:
|
||||
identifier (int): The unique ID.
|
||||
|
||||
Returns:
|
||||
string, representing the token.
|
||||
"""
|
||||
return self.raw_vocab.id_to_token[identifier]
|
|
@ -0,0 +1,56 @@
|
|||
""" Utility functions for loading and processing ATIS data.
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
|
||||
class DatasetSplit:
|
||||
"""Stores a split of the ATIS dataset.
|
||||
|
||||
Attributes:
|
||||
examples (list of Interaction): Stores the examples in the split.
|
||||
"""
|
||||
def __init__(self, processed_filename, raw_filename, load_function):
|
||||
if os.path.exists(processed_filename):
|
||||
print("Loading preprocessed data from " + processed_filename)
|
||||
with open(processed_filename, 'rb') as infile:
|
||||
self.examples = pickle.load(infile)
|
||||
else:
|
||||
print(
|
||||
"Loading raw data from " +
|
||||
raw_filename +
|
||||
" and writing to " +
|
||||
processed_filename)
|
||||
|
||||
infile = open(raw_filename, 'rb')
|
||||
examples_from_file = pickle.load(infile)
|
||||
assert isinstance(examples_from_file, list), raw_filename + \
|
||||
" does not contain a list of examples"
|
||||
infile.close()
|
||||
|
||||
self.examples = []
|
||||
for example in examples_from_file:
|
||||
obj, keep = load_function(example)
|
||||
|
||||
if keep:
|
||||
self.examples.append(obj)
|
||||
|
||||
|
||||
print("Loaded " + str(len(self.examples)) + " examples")
|
||||
outfile = open(processed_filename, 'wb')
|
||||
pickle.dump(self.examples, outfile)
|
||||
outfile.close()
|
||||
|
||||
def get_ex_properties(self, function):
|
||||
""" Applies some function to the examples in the dataset.
|
||||
|
||||
Inputs:
|
||||
function: (lambda Interaction -> T): Function to apply to all
|
||||
examples.
|
||||
|
||||
Returns
|
||||
list of the return value of the function
|
||||
"""
|
||||
elems = []
|
||||
for example in self.examples:
|
||||
elems.append(function(example))
|
||||
return elems
|
|
@ -0,0 +1,63 @@
|
|||
""" Classes for keeping track of the entities in a natural language string. """
|
||||
import json
|
||||
|
||||
|
||||
class NLtoSQLDict:
|
||||
"""
|
||||
Entity dict file should contain, on each line, a JSON dictionary with
|
||||
"input" and "output" keys specifying the string for the input and output
|
||||
pairs. The idea is that the existence of the key in an input sequence
|
||||
likely corresponds to the existence of the value in the output sequence.
|
||||
|
||||
The entity_dict should map keys (input strings) to a list of values (output
|
||||
strings) where this property holds. This allows keys to map to multiple
|
||||
output strings (e.g. for times).
|
||||
"""
|
||||
def __init__(self, entity_dict_filename):
|
||||
self.entity_dict = {}
|
||||
|
||||
pairs = [json.loads(line)
|
||||
for line in open(entity_dict_filename).readlines()]
|
||||
for pair in pairs:
|
||||
input_seq = pair["input"]
|
||||
output_seq = pair["output"]
|
||||
if input_seq not in self.entity_dict:
|
||||
self.entity_dict[input_seq] = []
|
||||
self.entity_dict[input_seq].append(output_seq)
|
||||
|
||||
def get_sql_entities(self, tokenized_nl_string):
|
||||
"""
|
||||
Gets the output-side entities which correspond to the input entities in
|
||||
the input sequence.
|
||||
Inputs:
|
||||
tokenized_input_string: list of tokens in the input string.
|
||||
Outputs:
|
||||
set of output strings.
|
||||
"""
|
||||
assert len(tokenized_nl_string) > 0
|
||||
flat_input_string = " ".join(tokenized_nl_string)
|
||||
entities = []
|
||||
|
||||
# See if any input strings are in our input sequence, and add the
|
||||
# corresponding output strings if so.
|
||||
for entry, values in self.entity_dict.items():
|
||||
in_middle = " " + entry + " " in flat_input_string
|
||||
|
||||
leftspace = " " + entry
|
||||
at_end = leftspace in flat_input_string and flat_input_string.endswith(
|
||||
leftspace)
|
||||
|
||||
rightspace = entry + " "
|
||||
at_beginning = rightspace in flat_input_string and flat_input_string.startswith(
|
||||
rightspace)
|
||||
if in_middle or at_end or at_beginning:
|
||||
for out_string in values:
|
||||
entities.append(out_string)
|
||||
|
||||
# Also add any integers in the input string (these aren't in the entity)
|
||||
# dict.
|
||||
for token in tokenized_nl_string:
|
||||
if token.isnumeric():
|
||||
entities.append(token)
|
||||
|
||||
return entities
|
|
@ -0,0 +1,312 @@
|
|||
""" Contains the class for an interaction in ATIS. """
|
||||
|
||||
from . import anonymization as anon
|
||||
from . import sql_util
|
||||
from .snippets import expand_snippets
|
||||
from .utterance import Utterance, OUTPUT_KEY, ANON_INPUT_KEY
|
||||
|
||||
import torch
|
||||
|
||||
class Schema:
|
||||
def __init__(self, table_schema, simple=False):
|
||||
if simple:
|
||||
self.helper1(table_schema)
|
||||
else:
|
||||
self.helper2(table_schema)
|
||||
|
||||
def helper1(self, table_schema):
|
||||
self.table_schema = table_schema
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
|
||||
|
||||
column_keep_index = []
|
||||
|
||||
self.column_names_surface_form = []
|
||||
self.column_names_surface_form_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
column_name_surface_form = column_name
|
||||
column_name_surface_form = column_name_surface_form.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index.append(i)
|
||||
|
||||
column_keep_index_2 = []
|
||||
for i, table_name in enumerate(table_names_original):
|
||||
column_name_surface_form = table_name.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index_2.append(i)
|
||||
|
||||
self.column_names_embedder_input = []
|
||||
self.column_names_embedder_input_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
column_name_embedder_input = column_name
|
||||
if i in column_keep_index:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
for i, table_name in enumerate(table_names):
|
||||
column_name_embedder_input = table_name
|
||||
if i in column_keep_index_2:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
|
||||
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
|
||||
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
|
||||
|
||||
self.num_col = len(self.column_names_surface_form)
|
||||
|
||||
def helper2(self, table_schema):
|
||||
self.table_schema = table_schema
|
||||
column_names = table_schema['column_names']
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
|
||||
|
||||
column_keep_index = []
|
||||
|
||||
self.column_names_surface_form = []
|
||||
self.column_names_surface_form_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
column_name_surface_form = column_name
|
||||
column_name_surface_form = column_name_surface_form.lower()
|
||||
if column_name_surface_form not in self.column_names_surface_form_to_id:
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
|
||||
column_keep_index.append(i)
|
||||
|
||||
start_i = len(self.column_names_surface_form_to_id)
|
||||
for i, table_name in enumerate(table_names_original):
|
||||
column_name_surface_form = '{}.*'.format(table_name.lower())
|
||||
self.column_names_surface_form.append(column_name_surface_form)
|
||||
self.column_names_surface_form_to_id[column_name_surface_form] = i + start_i
|
||||
|
||||
self.column_names_embedder_input = []
|
||||
self.column_names_embedder_input_to_id = {}
|
||||
for i, (table_id, column_name) in enumerate(column_names):
|
||||
if table_id >= 0:
|
||||
table_name = table_names[table_id]
|
||||
column_name_embedder_input = table_name + ' . ' + column_name
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
if i in column_keep_index:
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
|
||||
|
||||
start_i = len(self.column_names_embedder_input_to_id)
|
||||
for i, table_name in enumerate(table_names):
|
||||
column_name_embedder_input = table_name + ' . *'
|
||||
self.column_names_embedder_input.append(column_name_embedder_input)
|
||||
self.column_names_embedder_input_to_id[column_name_embedder_input] = i + start_i
|
||||
|
||||
assert len(self.column_names_surface_form) == len(self.column_names_surface_form_to_id) == len(self.column_names_embedder_input) == len(self.column_names_embedder_input_to_id)
|
||||
|
||||
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
|
||||
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
|
||||
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
|
||||
|
||||
self.num_col = len(self.column_names_surface_form)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_col
|
||||
|
||||
def in_vocabulary(self, column_name, surface_form=False):
|
||||
if surface_form:
|
||||
return column_name in self.column_names_surface_form_to_id
|
||||
else:
|
||||
return column_name in self.column_names_embedder_input_to_id
|
||||
|
||||
def column_name_embedder_bow(self, column_name, surface_form=False, column_name_token_embedder=None):
|
||||
assert self.in_vocabulary(column_name, surface_form)
|
||||
if surface_form:
|
||||
column_name_id = self.column_names_surface_form_to_id[column_name]
|
||||
column_name_embedder_input = self.column_names_embedder_input[column_name_id]
|
||||
else:
|
||||
column_name_embedder_input = column_name
|
||||
|
||||
column_name_embeddings = [column_name_token_embedder(token) for token in column_name_embedder_input.split()]
|
||||
column_name_embeddings = torch.stack(column_name_embeddings, dim=0)
|
||||
return torch.mean(column_name_embeddings, dim=0)
|
||||
|
||||
def set_column_name_embeddings(self, column_name_embeddings):
|
||||
self.column_name_embeddings = column_name_embeddings
|
||||
assert len(self.column_name_embeddings) == self.num_col
|
||||
|
||||
def column_name_embedder(self, column_name, surface_form=False):
|
||||
assert self.in_vocabulary(column_name, surface_form)
|
||||
if surface_form:
|
||||
column_name_id = self.column_names_surface_form_to_id[column_name]
|
||||
else:
|
||||
column_name_id = self.column_names_embedder_input_to_id[column_name]
|
||||
|
||||
return self.column_name_embeddings[column_name_id]
|
||||
|
||||
class Interaction:
|
||||
""" ATIS interaction class.
|
||||
|
||||
Attributes:
|
||||
utterances (list of Utterance): The utterances in the interaction.
|
||||
snippets (list of Snippet): The snippets that appear through the interaction.
|
||||
anon_tok_to_ent:
|
||||
identifier (str): Unique identifier for the interaction in the dataset.
|
||||
"""
|
||||
def __init__(self,
|
||||
utterances,
|
||||
schema,
|
||||
snippets,
|
||||
anon_tok_to_ent,
|
||||
identifier,
|
||||
params):
|
||||
self.utterances = utterances
|
||||
self.schema = schema
|
||||
self.snippets = snippets
|
||||
self.anon_tok_to_ent = anon_tok_to_ent
|
||||
self.identifier = identifier
|
||||
|
||||
# Ensure that each utterance's input and output sequences, when remapped
|
||||
# without anonymization or snippets, are the same as the original
|
||||
# version.
|
||||
for i, utterance in enumerate(self.utterances):
|
||||
deanon_input = self.deanonymize(utterance.input_seq_to_use,
|
||||
ANON_INPUT_KEY)
|
||||
assert deanon_input == utterance.original_input_seq, "Anonymized sequence [" \
|
||||
+ " ".join(utterance.input_seq_to_use) + "] is not the same as [" \
|
||||
+ " ".join(utterance.original_input_seq) + "] when deanonymized (is [" \
|
||||
+ " ".join(deanon_input) + "] instead)"
|
||||
desnippet_gold = self.expand_snippets(utterance.gold_query_to_use)
|
||||
deanon_gold = self.deanonymize(desnippet_gold, OUTPUT_KEY)
|
||||
assert deanon_gold == utterance.original_gold_query, \
|
||||
"Anonymized and/or snippet'd query " \
|
||||
+ " ".join(utterance.gold_query_to_use) + " is not the same as " \
|
||||
+ " ".join(utterance.original_gold_query)
|
||||
|
||||
def __str__(self):
|
||||
string = "Utterances:\n"
|
||||
for utterance in self.utterances:
|
||||
string += str(utterance) + "\n"
|
||||
string += "Anonymization dictionary:\n"
|
||||
for ent_tok, deanon in self.anon_tok_to_ent.items():
|
||||
string += ent_tok + "\t" + str(deanon) + "\n"
|
||||
|
||||
return string
|
||||
|
||||
def __len__(self):
|
||||
return len(self.utterances)
|
||||
|
||||
def deanonymize(self, sequence, key):
|
||||
""" Deanonymizes a predicted query or an input utterance.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): The sequence to deanonymize.
|
||||
key (str): The key in the anonymization table, e.g. NL or SQL.
|
||||
"""
|
||||
return anon.deanonymize(sequence, self.anon_tok_to_ent, key)
|
||||
|
||||
def expand_snippets(self, sequence):
|
||||
""" Expands snippets for a sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): A SQL query.
|
||||
|
||||
"""
|
||||
return expand_snippets(sequence, self.snippets)
|
||||
|
||||
def input_seqs(self):
|
||||
in_seqs = []
|
||||
for utterance in self.utterances:
|
||||
in_seqs.append(utterance.input_seq_to_use)
|
||||
return in_seqs
|
||||
|
||||
def output_seqs(self):
|
||||
out_seqs = []
|
||||
for utterance in self.utterances:
|
||||
out_seqs.append(utterance.gold_query_to_use)
|
||||
return out_seqs
|
||||
|
||||
def load_function(parameters,
|
||||
nl_to_sql_dict,
|
||||
anonymizer,
|
||||
database_schema=None):
|
||||
def fn(interaction_example):
|
||||
keep = False
|
||||
|
||||
raw_utterances = interaction_example["interaction"]
|
||||
|
||||
if "database_id" in interaction_example:
|
||||
database_id = interaction_example["database_id"]
|
||||
interaction_id = interaction_example["interaction_id"]
|
||||
identifier = database_id + '/' + str(interaction_id)
|
||||
else:
|
||||
identifier = interaction_example["id"]
|
||||
|
||||
schema = None
|
||||
if database_schema:
|
||||
if 'removefrom' not in parameters.data_directory:
|
||||
schema = Schema(database_schema[database_id], simple=True)
|
||||
else:
|
||||
schema = Schema(database_schema[database_id])
|
||||
|
||||
snippet_bank = []
|
||||
|
||||
utterance_examples = []
|
||||
|
||||
anon_tok_to_ent = {}
|
||||
|
||||
for utterance in raw_utterances:
|
||||
available_snippets = [
|
||||
snippet for snippet in snippet_bank if snippet.index <= 1]
|
||||
|
||||
proc_utterance = Utterance(utterance,
|
||||
available_snippets,
|
||||
nl_to_sql_dict,
|
||||
parameters,
|
||||
anon_tok_to_ent,
|
||||
anonymizer)
|
||||
keep_utterance = proc_utterance.keep
|
||||
|
||||
if schema:
|
||||
assert keep_utterance
|
||||
|
||||
if keep_utterance:
|
||||
keep = True
|
||||
utterance_examples.append(proc_utterance)
|
||||
|
||||
# Update the snippet bank, and age each snippet in it.
|
||||
if parameters.use_snippets:
|
||||
if 'atis' in parameters.data_directory:
|
||||
snippets = sql_util.get_subtrees(
|
||||
proc_utterance.anonymized_gold_query,
|
||||
proc_utterance.available_snippets)
|
||||
else:
|
||||
snippets = sql_util.get_subtrees_simple(
|
||||
proc_utterance.anonymized_gold_query,
|
||||
proc_utterance.available_snippets)
|
||||
|
||||
for snippet in snippets:
|
||||
snippet.assign_id(len(snippet_bank))
|
||||
snippet_bank.append(snippet)
|
||||
|
||||
for snippet in snippet_bank:
|
||||
snippet.increase_age()
|
||||
|
||||
interaction = Interaction(utterance_examples,
|
||||
schema,
|
||||
snippet_bank,
|
||||
anon_tok_to_ent,
|
||||
identifier,
|
||||
parameters)
|
||||
|
||||
return interaction, keep
|
||||
|
||||
return fn
|
|
@ -0,0 +1,105 @@
|
|||
""" Contains the Snippet class and methods for handling snippets.
|
||||
|
||||
Attributes:
|
||||
SNIPPET_PREFIX: string prefix for snippets.
|
||||
"""
|
||||
|
||||
SNIPPET_PREFIX = "SNIPPET_"
|
||||
|
||||
|
||||
def is_snippet(token):
|
||||
""" Determines whether a token is a snippet or not.
|
||||
|
||||
Inputs:
|
||||
token (str): The token to check.
|
||||
|
||||
Returns:
|
||||
bool, indicating whether it's a snippet.
|
||||
"""
|
||||
return token.startswith(SNIPPET_PREFIX)
|
||||
|
||||
def expand_snippets(sequence, snippets):
|
||||
""" Given a sequence and a list of snippets, expand the snippets in the sequence.
|
||||
|
||||
Inputs:
|
||||
sequence (list of str): Query containing snippet references.
|
||||
snippets (list of Snippet): List of available snippets.
|
||||
|
||||
return list of str representing the expanded sequence
|
||||
"""
|
||||
snippet_id_to_snippet = {}
|
||||
for snippet in snippets:
|
||||
assert snippet.name not in snippet_id_to_snippet
|
||||
snippet_id_to_snippet[snippet.name] = snippet
|
||||
expanded_seq = []
|
||||
for token in sequence:
|
||||
if token in snippet_id_to_snippet:
|
||||
expanded_seq.extend(snippet_id_to_snippet[token].sequence)
|
||||
else:
|
||||
assert not is_snippet(token)
|
||||
expanded_seq.append(token)
|
||||
|
||||
return expanded_seq
|
||||
|
||||
def snippet_index(token):
|
||||
""" Returns the index of a snippet.
|
||||
|
||||
Inputs:
|
||||
token (str): The snippet to check.
|
||||
|
||||
Returns:
|
||||
integer, the index of the snippet.
|
||||
"""
|
||||
assert is_snippet(token)
|
||||
return int(token.split("_")[-1])
|
||||
|
||||
|
||||
class Snippet():
|
||||
""" Contains a snippet. """
|
||||
def __init__(self,
|
||||
sequence,
|
||||
startpos,
|
||||
sql,
|
||||
age=0):
|
||||
self.sequence = sequence
|
||||
self.startpos = startpos
|
||||
self.sql = sql
|
||||
|
||||
# TODO: age vs. index?
|
||||
self.age = age
|
||||
self.index = 0
|
||||
|
||||
self.name = ""
|
||||
self.embedding = None
|
||||
|
||||
self.endpos = self.startpos + len(self.sequence)
|
||||
assert self.endpos < len(self.sql), "End position of snippet is " + str(
|
||||
self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")"
|
||||
assert self.sequence == self.sql[self.startpos:self.endpos], \
|
||||
"Value of snippet (" + " ".join(self.sequence) + ") " \
|
||||
"is not the same as SQL at the same positions (" \
|
||||
+ " ".join(self.sql[self.startpos:self.endpos]) + ")"
|
||||
|
||||
def __str__(self):
|
||||
return self.name + "\t" + \
|
||||
str(self.age) + "\t" + " ".join(self.sequence)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence)
|
||||
|
||||
def increase_age(self):
|
||||
""" Ages a snippet by one. """
|
||||
self.index += 1
|
||||
|
||||
def assign_id(self, number):
|
||||
""" Assigns the name of the snippet to be the prefix + the number. """
|
||||
self.name = SNIPPET_PREFIX + str(number)
|
||||
|
||||
def set_embedding(self, embedding):
|
||||
""" Sets the embedding of the snippet.
|
||||
|
||||
Inputs:
|
||||
embedding (dy.Expression)
|
||||
|
||||
"""
|
||||
self.embedding = embedding
|
|
@ -0,0 +1,445 @@
|
|||
import copy
|
||||
import pymysql
|
||||
import random
|
||||
import signal
|
||||
import sqlparse
|
||||
from . import util
|
||||
|
||||
from .snippets import Snippet
|
||||
from sqlparse import tokens as token_types
|
||||
from sqlparse import sql as sql_types
|
||||
|
||||
interesting_selects = ["DISTINCT", "MAX", "MIN", "count"]
|
||||
ignored_subtrees = [["1", "=", "1"]]
|
||||
|
||||
# strip_whitespace_front
|
||||
# Strips whitespace and punctuation from the front of a SQL token list.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the token list.
|
||||
#
|
||||
# Outputs:
|
||||
# new token list.
|
||||
|
||||
|
||||
def strip_whitespace_front(token_list):
|
||||
new_token_list = []
|
||||
found_valid = False
|
||||
|
||||
for token in token_list:
|
||||
if not (token.is_whitespace or token.ttype ==
|
||||
token_types.Punctuation) or found_valid:
|
||||
found_valid = True
|
||||
new_token_list.append(token)
|
||||
|
||||
return new_token_list
|
||||
|
||||
# strip_whitespace
|
||||
# Strips whitespace from a token list.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the token list.
|
||||
#
|
||||
# Outputs:
|
||||
# new token list with no whitespace/punctuation surrounding.
|
||||
|
||||
|
||||
def strip_whitespace(token_list):
|
||||
subtokens = strip_whitespace_front(token_list)
|
||||
subtokens = strip_whitespace_front(subtokens[::-1])[::-1]
|
||||
return subtokens
|
||||
|
||||
# token_list_to_seq
|
||||
# Converts a Token list to a sequence of strings, stripping out surrounding
|
||||
# punctuation and all whitespace.
|
||||
#
|
||||
# Inputs:
|
||||
# token_list: the list of tokens.
|
||||
#
|
||||
# Outputs:
|
||||
# list of strings
|
||||
|
||||
|
||||
def token_list_to_seq(token_list):
|
||||
subtokens = strip_whitespace(token_list)
|
||||
|
||||
seq = []
|
||||
flat = sqlparse.sql.TokenList(subtokens).flatten()
|
||||
for i, token in enumerate(flat):
|
||||
strip_token = str(token).strip()
|
||||
if len(strip_token) > 0:
|
||||
seq.append(strip_token)
|
||||
if len(seq) > 0:
|
||||
if seq[0] == "(" and seq[-1] == ")":
|
||||
seq = seq[1:-1]
|
||||
|
||||
return seq
|
||||
|
||||
# TODO: clean this up
|
||||
# find_subtrees
|
||||
# Finds subtrees for a subsequence of SQL.
|
||||
#
|
||||
# Inputs:
|
||||
# sequence: sequence of SQL tokens.
|
||||
# current_subtrees: current list of subtrees.
|
||||
#
|
||||
# Optional inputs:
|
||||
# where_parent: whether the parent of the current sequence was a where clause
|
||||
# keep_conj_subtrees: whether to look for a conjunction in this sequence and
|
||||
# keep its arguments
|
||||
|
||||
|
||||
def find_subtrees(sequence,
|
||||
current_subtrees,
|
||||
where_parent=False,
|
||||
keep_conj_subtrees=False):
|
||||
# If the parent of the subsequence was a WHERE clause, keep everything in the
|
||||
# sequence except for the beginning WHERE and any surrounding parentheses.
|
||||
if where_parent:
|
||||
# Strip out the beginning WHERE, and any punctuation or whitespace at the
|
||||
# beginning or end of the token list.
|
||||
seq = token_list_to_seq(sequence.tokens[1:])
|
||||
if len(seq) > 0 and seq not in current_subtrees:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
# If the current sequence has subtokens, i.e. if it's a node that can be
|
||||
# expanded, check for a conjunction in its subtrees, and expand its subtrees.
|
||||
# Also check for any SELECT statements and keep track of what follows.
|
||||
if sequence.is_group:
|
||||
if keep_conj_subtrees:
|
||||
subtokens = strip_whitespace(sequence.tokens)
|
||||
|
||||
# Check if there is a conjunction in the subsequence. If so, keep the
|
||||
# children. Also make sure you don't split where AND is used within a
|
||||
# child -- the subtokens sequence won't treat those ANDs differently (a
|
||||
# bit hacky but it works)
|
||||
has_and = False
|
||||
for i, token in enumerate(subtokens):
|
||||
if token.value == "OR" or token.value == "AND":
|
||||
has_and = True
|
||||
break
|
||||
|
||||
if has_and:
|
||||
and_subtrees = []
|
||||
current_subtree = []
|
||||
for i, token in enumerate(subtokens):
|
||||
if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len(
|
||||
subtokens) and subtokens[i - 4].value != "BETWEEN"):
|
||||
and_subtrees.append(current_subtree)
|
||||
current_subtree = []
|
||||
else:
|
||||
current_subtree.append(token)
|
||||
and_subtrees.append(current_subtree)
|
||||
|
||||
for subtree in and_subtrees:
|
||||
seq = token_list_to_seq(subtree)
|
||||
if len(seq) > 0 and seq[0] == "WHERE":
|
||||
seq = seq[1:]
|
||||
if seq not in current_subtrees:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
in_select = False
|
||||
select_toks = []
|
||||
for i, token in enumerate(sequence.tokens):
|
||||
# Mark whether this current token is a WHERE.
|
||||
is_where = (isinstance(token, sql_types.Where))
|
||||
|
||||
# If you are in a SELECT, start recording what follows until you hit a
|
||||
# FROM
|
||||
if token.value == "SELECT":
|
||||
in_select = True
|
||||
elif in_select:
|
||||
select_toks.append(token)
|
||||
if token.value == "FROM":
|
||||
in_select = False
|
||||
|
||||
seq = []
|
||||
if len(sequence.tokens) > i + 2:
|
||||
seq = token_list_to_seq(
|
||||
select_toks + [sequence.tokens[i + 2]])
|
||||
|
||||
if seq not in current_subtrees and len(
|
||||
seq) > 0 and seq[0] in interesting_selects:
|
||||
current_subtrees.append(seq)
|
||||
|
||||
select_toks = []
|
||||
|
||||
# Recursively find subtrees in the children of the node.
|
||||
find_subtrees(token,
|
||||
current_subtrees,
|
||||
is_where,
|
||||
where_parent or keep_conj_subtrees)
|
||||
|
||||
# get_subtrees
|
||||
|
||||
|
||||
def get_subtrees(sql, oldsnippets=[]):
|
||||
parsed = sqlparse.parse(" ".join(sql))[0]
|
||||
|
||||
subtrees = []
|
||||
find_subtrees(parsed, subtrees)
|
||||
|
||||
final_subtrees = []
|
||||
for subtree in subtrees:
|
||||
if subtree not in ignored_subtrees:
|
||||
final_version = []
|
||||
keep = True
|
||||
|
||||
parens_counts = 0
|
||||
for i, token in enumerate(subtree):
|
||||
if token == ".":
|
||||
newtoken = final_version[-1] + "." + subtree[i + 1]
|
||||
final_version = final_version[:-1] + [newtoken]
|
||||
keep = False
|
||||
elif keep:
|
||||
final_version.append(token)
|
||||
else:
|
||||
keep = True
|
||||
|
||||
if token == "(":
|
||||
parens_counts -= 1
|
||||
elif token == ")":
|
||||
parens_counts += 1
|
||||
|
||||
if parens_counts == 0:
|
||||
final_subtrees.append(final_version)
|
||||
|
||||
snippets = []
|
||||
sql = [str(tok) for tok in sql]
|
||||
for subtree in final_subtrees:
|
||||
startpos = -1
|
||||
for i in range(len(sql) - len(subtree) + 1):
|
||||
if sql[i:i + len(subtree)] == subtree:
|
||||
startpos = i
|
||||
if startpos >= 0 and startpos + len(subtree) < len(sql):
|
||||
age = 0
|
||||
for prevsnippet in oldsnippets:
|
||||
if prevsnippet.sequence == subtree:
|
||||
age = prevsnippet.age + 1
|
||||
snippet = Snippet(subtree, startpos, sql, age=age)
|
||||
snippets.append(snippet)
|
||||
|
||||
return snippets
|
||||
|
||||
|
||||
def get_subtrees_simple(sql, oldsnippets=[]):
|
||||
sql_string = " ".join(sql)
|
||||
format_sql = sqlparse.format(sql_string, reindent=True)
|
||||
|
||||
# get subtrees
|
||||
subtrees = []
|
||||
for sub_sql in format_sql.split('\n'):
|
||||
sub_sql = sub_sql.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ')
|
||||
|
||||
subtree = sub_sql.strip().split()
|
||||
if len(subtree) > 1:
|
||||
subtrees.append(subtree)
|
||||
|
||||
final_subtrees = subtrees
|
||||
|
||||
snippets = []
|
||||
sql = [str(tok) for tok in sql]
|
||||
for subtree in final_subtrees:
|
||||
startpos = -1
|
||||
for i in range(len(sql) - len(subtree) + 1):
|
||||
if sql[i:i + len(subtree)] == subtree:
|
||||
startpos = i
|
||||
|
||||
if startpos >= 0 and startpos + len(subtree) <= len(sql):
|
||||
age = 0
|
||||
for prevsnippet in oldsnippets:
|
||||
if prevsnippet.sequence == subtree:
|
||||
age = prevsnippet.age + 1
|
||||
new_sql = sql + [';']
|
||||
snippet = Snippet(subtree, startpos, new_sql, age=age)
|
||||
snippets.append(snippet)
|
||||
|
||||
return snippets
|
||||
|
||||
|
||||
conjunctions = {"AND", "OR", "WHERE"}
|
||||
|
||||
|
||||
def get_all_in_parens(sequence):
|
||||
if sequence[-1] == ";":
|
||||
sequence = sequence[:-1]
|
||||
|
||||
if not "(" in sequence:
|
||||
return []
|
||||
|
||||
if sequence[0] == "(" and sequence[-1] == ")":
|
||||
in_parens = sequence[1:-1]
|
||||
return [in_parens] + get_all_in_parens(in_parens)
|
||||
else:
|
||||
paren_subseqs = []
|
||||
current_seq = []
|
||||
num_parens = 0
|
||||
in_parens = False
|
||||
for token in sequence:
|
||||
if in_parens:
|
||||
current_seq.append(token)
|
||||
if token == ")":
|
||||
num_parens -= 1
|
||||
if num_parens == 0:
|
||||
in_parens = False
|
||||
paren_subseqs.append(current_seq)
|
||||
current_seq = []
|
||||
elif token == "(":
|
||||
in_parens = True
|
||||
current_seq.append(token)
|
||||
if token == "(":
|
||||
num_parens += 1
|
||||
|
||||
all_subseqs = []
|
||||
for subseq in paren_subseqs:
|
||||
all_subseqs.extend(get_all_in_parens(subseq))
|
||||
return all_subseqs
|
||||
|
||||
|
||||
def split_by_conj(sequence):
|
||||
num_parens = 0
|
||||
current_seq = []
|
||||
subsequences = []
|
||||
|
||||
for token in sequence:
|
||||
if num_parens == 0:
|
||||
if token in conjunctions:
|
||||
subsequences.append(current_seq)
|
||||
current_seq = []
|
||||
break
|
||||
current_seq.append(token)
|
||||
if token == "(":
|
||||
num_parens += 1
|
||||
elif token == ")":
|
||||
num_parens -= 1
|
||||
|
||||
assert num_parens >= 0
|
||||
|
||||
return subsequences
|
||||
|
||||
|
||||
def get_sql_snippets(sequence):
|
||||
# First, get all subsequences of the sequence that are surrounded by
|
||||
# parentheses.
|
||||
all_in_parens = get_all_in_parens(sequence)
|
||||
all_subseq = []
|
||||
|
||||
# Then for each one, split the sequence on conjunctions (AND/OR).
|
||||
for seq in all_in_parens:
|
||||
subsequences = split_by_conj(seq)
|
||||
all_subseq.append(seq)
|
||||
all_subseq.extend(subsequences)
|
||||
|
||||
# Finally, also get "interesting" selects
|
||||
|
||||
for i, seq in enumerate(all_subseq):
|
||||
print(str(i) + "\t" + " ".join(seq))
|
||||
exit()
|
||||
|
||||
# add_snippets_to_query
|
||||
|
||||
|
||||
def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.):
|
||||
query_copy = copy.copy(query)
|
||||
|
||||
# Replace the longest snippets first, so sort by length descending.
|
||||
sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1]
|
||||
|
||||
for snippet in sorted_snippets:
|
||||
ignore = False
|
||||
snippet_seq = snippet.sequence
|
||||
|
||||
# TODO: continue here
|
||||
# If it contains an ignored entity, then don't use it.
|
||||
for entity in ignored_entities:
|
||||
ignore = ignore or util.subsequence(entity, snippet_seq)
|
||||
|
||||
# No NL entities found in snippet, then see if snippet is a substring of
|
||||
# the gold sequence
|
||||
if not ignore:
|
||||
snippet_length = len(snippet_seq)
|
||||
|
||||
# Iterate through gold sequence to see if it's a subsequence.
|
||||
for start_idx in range(len(query_copy) - snippet_length + 1):
|
||||
if query_copy[start_idx:start_idx +
|
||||
snippet_length] == snippet_seq:
|
||||
align = random.random() < prob_align
|
||||
|
||||
if align:
|
||||
prev_length = len(query_copy)
|
||||
|
||||
# At the start position of the snippet, replace with an
|
||||
# identifier.
|
||||
query_copy[start_idx] = snippet.name
|
||||
|
||||
# Then cut out the indices which were collapsed into
|
||||
# the snippet.
|
||||
query_copy = query_copy[:start_idx + 1] + \
|
||||
query_copy[start_idx + snippet_length:]
|
||||
|
||||
# Make sure the length is as expected
|
||||
assert len(query_copy) == prev_length - \
|
||||
(snippet_length - 1)
|
||||
|
||||
return query_copy
|
||||
|
||||
|
||||
def execution_results(query, username, password, timeout=3):
|
||||
connection = pymysql.connect(user=username, password=password)
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutException
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
|
||||
syntactic = True
|
||||
semantic = True
|
||||
|
||||
table = []
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
cursor.execute("SET sql_mode='IGNORE_SPACE';")
|
||||
cursor.execute("use atis3;")
|
||||
cursor.execute(query)
|
||||
table = cursor.fetchall()
|
||||
cursor.close()
|
||||
except TimeoutException:
|
||||
signal.alarm(0)
|
||||
cursor.close()
|
||||
except pymysql.err.ProgrammingError:
|
||||
syntactic = False
|
||||
semantic = False
|
||||
cursor.close()
|
||||
except pymysql.err.InternalError:
|
||||
semantic = False
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
signal.alarm(0)
|
||||
cursor.close()
|
||||
signal.alarm(0)
|
||||
|
||||
connection.close()
|
||||
|
||||
return (syntactic, semantic, sorted(table))
|
||||
|
||||
|
||||
def executable(query, username, password, timeout=2):
|
||||
return execution_results(query, username, password, timeout)[1]
|
||||
|
||||
|
||||
def fix_parentheses(sequence):
|
||||
num_left = sequence.count("(")
|
||||
num_right = sequence.count(")")
|
||||
|
||||
if num_right < num_left:
|
||||
fixed_sequence = sequence[:-1] + \
|
||||
[")" for _ in range(num_left - num_right)] + [sequence[-1]]
|
||||
return fixed_sequence
|
||||
|
||||
return sequence
|
|
@ -0,0 +1,82 @@
|
|||
"""Tokenizers for natural language SQL queries, and lambda calculus."""
|
||||
import nltk
|
||||
import sqlparse
|
||||
|
||||
def nl_tokenize(string):
|
||||
"""Tokenizes a natural language string into tokens.
|
||||
|
||||
Inputs:
|
||||
string: the string to tokenize.
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
|
||||
Assumes data is space-separated (this is true of ZC07 data in ATIS2/3).
|
||||
"""
|
||||
return nltk.word_tokenize(string)
|
||||
|
||||
def sql_tokenize(string):
|
||||
""" Tokenizes a SQL statement into tokens.
|
||||
|
||||
Inputs:
|
||||
string: string to tokenize.
|
||||
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
"""
|
||||
tokens = []
|
||||
statements = sqlparse.parse(string)
|
||||
|
||||
# SQLparse gives you a list of statements.
|
||||
for statement in statements:
|
||||
# Flatten the tokens in each statement and add to the tokens list.
|
||||
flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten()
|
||||
for token in flat_tokens:
|
||||
strip_token = str(token).strip()
|
||||
if len(strip_token) > 0:
|
||||
tokens.append(strip_token)
|
||||
|
||||
newtokens = []
|
||||
keep = True
|
||||
for i, token in enumerate(tokens):
|
||||
if token == ".":
|
||||
newtoken = newtokens[-1] + "." + tokens[i + 1]
|
||||
newtokens = newtokens[:-1] + [newtoken]
|
||||
keep = False
|
||||
elif keep:
|
||||
newtokens.append(token)
|
||||
else:
|
||||
keep = True
|
||||
|
||||
return newtokens
|
||||
|
||||
def lambda_tokenize(string):
|
||||
""" Tokenizes a lambda-calculus statement into tokens.
|
||||
|
||||
Inputs:
|
||||
string: a lambda-calculus string
|
||||
|
||||
Outputs:
|
||||
a list of tokens.
|
||||
"""
|
||||
|
||||
space_separated = string.split(" ")
|
||||
|
||||
new_tokens = []
|
||||
|
||||
# Separate the string by spaces, then separate based on existence of ( or
|
||||
# ).
|
||||
for token in space_separated:
|
||||
tokens = []
|
||||
|
||||
current_token = ""
|
||||
for char in token:
|
||||
if char == ")" or char == "(":
|
||||
tokens.append(current_token)
|
||||
tokens.append(char)
|
||||
current_token = ""
|
||||
else:
|
||||
current_token += char
|
||||
tokens.append(current_token)
|
||||
new_tokens.extend([tok for tok in tokens if tok])
|
||||
|
||||
return new_tokens
|
|
@ -0,0 +1,16 @@
|
|||
"""Contains various utility functions."""
|
||||
def subsequence(first_sequence, second_sequence):
|
||||
"""
|
||||
Returns whether the first sequence is a subsequence of the second sequence.
|
||||
|
||||
Inputs:
|
||||
first_sequence (list): A sequence.
|
||||
second_sequence (list): Another sequence.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether first_sequence is a subsequence of second_sequence.
|
||||
"""
|
||||
for startidx in range(len(second_sequence) - len(first_sequence) + 1):
|
||||
if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence:
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,115 @@
|
|||
""" Contains the Utterance class. """
|
||||
|
||||
from . import sql_util
|
||||
from . import tokenizers
|
||||
|
||||
ANON_INPUT_KEY = "cleaned_nl"
|
||||
OUTPUT_KEY = "sql"
|
||||
|
||||
class Utterance:
|
||||
""" Utterance class. """
|
||||
def process_input_seq(self,
|
||||
anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent):
|
||||
assert not anon_tok_to_ent or anonymize
|
||||
assert not anonymize or anonymizer
|
||||
|
||||
if anonymize:
|
||||
assert anonymizer
|
||||
|
||||
self.input_seq_to_use = anonymizer.anonymize(
|
||||
self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True)
|
||||
else:
|
||||
self.input_seq_to_use = self.original_input_seq
|
||||
|
||||
def process_gold_seq(self,
|
||||
output_sequences,
|
||||
nl_to_sql_dict,
|
||||
available_snippets,
|
||||
anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent):
|
||||
# Get entities in the input sequence:
|
||||
# anonymized entity types
|
||||
# othe recognized entities (this includes "flight")
|
||||
entities_in_input = [
|
||||
[tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent]
|
||||
entities_in_input.extend(
|
||||
nl_to_sql_dict.get_sql_entities(
|
||||
self.input_seq_to_use))
|
||||
|
||||
# Get the shortest gold query (this is what we use to train)
|
||||
shortest_gold_and_results = min(output_sequences,
|
||||
key=lambda x: len(x[0]))
|
||||
|
||||
# Tokenize and anonymize it if necessary.
|
||||
self.original_gold_query = shortest_gold_and_results[0]
|
||||
self.gold_sql_results = shortest_gold_and_results[1]
|
||||
|
||||
self.contained_entities = entities_in_input
|
||||
|
||||
# Keep track of all gold queries and the resulting tables so that we can
|
||||
# give credit if it predicts a different correct sequence.
|
||||
self.all_gold_queries = output_sequences
|
||||
|
||||
self.anonymized_gold_query = self.original_gold_query
|
||||
if anonymize:
|
||||
self.anonymized_gold_query = anonymizer.anonymize(
|
||||
self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False)
|
||||
|
||||
# Add snippets to it.
|
||||
self.gold_query_to_use = sql_util.add_snippets_to_query(
|
||||
available_snippets, entities_in_input, self.anonymized_gold_query)
|
||||
|
||||
def __init__(self,
|
||||
example,
|
||||
available_snippets,
|
||||
nl_to_sql_dict,
|
||||
params,
|
||||
anon_tok_to_ent={},
|
||||
anonymizer=None):
|
||||
# Get output and input sequences from the dictionary representation.
|
||||
output_sequences = example[OUTPUT_KEY]
|
||||
self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key])
|
||||
self.available_snippets = available_snippets
|
||||
self.keep = False
|
||||
|
||||
# pruned_output_sequences = []
|
||||
# for sequence in output_sequences:
|
||||
# if len(sequence[0]) > 3:
|
||||
# pruned_output_sequences.append(sequence)
|
||||
|
||||
# output_sequences = pruned_output_sequences
|
||||
if len(output_sequences) > 0 and len(self.original_input_seq) > 0:
|
||||
# Only keep this example if there is at least one output sequence.
|
||||
self.keep = True
|
||||
if len(output_sequences) == 0 or len(self.original_input_seq) == 0:
|
||||
return
|
||||
|
||||
# Process the input sequence
|
||||
self.process_input_seq(params.anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent)
|
||||
|
||||
# Process the gold sequence
|
||||
self.process_gold_seq(output_sequences,
|
||||
nl_to_sql_dict,
|
||||
self.available_snippets,
|
||||
params.anonymize,
|
||||
anonymizer,
|
||||
anon_tok_to_ent)
|
||||
|
||||
def __str__(self):
|
||||
string = "Original input: " + " ".join(self.original_input_seq) + "\n"
|
||||
string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n"
|
||||
string += "Original output: " + " ".join(self.original_gold_query) + "\n"
|
||||
string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n"
|
||||
string += "Snippets:\n"
|
||||
for snippet in self.available_snippets:
|
||||
string += str(snippet) + "\n"
|
||||
return string
|
||||
|
||||
def length_valid(self, input_limit, output_limit):
|
||||
return (len(self.input_seq_to_use) < input_limit \
|
||||
and len(self.gold_query_to_use) < output_limit)
|
|
@ -0,0 +1,98 @@
|
|||
"""Contains class and methods for storing and computing a vocabulary from text."""
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
|
||||
# Special sequencing tokens.
|
||||
UNK_TOK = "_UNK" # Replaces out-of-vocabulary words.
|
||||
EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end.
|
||||
DEL_TOK = ";"
|
||||
|
||||
|
||||
class Vocabulary:
|
||||
"""Vocabulary class: stores information about words in a corpus.
|
||||
|
||||
Members:
|
||||
functional_types (list of str): Functional vocabulary words, such as EOS.
|
||||
max_size (int): The maximum size of vocabulary to keep.
|
||||
min_occur (int): The minimum number of times a word should occur to keep it.
|
||||
id_to_token (list of str): Ordered list of word types.
|
||||
token_to_id (dict str->int): Maps from each unique word type to its index.
|
||||
"""
|
||||
def get_vocab(self, sequences, ignore_fn):
|
||||
"""Gets vocabulary from a list of sequences.
|
||||
|
||||
Inputs:
|
||||
sequences (list of list of str): Sequences from which to compute the vocabulary.
|
||||
ignore_fn (lambda str: bool): Function used to tell whether to ignore a
|
||||
token during computation of the vocabulary.
|
||||
|
||||
Returns:
|
||||
list of str, representing the unique word types in the vocabulary.
|
||||
"""
|
||||
type_counts = {}
|
||||
|
||||
for sequence in sequences:
|
||||
for token in sequence:
|
||||
if not ignore_fn(token):
|
||||
if token not in type_counts:
|
||||
type_counts[token] = 0
|
||||
type_counts[token] += 1
|
||||
|
||||
# Create sorted list of tokens, by their counts. Reverse so it is in order of
|
||||
# most frequent to least frequent.
|
||||
sorted_type_counts = sorted(sorted(type_counts.items()),
|
||||
key=operator.itemgetter(1))[::-1]
|
||||
|
||||
sorted_types = [typecount[0]
|
||||
for typecount in sorted_type_counts if typecount[1] >= self.min_occur]
|
||||
|
||||
# Append the necessary functional tokens.
|
||||
sorted_types = self.functional_types + sorted_types
|
||||
|
||||
# Cut off if vocab_size is set (nonnegative)
|
||||
if self.max_size >= 0:
|
||||
vocab = sorted_types[:max(self.max_size, len(sorted_types))]
|
||||
else:
|
||||
vocab = sorted_types
|
||||
|
||||
return vocab
|
||||
|
||||
def __init__(self,
|
||||
sequences,
|
||||
filename,
|
||||
functional_types=None,
|
||||
max_size=-1,
|
||||
min_occur=0,
|
||||
ignore_fn=lambda x: False):
|
||||
self.functional_types = functional_types
|
||||
self.max_size = max_size
|
||||
self.min_occur = min_occur
|
||||
|
||||
vocab = self.get_vocab(sequences, ignore_fn)
|
||||
|
||||
self.id_to_token = []
|
||||
self.token_to_id = {}
|
||||
|
||||
for i, word_type in enumerate(vocab):
|
||||
self.id_to_token.append(word_type)
|
||||
self.token_to_id[word_type] = i
|
||||
|
||||
# Load the previous vocab, if it exists.
|
||||
if os.path.exists(filename):
|
||||
infile = open(filename, 'rb')
|
||||
loaded_vocab = pickle.load(infile)
|
||||
infile.close()
|
||||
|
||||
print("Loaded vocabulary from " + str(filename))
|
||||
if loaded_vocab.id_to_token != self.id_to_token \
|
||||
or loaded_vocab.token_to_id != self.token_to_id:
|
||||
print("Loaded vocabulary is different than generated vocabulary.")
|
||||
else:
|
||||
print("Writing vocabulary to " + str(filename))
|
||||
outfile = open(filename, 'wb')
|
||||
pickle.dump(self, outfile)
|
||||
outfile.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.id_to_token)
|
|
@ -0,0 +1,36 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import sqlite3 as db
|
||||
import pprint
|
||||
|
||||
def process_schema(args):
|
||||
schema_name = args.db
|
||||
schema_path = os.path.join(args.root_path, schema_name, schema_name + '.sqlite')
|
||||
print("load db data from", schema_path)
|
||||
conn = db.connect(schema_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables_name = cur.fetchall()
|
||||
table_name_lst = [tuple[0] for tuple in tables_name]
|
||||
print(table_name_lst)
|
||||
for table_name in table_name_lst:
|
||||
cur.execute("SELECT * FROM %s" % table_name)
|
||||
col_name_list = [tuple[0] for tuple in cur.description]
|
||||
print(table_name, col_name_list)
|
||||
tables = cur.fetchall()
|
||||
#print(tables)
|
||||
|
||||
|
||||
def main(args):
|
||||
process_schema(args)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root_path", type=str, default="./data/database")
|
||||
parser.add_argument("--save_path", type=str, default="./database_content.json")
|
||||
parser.add_argument("--db", type=str, default="aircraft")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1 @@
|
|||
{"input": "northeast express regional airlines", "output": ["'2V'"]}
|
|
@ -0,0 +1,866 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
|
||||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]]
|
||||
label_wo_agg = [unit[2] for unit in label_conds]
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
|
||||
for _, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
if len(label['from']['table_units']) > 0:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
return label_tables == pred_tables
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold) as f:
|
||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict) as f:
|
||||
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
|
||||
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
|
||||
evaluator = Evaluator()
|
||||
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
p_str = p[0]
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table) as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
|
@ -0,0 +1,938 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
|
||||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]]
|
||||
label_wo_agg = [unit[2] for unit in label_conds]
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
|
||||
for key, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
|
||||
if len(label['from']['table_units']) > 0:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
return label_tables == pred_tables
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all', "joint_all"]
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
print("\n\n{:20} {:20} {:20} {:20} {:20} {:20}".format("", *turns))
|
||||
counts = [scores[turn]['count'] for turn in turns]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== TRUN XECUTION ACCURACY =====================')
|
||||
this_scores = [scores[turn]['exec'] for turn in turns]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== TRUN EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[turn]['exact'] for turn in turns]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold) as f:
|
||||
glist = []
|
||||
gseq_one = []
|
||||
for l in f.readlines():
|
||||
if len(l.strip()) == 0:
|
||||
glist.append(gseq_one)
|
||||
gseq_one = []
|
||||
else:
|
||||
lstrip = l.strip().split('\t')
|
||||
gseq_one.append(lstrip)
|
||||
#glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict) as f:
|
||||
plist = []
|
||||
pseq_one = []
|
||||
for l in f.readlines():
|
||||
if len(l.strip()) == 0:
|
||||
plist.append(pseq_one)
|
||||
pseq_one = []
|
||||
else:
|
||||
pseq_one.append(l.strip().split('\t'))
|
||||
#plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
|
||||
# glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
|
||||
evaluator = Evaluator()
|
||||
|
||||
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for turn in turns:
|
||||
scores[turn] = {'count': 0, 'exact': 0.}
|
||||
scores[turn]['exec'] = 0
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
scores['joint_all']['count'] += 1
|
||||
turn_scores = {"exec": [], "exact": []}
|
||||
for idx, pg in enumerate(zip(p, g)):
|
||||
p, g = pg
|
||||
p_str = p[0]
|
||||
p_str = p_str.replace("value", "1")
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
if idx > 3:
|
||||
idx = ">4"
|
||||
else:
|
||||
idx += 1
|
||||
turn_id = "turn " + str(idx)
|
||||
scores[turn_id]['count'] += 1
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
scores[turn_id]['exec'] += 1
|
||||
turn_scores['exec'].append(1)
|
||||
else:
|
||||
turn_scores['exec'].append(0)
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
turn_scores['exact'].append(0)
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
else:
|
||||
turn_scores['exact'].append(1)
|
||||
scores[turn_id]['exact'] += exact_score
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
if all(v == 1 for v in turn_scores["exec"]):
|
||||
scores['joint_all']['exec'] += 1
|
||||
|
||||
if all(v == 1 for v in turn_scores["exact"]):
|
||||
scores['joint_all']['exact'] += 1
|
||||
|
||||
for turn in turns:
|
||||
if scores[turn]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[turn]['exec'] /= scores[turn]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[turn]['exact'] /= scores[turn]['count']
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table) as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,53 @@
|
|||
import json
|
||||
import sys
|
||||
|
||||
predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line]
|
||||
|
||||
string_count = 0.
|
||||
sem_count = 0.
|
||||
syn_count = 0.
|
||||
table_count = 0.
|
||||
strict_table_count = 0.
|
||||
|
||||
precision_denom = 0.
|
||||
precision = 0.
|
||||
recall_denom = 0.
|
||||
recall = 0.
|
||||
f1_score = 0.
|
||||
f1_denom = 0.
|
||||
|
||||
time = 0.
|
||||
|
||||
for prediction in predictions:
|
||||
if prediction["correct_string"]:
|
||||
string_count += 1.
|
||||
if prediction["semantic"]:
|
||||
sem_count += 1.
|
||||
if prediction["syntactic"]:
|
||||
syn_count += 1.
|
||||
if prediction["correct_table"]:
|
||||
table_count += 1.
|
||||
if prediction["strict_correct_table"]:
|
||||
strict_table_count += 1.
|
||||
if prediction["gold_tables"] !="[[]]":
|
||||
precision += prediction["table_prec"]
|
||||
precision_denom += 1
|
||||
if prediction["pred_table"] != "[]":
|
||||
recall += prediction["table_rec"]
|
||||
recall_denom += 1
|
||||
|
||||
if prediction["gold_tables"] != "[[]]":
|
||||
f1_score += prediction["table_f1"]
|
||||
f1_denom += 1
|
||||
|
||||
num_p = len(predictions)
|
||||
print("string precision: " + str(string_count / num_p))
|
||||
print("% semantic: " + str(sem_count / num_p))
|
||||
print("% syntactic: " + str(syn_count / num_p))
|
||||
print("table prec: " + str(table_count / num_p))
|
||||
print("strict table prec: " + str(strict_table_count / num_p))
|
||||
print("table row prec: " + str(precision / precision_denom))
|
||||
print("table row recall: " + str(recall / recall_denom))
|
||||
print("table row f1: " + str(f1_score / f1_denom))
|
||||
print("inference time: " + str(time / num_p))
|
||||
|
|
@ -0,0 +1,569 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from nltk import word_tokenize
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema):
|
||||
self._schema = schema
|
||||
self._idMap = self._map(self._schema)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema):
|
||||
idMap = {'*': "__all__"}
|
||||
id = 1
|
||||
for key, vals in schema.items():
|
||||
for val in vals:
|
||||
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
|
||||
id += 1
|
||||
|
||||
for key in schema:
|
||||
idMap[key.lower()] = "__" + key.lower() + "__"
|
||||
id += 1
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schema(db):
|
||||
"""
|
||||
Get database's schema, which is a dict with table name as key
|
||||
and list of column names as value
|
||||
:param db: database path
|
||||
:return: schema dict
|
||||
"""
|
||||
|
||||
schema = {}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# fetch table names
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [str(table[0].lower()) for table in cursor.fetchall()]
|
||||
|
||||
# fetch table info
|
||||
for table in tables:
|
||||
cursor.execute("PRAGMA table_info({})".format(table))
|
||||
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_schema_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
|
||||
schema = {}
|
||||
for entry in data:
|
||||
table = str(entry['table'].lower())
|
||||
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
|
||||
schema[table] = cols
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def tokenize(string):
|
||||
string = str(string)
|
||||
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
|
||||
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
|
||||
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
|
||||
|
||||
# keep string value as token
|
||||
vals = {}
|
||||
for i in range(len(quote_idxs)-1, -1, -2):
|
||||
qidx1 = quote_idxs[i-1]
|
||||
qidx2 = quote_idxs[i]
|
||||
val = string[qidx1: qidx2+1]
|
||||
key = "__val_{}_{}__".format(qidx1, qidx2)
|
||||
string = string[:qidx1] + key + string[qidx2+1:]
|
||||
vals[key] = val
|
||||
|
||||
toks = [word.lower() for word in word_tokenize(string)]
|
||||
# replace with string value token
|
||||
for i in range(len(toks)):
|
||||
if toks[i] in vals:
|
||||
toks[i] = vals[toks[i]]
|
||||
|
||||
# find if there exists !=, >=, <=
|
||||
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
|
||||
eq_idxs.reverse()
|
||||
prefix = ('!', '>', '<')
|
||||
for eq_idx in eq_idxs:
|
||||
pre_tok = toks[eq_idx-1]
|
||||
if pre_tok in prefix:
|
||||
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def scan_alias(toks):
|
||||
"""Scan the index of 'as' and build the map for all alias"""
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
alias[toks[idx+1]] = toks[idx-1]
|
||||
return alias
|
||||
|
||||
|
||||
def get_tables_with_alias(schema, toks):
|
||||
tables = scan_alias(toks)
|
||||
for key in schema:
|
||||
assert key not in tables, "Alias {} has the same name in table".format(key)
|
||||
tables[key] = key
|
||||
return tables
|
||||
|
||||
|
||||
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, column id
|
||||
"""
|
||||
tok = toks[start_idx]
|
||||
if tok == "*":
|
||||
return start_idx + 1, schema.idMap[tok]
|
||||
|
||||
if '.' in tok: # if token is a composite
|
||||
alias, col = tok.split('.')
|
||||
key = tables_with_alias[alias] + "." + col
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
|
||||
|
||||
for alias in default_tables:
|
||||
table = tables_with_alias[alias]
|
||||
if tok in schema.schema[table]:
|
||||
key = table + "." + tok
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert False, "Error col: {}".format(tok)
|
||||
|
||||
|
||||
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, (agg_op id, col_id)
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
isDistinct = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
assert idx < len_ and toks[idx] == '('
|
||||
idx += 1
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert idx < len_ and toks[idx] == ')'
|
||||
idx += 1
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
agg_id = AGG_OPS.index("none")
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
|
||||
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
col_unit1 = None
|
||||
col_unit2 = None
|
||||
unit_op = UNIT_OPS.index('none')
|
||||
|
||||
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if idx < len_ and toks[idx] in UNIT_OPS:
|
||||
unit_op = UNIT_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (unit_op, col_unit1, col_unit2)
|
||||
|
||||
|
||||
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
:returns next idx, table id, table name
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
key = tables_with_alias[toks[idx]]
|
||||
|
||||
if idx + 1 < len_ and toks[idx+1] == "as":
|
||||
idx += 3
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
return idx, schema.idMap[key], key
|
||||
|
||||
|
||||
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
elif "\"" in toks[idx]: # token is a string value
|
||||
val = toks[idx]
|
||||
idx += 1
|
||||
else:
|
||||
try:
|
||||
val = float(toks[idx])
|
||||
idx += 1
|
||||
except:
|
||||
end_idx = idx
|
||||
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
|
||||
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
|
||||
end_idx += 1
|
||||
|
||||
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
|
||||
idx = end_idx
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
|
||||
return idx, val
|
||||
|
||||
|
||||
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
not_op = False
|
||||
if toks[idx] == 'not':
|
||||
not_op = True
|
||||
idx += 1
|
||||
|
||||
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
|
||||
op_id = WHERE_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
val1 = val2 = None
|
||||
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert toks[idx] == 'and'
|
||||
idx += 1
|
||||
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
else: # normal case: single value
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val2 = None
|
||||
|
||||
conds.append((not_op, op_id, val_unit, val1, val2))
|
||||
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
|
||||
break
|
||||
|
||||
if idx < len_ and toks[idx] in COND_OPS:
|
||||
conds.append(toks[idx])
|
||||
idx += 1 # skip and/or
|
||||
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
assert toks[idx] == 'select', "'select' not found"
|
||||
idx += 1
|
||||
isDistinct = False
|
||||
if idx < len_ and toks[idx] == 'distinct':
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
val_units = []
|
||||
|
||||
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
|
||||
agg_id = AGG_OPS.index("none")
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append((agg_id, val_unit))
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
|
||||
return idx, (isDistinct, val_units)
|
||||
|
||||
|
||||
def parse_from(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
Assume in the from clause, all table units are combined with join
|
||||
"""
|
||||
assert 'from' in toks[start_idx:], "'from' not found"
|
||||
|
||||
len_ = len(toks)
|
||||
idx = toks.index('from', start_idx) + 1
|
||||
default_tables = []
|
||||
table_units = []
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['sql'], sql))
|
||||
else:
|
||||
if idx < len_ and toks[idx] == 'join':
|
||||
idx += 1 # skip join
|
||||
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['table_unit'],table_unit))
|
||||
default_tables.append(table_name)
|
||||
if idx < len_ and toks[idx] == "on":
|
||||
idx += 1 # skip on
|
||||
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if len(conds) > 0:
|
||||
conds.append('and')
|
||||
conds.extend(this_conds)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
break
|
||||
|
||||
return idx, table_units, conds, default_tables
|
||||
|
||||
|
||||
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'where':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
col_units = []
|
||||
|
||||
if idx >= len_ or toks[idx] != 'group':
|
||||
return idx, col_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
col_units.append(col_unit)
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, col_units
|
||||
|
||||
|
||||
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
val_units = []
|
||||
order_type = 'asc' # default type is 'asc'
|
||||
|
||||
if idx >= len_ or toks[idx] != 'order':
|
||||
return idx, val_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append(val_unit)
|
||||
if idx < len_ and toks[idx] in ORDER_OPS:
|
||||
order_type = toks[idx]
|
||||
idx += 1
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, (order_type, val_units)
|
||||
|
||||
|
||||
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'having':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_limit(toks, start_idx):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx < len_ and toks[idx] == 'limit':
|
||||
idx += 2
|
||||
# make limit value can work, cannot assume put 1 as a fake limit number
|
||||
if type(toks[idx-1]) != int:
|
||||
return idx, 1
|
||||
|
||||
return idx, int(toks[idx-1])
|
||||
|
||||
return idx, None
|
||||
|
||||
|
||||
def parse_sql(toks, start_idx, tables_with_alias, schema):
|
||||
isBlock = False # indicate whether this is a block of sql/sub-sql
|
||||
len_ = len(toks)
|
||||
idx = start_idx
|
||||
|
||||
sql = {}
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
# parse from clause in order to get default tables
|
||||
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
|
||||
sql['from'] = {'table_units': table_units, 'conds': conds}
|
||||
# select clause
|
||||
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
|
||||
idx = from_end_idx
|
||||
sql['select'] = select_col_units
|
||||
# where clause
|
||||
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['where'] = where_conds
|
||||
# group by clause
|
||||
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['groupBy'] = group_col_units
|
||||
# having clause
|
||||
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['having'] = having_conds
|
||||
# order by clause
|
||||
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['orderBy'] = order_col_units
|
||||
# limit clause
|
||||
idx, limit_val = parse_limit(toks, idx)
|
||||
sql['limit'] = limit_val
|
||||
|
||||
idx = skip_semicolon(toks, idx)
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
idx = skip_semicolon(toks, idx)
|
||||
|
||||
# intersect/union/except clause
|
||||
for op in SQL_OPS: # initialize IUE
|
||||
sql[op] = None
|
||||
if idx < len_ and toks[idx] in SQL_OPS:
|
||||
sql_op = toks[idx]
|
||||
idx += 1
|
||||
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
sql[sql_op] = IUE_sql
|
||||
return idx, sql
|
||||
|
||||
|
||||
def load_data(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def skip_semicolon(toks, start_idx):
|
||||
idx = start_idx
|
||||
while idx < len(toks) and toks[idx] == ";":
|
||||
idx += 1
|
||||
return idx
|
|
@ -0,0 +1,58 @@
|
|||
"""Contains the logging class."""
|
||||
|
||||
class Logger():
|
||||
"""Attributes:
|
||||
|
||||
fileptr (file): File pointer for input/output.
|
||||
lines (list of str): The lines read from the log.
|
||||
"""
|
||||
def __init__(self, filename, option):
|
||||
self.fileptr = open(filename, option)
|
||||
if option == "r":
|
||||
self.lines = self.fileptr.readlines()
|
||||
else:
|
||||
self.lines = []
|
||||
|
||||
def put(self, string):
|
||||
"""Writes to the file."""
|
||||
self.fileptr.write(string + "\n")
|
||||
self.fileptr.flush()
|
||||
|
||||
def close(self):
|
||||
"""Closes the logger."""
|
||||
self.fileptr.close()
|
||||
|
||||
def findlast(self, identifier, default=0.):
|
||||
"""Finds the last line in the log with a certain value."""
|
||||
for line in self.lines[::-1]:
|
||||
if line.lower().startswith(identifier):
|
||||
string = line.strip().split("\t")[1]
|
||||
if string.replace(".", "").isdigit():
|
||||
return float(string)
|
||||
elif string.lower() == "true":
|
||||
return True
|
||||
elif string.lower() == "false":
|
||||
return False
|
||||
else:
|
||||
return string
|
||||
return default
|
||||
|
||||
def contains(self, string):
|
||||
"""Dtermines whether the string is present in the log."""
|
||||
for line in self.lines[::-1]:
|
||||
if string.lower() in line.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
def findlast_log_before(self, before_str):
|
||||
"""Finds the last entry in the log before another entry."""
|
||||
loglines = []
|
||||
in_line = False
|
||||
for line in self.lines[::-1]:
|
||||
if line.startswith(before_str):
|
||||
in_line = True
|
||||
elif in_line:
|
||||
loglines.append(line)
|
||||
if line.strip() == "" and in_line:
|
||||
return "".join(loglines[::-1])
|
||||
return "".join(loglines[::-1])
|
|
@ -0,0 +1,654 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import json
|
||||
import shutil
|
||||
import sqlparse
|
||||
from postprocess_eval import get_candidate_tables
|
||||
|
||||
|
||||
def write_interaction(interaction_list,split,output_dir):
|
||||
json_split = os.path.join(output_dir,split+'.json')
|
||||
pkl_split = os.path.join(output_dir,split+'.pkl')
|
||||
|
||||
with open(json_split, 'w') as outfile:
|
||||
for interaction in interaction_list:
|
||||
json.dump(interaction, outfile, indent = 4)
|
||||
outfile.write('\n')
|
||||
|
||||
new_objs = []
|
||||
for i, obj in enumerate(interaction_list):
|
||||
new_interaction = []
|
||||
for ut in obj["interaction"]:
|
||||
sql = ut["sql"]
|
||||
sqls = [sql]
|
||||
tok_sql_list = []
|
||||
for sql in sqls:
|
||||
results = []
|
||||
tokenized_sql = sql.split()
|
||||
tok_sql_list.append((tokenized_sql, results))
|
||||
ut["sql"] = tok_sql_list
|
||||
new_interaction.append(ut)
|
||||
obj["interaction"] = new_interaction
|
||||
new_objs.append(obj)
|
||||
|
||||
with open(pkl_split,'wb') as outfile:
|
||||
pickle.dump(new_objs, outfile)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas_dict):
|
||||
with open(database_schema_filename) as f:
|
||||
database_schemas = json.load(f)
|
||||
|
||||
def get_schema_tokens(table_schema):
|
||||
column_names_surface_form = []
|
||||
column_names = []
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name,column_name)
|
||||
else:
|
||||
# this is just *
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
column_names.append(column_name.lower())
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
|
||||
return column_names_surface_form, column_names
|
||||
|
||||
for table_schema in database_schemas:
|
||||
database_id = table_schema['db_id']
|
||||
database_schemas_dict[database_id] = table_schema
|
||||
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
|
||||
|
||||
return schema_tokens, column_names, database_schemas_dict
|
||||
|
||||
|
||||
def remove_from_with_join(format_sql_2):
|
||||
used_tables_list = []
|
||||
format_sql_3 = []
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
old_table_to_name = {}
|
||||
old_table_list = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
old_table_list = table_list
|
||||
old_table_to_name = table_to_name
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
format_sql_3.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
new_sub_sql = None
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
for t_i, t in enumerate(sub_sql_tokens):
|
||||
if t == 'as':
|
||||
table_to_name[sub_sql_tokens[t_i+1]] = sub_sql_tokens[t_i-1]
|
||||
table_list.append(sub_sql_tokens[t_i-1])
|
||||
elif t == ')' and new_sub_sql is None:
|
||||
# new_sub_sql keeps some trailing parts after ')'
|
||||
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
|
||||
if len(table_list) > 0:
|
||||
# if it's a from clause with join
|
||||
if new_sub_sql is not None:
|
||||
format_sql_3.append(new_sub_sql)
|
||||
|
||||
used_tables_list.append(table_list)
|
||||
else:
|
||||
# if it's a from clause without join
|
||||
table_list = old_table_list
|
||||
table_to_name = old_table_to_name
|
||||
assert 'join' not in sub_sql
|
||||
if new_sub_sql is not None:
|
||||
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
|
||||
assert len(sub_sub_sql.split()) == 2
|
||||
used_tables_list.append([sub_sub_sql.split()[1]])
|
||||
format_sql_3.append(sub_sub_sql)
|
||||
format_sql_3.append(new_sub_sql)
|
||||
elif 'join' not in sub_sql:
|
||||
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
|
||||
if len(sub_sql.split()) == 2:
|
||||
used_tables_list.append([sub_sql.split()[1]])
|
||||
|
||||
format_sql_3.append(sub_sql)
|
||||
else:
|
||||
print('bad from clause in remove_from_with_join')
|
||||
exit()
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
used_tables = []
|
||||
for t in used_tables_list:
|
||||
for tt in t:
|
||||
used_tables.append(tt)
|
||||
used_tables = list(set(used_tables))
|
||||
|
||||
return format_sql_3, used_tables, used_tables_list
|
||||
|
||||
|
||||
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
|
||||
format_sql_4 = []
|
||||
table_name = None
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
format_sql_4.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
if len(sub_sql_tokens) == 1:
|
||||
table_name = None
|
||||
elif len(sub_sql_tokens) == 2:
|
||||
table_name = sub_sql_tokens[1]
|
||||
else:
|
||||
print('bad from clause in remove_from_without_join')
|
||||
print(format_sql_3)
|
||||
exit()
|
||||
else:
|
||||
format_sql_4.append(sub_sql)
|
||||
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
|
||||
# If just one table used, easy case, replace all column_name -> table_name.column_name
|
||||
if len(used_tables) == 1:
|
||||
table_name = used_tables[0]
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
if '{}.{}'.format(table_name,token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name,token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
return format_sql_4
|
||||
|
||||
def get_table_name_for(token):
|
||||
table_names = []
|
||||
for table_name in used_tables:
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
table_names.append(table_name)
|
||||
if len(table_names) == 0:
|
||||
return 'table'
|
||||
if len(table_names) > 1:
|
||||
return None
|
||||
else:
|
||||
return table_names[0]
|
||||
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
# skip *
|
||||
if token == '*':
|
||||
continue
|
||||
if token in column_names and tokens[ii-1] != '.':
|
||||
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
|
||||
table_name = get_table_name_for(token)
|
||||
if table_name:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def check_oov(format_sql_final, output_vocab, schema_tokens):
|
||||
for sql_tok in format_sql_final.split():
|
||||
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
|
||||
print('OOV!', sql_tok)
|
||||
raise Exception('OOV')
|
||||
|
||||
|
||||
def normalize_space(format_sql):
|
||||
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('.',' . ').replace('(',' ( ').replace(')',' ) ').split()) for sub_sql in format_sql.split('\n')]
|
||||
format_sql_1 = '\n'.join(format_sql_1)
|
||||
|
||||
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor', 'select constructor').replace('select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork', 'select work').replace('select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder', 'select order').replace('from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor', 'from for').replace('from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace('.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace('.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor', '. constructor').replace('.\ngroup', '. group').replace('.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
|
||||
return format_sql_2
|
||||
|
||||
|
||||
def normalize_final_sql(format_sql_5):
|
||||
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
|
||||
|
||||
# normalize two bad sqls
|
||||
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
|
||||
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
|
||||
|
||||
# This is the failure case of remove_from_without_join()
|
||||
format_sql_final = format_sql_final.replace('select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value', 'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def parse_sql(sql_string, db_id, column_names, output_vocab, schema_tokens, schema):
|
||||
format_sql = sqlparse.format(sql_string, reindent=True)
|
||||
format_sql_2 = normalize_space(format_sql)
|
||||
|
||||
num_from = sum([1 for sub_sql in format_sql_2.split('\n') if sub_sql.startswith('from')])
|
||||
num_select = format_sql_2.count('select ') + format_sql_2.count('select\n')
|
||||
|
||||
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
|
||||
|
||||
format_sql_3 = '\n'.join(format_sql_3)
|
||||
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
|
||||
|
||||
format_sql_4 = '\n'.join(format_sql_4)
|
||||
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
|
||||
|
||||
format_sql_5 = '\n'.join(format_sql_5)
|
||||
format_sql_final = normalize_final_sql(format_sql_5)
|
||||
|
||||
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_final, schema)
|
||||
|
||||
failure = False
|
||||
if len(candidate_tables_id) != len(used_tables):
|
||||
failure = True
|
||||
|
||||
check_oov(format_sql_final, output_vocab, schema_tokens)
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def read_spider_split(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
with open(split_json) as f:
|
||||
split_data = json.load(f)
|
||||
print('read_spider_split', split_json, len(split_data))
|
||||
|
||||
for i, ex in enumerate(split_data):
|
||||
db_id = ex['db_id']
|
||||
|
||||
final_sql = []
|
||||
skip = False
|
||||
for query_tok in ex['query_toks_no_value']:
|
||||
if query_tok != '.' and '.' in query_tok:
|
||||
# invalid sql; didn't use table alias in join
|
||||
final_sql += query_tok.replace('.',' . ').split()
|
||||
skip = True
|
||||
else:
|
||||
final_sql.append(query_tok)
|
||||
final_sql = ' '.join(final_sql)
|
||||
|
||||
if skip and 'train' in split_json:
|
||||
continue
|
||||
|
||||
if remove_from:
|
||||
final_sql_parse = parse_sql(final_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
|
||||
else:
|
||||
final_sql_parse = final_sql
|
||||
|
||||
final_utterance = ' '.join(ex['question_toks'])
|
||||
|
||||
if db_id not in interaction_list:
|
||||
interaction_list[db_id] = []
|
||||
|
||||
interaction = {}
|
||||
interaction['id'] = ''
|
||||
interaction['scenario'] = ''
|
||||
interaction['database_id'] = db_id
|
||||
interaction['interaction_id'] = len(interaction_list[db_id])
|
||||
interaction['final'] = {}
|
||||
interaction['final']['utterance'] = final_utterance
|
||||
interaction['final']['sql'] = final_sql_parse
|
||||
interaction['interaction'] = [{'utterance': final_utterance, 'sql': final_sql_parse}]
|
||||
interaction_list[db_id].append(interaction)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_data_json(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
with open(split_json) as f:
|
||||
split_data = json.load(f)
|
||||
print('read_data_json', split_json, len(split_data))
|
||||
|
||||
for interaction_data in split_data:
|
||||
db_id = interaction_data['database_id']
|
||||
final_sql = interaction_data['final']['query']
|
||||
final_utterance = interaction_data['final']['utterance']
|
||||
|
||||
if db_id not in interaction_list:
|
||||
interaction_list[db_id] = []
|
||||
|
||||
# no interaction_id in train
|
||||
if 'interaction_id' in interaction_data['interaction']:
|
||||
interaction_id = interaction_data['interaction']['interaction_id']
|
||||
else:
|
||||
interaction_id = len(interaction_list[db_id])
|
||||
|
||||
interaction = {}
|
||||
interaction['id'] = ''
|
||||
interaction['scenario'] = ''
|
||||
interaction['database_id'] = db_id
|
||||
interaction['interaction_id'] = interaction_id
|
||||
interaction['final'] = {}
|
||||
interaction['final']['utterance'] = final_utterance
|
||||
interaction['final']['sql'] = final_sql
|
||||
interaction['interaction'] = []
|
||||
|
||||
for turn in interaction_data['interaction']:
|
||||
turn_sql = []
|
||||
skip = False
|
||||
for query_tok in turn['query_toks_no_value']:
|
||||
if query_tok != '.' and '.' in query_tok:
|
||||
# invalid sql; didn't use table alias in join
|
||||
turn_sql += query_tok.replace('.',' . ').split()
|
||||
skip = True
|
||||
else:
|
||||
turn_sql.append(query_tok)
|
||||
turn_sql = ' '.join(turn_sql)
|
||||
|
||||
# Correct some human sql annotation error
|
||||
turn_sql = turn_sql.replace('select f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id', 'select t1 . f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id')
|
||||
turn_sql = turn_sql.replace('select name from climber mountain', 'select name from climber')
|
||||
turn_sql = turn_sql.replace('select sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid', 'select t1 . sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid')
|
||||
turn_sql = turn_sql.replace('select avg ( price ) from goods )', 'select avg ( price ) from goods')
|
||||
turn_sql = turn_sql.replace('select min ( annual_fuel_cost ) , from vehicles', 'select min ( annual_fuel_cost ) from vehicles')
|
||||
turn_sql = turn_sql.replace('select * from goods where price < ( select avg ( price ) from goods', 'select * from goods where price < ( select avg ( price ) from goods )')
|
||||
turn_sql = turn_sql.replace('select distinct id , price from goods where price < ( select avg ( price ) from goods', 'select distinct id , price from goods where price < ( select avg ( price ) from goods )')
|
||||
turn_sql = turn_sql.replace('select id from goods where price > ( select avg ( price ) from goods', 'select id from goods where price > ( select avg ( price ) from goods )')
|
||||
|
||||
if skip and 'train' in split_json:
|
||||
continue
|
||||
|
||||
if remove_from:
|
||||
try:
|
||||
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
|
||||
except:
|
||||
print('continue')
|
||||
continue
|
||||
else:
|
||||
turn_sql_parse = turn_sql
|
||||
|
||||
if 'utterance_toks' in turn:
|
||||
turn_utterance = ' '.join(turn['utterance_toks']) # not lower()
|
||||
else:
|
||||
turn_utterance = turn['utterance']
|
||||
|
||||
interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
|
||||
if len(interaction['interaction']) > 0:
|
||||
interaction_list[db_id].append(interaction)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(spider_dir, 'train.json')
|
||||
interaction_list = read_spider_split(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(spider_dir, 'dev.json')
|
||||
interaction_list = read_spider_split(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(sparc_dir, 'train_no_value.json')
|
||||
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(sparc_dir, 'dev_no_value.json')
|
||||
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
|
||||
interaction_list = {}
|
||||
|
||||
train_json = os.path.join(cosql_dir, 'train.json')
|
||||
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
dev_json = os.path.join(cosql_dir, 'dev.json')
|
||||
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def read_db_split(data_dir):
|
||||
train_database = []
|
||||
with open(os.path.join(data_dir,'train_db_ids.txt')) as f:
|
||||
for line in f:
|
||||
train_database.append(line.strip())
|
||||
|
||||
dev_database = []
|
||||
with open(os.path.join(data_dir,'dev_db_ids.txt')) as f:
|
||||
for line in f:
|
||||
dev_database.append(line.strip())
|
||||
|
||||
return train_database, dev_database
|
||||
|
||||
|
||||
def preprocess(datasets, remove_from=False):
|
||||
# Validate output_vocab
|
||||
output_vocab = ['_UNK', '_EOS', '.', 't1', 't2', '=', 'select', 'from', 'as', 'value', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'count', 'group', 'order', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
|
||||
if remove_from:
|
||||
output_vocab = ['_UNK', '_EOS', '=', 'select', 'value', ')', '(', 'where', ',', 'count', 'group_by', 'order_by', 'distinct', 'and', 'limit_value', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!=', 'union', 'between', '-', '+', '/']
|
||||
print('size of output_vocab', len(output_vocab))
|
||||
print('output_vocab', output_vocab)
|
||||
print()
|
||||
|
||||
print('datasets', datasets)
|
||||
|
||||
if 'spider' in datasets:
|
||||
spider_dir = 'data/spider/'
|
||||
database_schema_filename = 'data/spider/tables.json'
|
||||
output_dir = 'data/spider_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/spider_data_removefrom'
|
||||
spider_train_database, _ = read_db_split(spider_dir)
|
||||
if 'sparc' in datasets:
|
||||
sparc_dir = 'data/sparc/'
|
||||
database_schema_filename = 'data/sparc/tables.json'
|
||||
output_dir = 'data/sparc_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/sparc_data_removefrom'
|
||||
sparc_train_database, dev_database = read_db_split(sparc_dir)
|
||||
if 'cosql' in datasets:
|
||||
cosql_dir = 'data/cosql/'
|
||||
database_schema_filename = 'data/cosql/tables.json'
|
||||
output_dir = 'data/cosql_data'
|
||||
if remove_from:
|
||||
output_dir = 'data/cosql_data_removefrom'
|
||||
cosql_train_database, _ = read_db_split(cosql_dir)
|
||||
|
||||
#print('sparc_train_database', sparc_train_database)
|
||||
#print('cosql_train_database', cosql_train_database)
|
||||
#print('spider_train_database', spider_train_database)
|
||||
#print('dev_database', dev_database)
|
||||
|
||||
#for x in dev_database:
|
||||
# if x in sparc_train_database:
|
||||
# print('x1', x)
|
||||
# if x in spider_train_database:
|
||||
# print('x2', x)
|
||||
# if x in cosql_train_database:
|
||||
# print('x3', x)
|
||||
|
||||
output_dir = './data/merge_data_removefrom'
|
||||
|
||||
|
||||
if os.path.isdir(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
os.mkdir(output_dir)
|
||||
|
||||
schema_tokens = {}
|
||||
column_names = {}
|
||||
database_schemas = {}
|
||||
#assert False
|
||||
|
||||
print('Reading spider database schema file')
|
||||
schema_tokens = dict()
|
||||
column_names = dict()
|
||||
database_schemas = dict()
|
||||
assert remove_from
|
||||
for file in ['data/sparc_data_removefrom', 'data/cosql_data_removefrom', 'data/sparc_data_removefrom']:
|
||||
s, c, d = read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas)
|
||||
schema_tokens.update(s)
|
||||
column_names.update(c)
|
||||
database_schemas.update(d)
|
||||
|
||||
#x = 0
|
||||
#for k in schema_tokens.keys():
|
||||
# x = k
|
||||
#print('schema_tokens', schema_tokens[x], len(schema_tokens) )
|
||||
#print('column_names', column_names[x], len(column_names) )
|
||||
#print('database_schemas', database_schemas[x], len(database_schemas) )
|
||||
#assert False
|
||||
num_database = len(schema_tokens)
|
||||
print('num_database', num_database)
|
||||
print('total number of schema_tokens / databases:', len(schema_tokens))
|
||||
|
||||
#output_dir = 'data/merge_data_removefrom'
|
||||
output_database_schema_filename = os.path.join(output_dir, 'tables.json')
|
||||
#print("output_database_schema_filename", output_database_schema_filename)
|
||||
with open(output_database_schema_filename, 'w') as outfile:
|
||||
json.dump([v for k,v in database_schemas.items()], outfile, indent=4)
|
||||
#assert False
|
||||
if 'spider' in datasets:
|
||||
interaction_list1 = read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
if 'sparc' in datasets:
|
||||
interaction_list2 = read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
if 'cosql' in datasets:
|
||||
interaction_list3 = read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
|
||||
|
||||
|
||||
print('interaction_list1 length', len(interaction_list1))
|
||||
print('interaction_list2 length', len(interaction_list2))
|
||||
print('interaction_list3 length', len(interaction_list3))
|
||||
|
||||
#assert False
|
||||
s1 = 1.0
|
||||
s2 = 1.0
|
||||
n1 = 1.0
|
||||
n2 = 1.0
|
||||
m1 = 1.0
|
||||
m2 = 1.0
|
||||
n3 = 1.0
|
||||
train_interaction = []
|
||||
vis = set()
|
||||
for database_id in interaction_list1:
|
||||
if database_id not in dev_database:
|
||||
#Len = len(interaction_list1[database_id])//2
|
||||
Len = len(interaction_list1[database_id])
|
||||
train_interaction += interaction_list1[database_id][:Len]
|
||||
#print('xxx', type(interaction_list1[database_id]))
|
||||
#print(interaction_list1[database_id][0])
|
||||
#print(interaction_list2[database_id][0])
|
||||
#print(interaction_list3[database_id][0])
|
||||
#assert False
|
||||
if database_id in interaction_list2:
|
||||
train_interaction += interaction_list2[database_id]
|
||||
#for x in interaction_list1[database_id]:
|
||||
# if len( x['interaction'][0] ) > 300:
|
||||
# continue
|
||||
# train_interaction.append(x)
|
||||
# n3+=1
|
||||
'''if database_id in interaction_list2:
|
||||
for x in interaction_list2[database_id]:
|
||||
# continue
|
||||
#n1 += 1
|
||||
#m1 = max(m1, len(x['interaction']))
|
||||
s = '&'.join([k['utterance'] for k in x['interaction']])
|
||||
if s in vis:
|
||||
print('s', s)
|
||||
vis.add(s)
|
||||
'''
|
||||
if database_id in interaction_list3:
|
||||
for x in interaction_list3[database_id]:
|
||||
# continue
|
||||
#n1 += 1
|
||||
#m1 = max(m1, len(x['interaction']))
|
||||
#s = '&'.join([k['utterance'] for k in x['interaction']])
|
||||
#if s in vis:
|
||||
# print('s1', s)
|
||||
#vis.add(s)
|
||||
#print(x['interaction'][0]['utterance'])
|
||||
#print(x['interaction'][0]['utterance'])
|
||||
if len(x['interaction']) > 4:
|
||||
#assert False
|
||||
continue
|
||||
sumlen = 0
|
||||
for one in x['interaction']:
|
||||
sumlen += len(one['utterance'])
|
||||
if sumlen > 300:
|
||||
continue
|
||||
n1 += 1
|
||||
train_interaction.append(x)
|
||||
print('n1', n1)
|
||||
print('n3', n3)
|
||||
|
||||
dev_interaction = []
|
||||
for database_id in dev_database:
|
||||
dev_interaction += interaction_list2[database_id]
|
||||
|
||||
print('train interaction: ', len(train_interaction))
|
||||
print('dev interaction: ', len(dev_interaction))
|
||||
|
||||
write_interaction(train_interaction, 'train', output_dir)
|
||||
write_interaction(dev_interaction, 'dev', output_dir)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--datasets", type=str, default='')
|
||||
parser.add_argument('--remove_from', action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
assert args.datasets != ''
|
||||
preprocess(args.datasets, args.remove_from)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""Contains classes for computing and keeping track of attention distributions.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
class AttentionResult(namedtuple('AttentionResult',
|
||||
('scores',
|
||||
'distribution',
|
||||
'vector'))):
|
||||
"""Stores the result of an attention calculation."""
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""Attention mechanism class. Stores parameters for and computes attention.
|
||||
|
||||
Attributes:
|
||||
transform_query (bool): Whether or not to transform the query being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
transform_key (bool): Whether or not to transform the key being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
transform_value (bool): Whether or not to transform the value being
|
||||
passed in with a weight transformation before computing attentino.
|
||||
key_size (int): The size of the key vectors.
|
||||
value_size (int): The size of the value vectors.
|
||||
the query or key.
|
||||
query_weights (dy.Parameters): Weights for transforming the query.
|
||||
key_weights (dy.Parameters): Weights for transforming the key.
|
||||
value_weights (dy.Parameters): Weights for transforming the value.
|
||||
"""
|
||||
def __init__(self, query_size, key_size, value_size):
|
||||
super().__init__()
|
||||
self.key_size = key_size
|
||||
self.value_size = value_size
|
||||
|
||||
self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q")
|
||||
|
||||
def transform_arguments(self, query, keys, values):
|
||||
""" Transforms the query/key/value inputs before attention calculations.
|
||||
|
||||
Arguments:
|
||||
query (dy.Expression): Vector representing the query (e.g., hidden state.)
|
||||
keys (list of dy.Expression): List of vectors representing the key
|
||||
values.
|
||||
values (list of dy.Expression): List of vectors representing the values.
|
||||
|
||||
Returns:
|
||||
triple of dy.Expression, where the first represents the (transformed)
|
||||
query, the second represents the (transformed and concatenated)
|
||||
keys, and the third represents the (transformed and concatenated)
|
||||
values.
|
||||
"""
|
||||
assert len(keys) == len(values)
|
||||
|
||||
all_keys = torch.stack(keys, dim=1)
|
||||
all_values = torch.stack(values, dim=1)
|
||||
|
||||
assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0])
|
||||
assert all_values.size()[0] == self.value_size
|
||||
|
||||
query = torch_utils.linear_layer(query, self.query_weights)
|
||||
|
||||
return query, all_keys, all_values
|
||||
|
||||
def forward(self, query, keys, values=None):
|
||||
if not values:
|
||||
values = keys
|
||||
|
||||
query_t, keys_t, values_t = self.transform_arguments(query, keys, values)
|
||||
|
||||
scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query)
|
||||
|
||||
distribution = F.softmax(scores, dim=0) # len(key) x len(query)
|
||||
|
||||
context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query)
|
||||
|
||||
return AttentionResult(scores, distribution, context_vector)
|
|
@ -0,0 +1,97 @@
|
|||
import heapq
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple, List
|
||||
import torch
|
||||
|
||||
class BeamSearchState:
|
||||
def __init__(self, sequence=[], log_probability=0, state={}):
|
||||
self.sequence = sequence
|
||||
self.log_probability = log_probability
|
||||
self.state = state
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.log_probability < other.log_probability
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.log_probability}: {' '.join(self.sequence)}"
|
||||
|
||||
class BeamSearch:
|
||||
"""
|
||||
Implements the beam search algorithm for decoding the most likely sequences.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
is_end_of_sequence: Callable[[int, str], bool],
|
||||
max_steps: int = 50,
|
||||
beam_size: int = 10,
|
||||
per_node_beam_size: int = None,
|
||||
) -> None:
|
||||
self.is_end_of_sequence = is_end_of_sequence
|
||||
self.max_steps = max_steps
|
||||
self.beam_size = beam_size
|
||||
self.per_node_beam_size = per_node_beam_size
|
||||
if not per_node_beam_size:
|
||||
self.per_node_beam_size = beam_size
|
||||
|
||||
self.beam = []
|
||||
|
||||
def append(self, log_probability, new_state):
|
||||
if len(self.beam) < self.beam_size:
|
||||
heapq.heappush(self.beam, (log_probability, new_state))
|
||||
else:
|
||||
heapq.heapreplace(self.beam, (log_probability, new_state))
|
||||
|
||||
def get_largest(self, beam):
|
||||
return list(heapq.nlargest(1, beam))[0]
|
||||
|
||||
def search(
|
||||
self,
|
||||
start_state: Dict,
|
||||
step_function: Callable[[Dict], Tuple[List[str], torch.Tensor, Tuple]],
|
||||
append_token_function: Callable[[Tuple, Tuple, str, float], Tuple]
|
||||
) -> Tuple[List, float, Tuple]:
|
||||
state = BeamSearchState(state=start_state)
|
||||
heapq.heappush(self.beam, (0, state))
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
current_beam = self.beam
|
||||
|
||||
self.beam = []
|
||||
found_non_terminal = False
|
||||
|
||||
for (log_probability, state) in current_beam:
|
||||
if self.is_end_of_sequence(state.state, self.max_steps, state.sequence):
|
||||
self.append(log_probability, state)
|
||||
|
||||
else:
|
||||
found_non_terminal = True
|
||||
tokens, token_probabilities, extra_state = step_function(state.state)
|
||||
|
||||
tps = np.array(token_probabilities)
|
||||
top_indices = heapq.nlargest(self.per_node_beam_size, range(len(tps)), tps.take)
|
||||
top_tokens = [tokens[i] for i in top_indices]
|
||||
top_probabilities = [token_probabilities[i] for i in top_indices]
|
||||
sequence_length = len(state.sequence)
|
||||
for i in range(len(top_tokens)):
|
||||
if top_probabilities[i] > 1e-6:
|
||||
#lp = ((log_probability * sequence_length) + math.log(top_probabilities[i])) / (sequence_length + 1.0)
|
||||
lp = log_probability + math.log(top_probabilities[i])
|
||||
t = top_tokens[i]
|
||||
new_state = append_token_function(state.state, extra_state, t, lp)
|
||||
new_sequence = state.sequence.copy()
|
||||
new_sequence.append(t)
|
||||
|
||||
ns = BeamSearchState(sequence=new_sequence, log_probability=lp, state=new_state)
|
||||
|
||||
self.append(lp, ns)
|
||||
|
||||
if not found_non_terminal:
|
||||
done = True
|
||||
|
||||
output = self.get_largest(self.beam)[1]
|
||||
|
||||
return output.sequence, output.log_probability, output.state, self.beam
|
||||
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,255 @@
|
|||
# PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models
|
||||
|
||||
## Forked for wikisql application
|
||||
|
||||
## NSML
|
||||
|
||||
### SQuAD1.1 finetuning
|
||||
|
||||
```
|
||||
nsml run -d squad_bert -g 4 -e run_squad.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu"
|
||||
|
||||
```
|
||||
|
||||
### SQuAD2.0 finetuning
|
||||
|
||||
```
|
||||
nsml run -d squad_bert -g 4 -e run_squad2.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu"
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
1. Download prediction file from NSML session
|
||||
|
||||
```
|
||||
nsml download -f /app/squad_base [NSML_ID]/squad_bert/[SESSION] .
|
||||
```
|
||||
|
||||
2. Run official evaluation file
|
||||
|
||||
```
|
||||
python3 evaluate-v1.1.py [dev.json] [predictions.json]
|
||||
|
||||
python3 evaluate-v2.0.py [dev.json] [predictions.json] -n [na_probs.json]
|
||||
```
|
||||
|
||||
## Introduction
|
||||
|
||||
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
|
||||
|
||||
This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below).
|
||||
|
||||
The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated).
|
||||
|
||||
## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models))
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
|
||||
|
||||
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
|
||||
|
||||
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
|
||||
|
||||
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.
|
||||
|
||||
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
|
||||
|
||||
```shell
|
||||
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
||||
|
||||
python convert_tf_checkpoint_to_pytorch.py \
|
||||
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
|
||||
--bert_config_file $BERT_BASE_DIR/bert_config.json \
|
||||
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
|
||||
```
|
||||
|
||||
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
|
||||
|
||||
## PyTorch models for BERT
|
||||
|
||||
We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py):
|
||||
|
||||
- `BertModel` - the basic BERT Transformer model
|
||||
- `BertForSequenceClassification` - the BERT model with a sequence classification head on top
|
||||
- `BertForQuestionAnswering` - the BERT model with a token classification head on top
|
||||
|
||||
Here are some details on each class.
|
||||
|
||||
### 1. `BertModel`
|
||||
|
||||
`BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).
|
||||
|
||||
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
|
||||
|
||||
We detail them here. This model takes as inputs:
|
||||
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
|
||||
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||
|
||||
This model outputs a tuple composed of:
|
||||
|
||||
- `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and
|
||||
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||
|
||||
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
|
||||
|
||||
### 2. `BertForSequenceClassification`
|
||||
|
||||
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`.
|
||||
|
||||
The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
|
||||
|
||||
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
|
||||
|
||||
### 3. `BertForQuestionAnswering`
|
||||
|
||||
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
|
||||
|
||||
The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper).
|
||||
|
||||
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
|
||||
|
||||
## Installation, requirements, test
|
||||
|
||||
This code was tested on Python 3.5+. The requirements are:
|
||||
|
||||
- PyTorch (>= 0.4.1)
|
||||
- tqdm
|
||||
|
||||
To install the dependencies:
|
||||
|
||||
````bash
|
||||
pip install -r ./requirements.txt
|
||||
````
|
||||
|
||||
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
|
||||
|
||||
You can run the tests with the command:
|
||||
```bash
|
||||
python -m pytest -sv tests/
|
||||
```
|
||||
|
||||
## Training on large batches: gradient accumulation, multi-GPU and distributed training
|
||||
|
||||
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
|
||||
|
||||
To help with fine-tuning these models, we have included four techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: optimize on CPU, gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
|
||||
|
||||
Here is how to use these techniques in our scripts:
|
||||
|
||||
- **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script.
|
||||
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
|
||||
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
|
||||
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details):
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
|
||||
```
|
||||
|
||||
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
|
||||
|
||||
## TPU support and pretraining scripts
|
||||
|
||||
TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)).
|
||||
|
||||
We will add TPU support when this next release is published.
|
||||
|
||||
The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py).
|
||||
|
||||
Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts.
|
||||
|
||||
## Comparing the PyTorch model and the TensorFlow model predictions
|
||||
|
||||
We also include [two Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
||||
|
||||
- The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.
|
||||
|
||||
- The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
|
||||
|
||||
Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide.
|
||||
|
||||
## Fine-tuning with BERT: running the examples
|
||||
|
||||
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
|
||||
|
||||
Before running these examples you should download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
|
||||
checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
|
||||
|
||||
This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
|
||||
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_classifier.py \
|
||||
--task_name MRPC \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $GLUE_DIR/MRPC/ \
|
||||
--vocab_file $BERT_BASE_DIR/vocab.txt \
|
||||
--bert_config_file $BERT_BASE_DIR/bert_config.json \
|
||||
--init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
```
|
||||
|
||||
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
|
||||
|
||||
The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
|
||||
|
||||
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
|
||||
|
||||
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
||||
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
|
||||
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
|
||||
|
||||
```shell
|
||||
export SQUAD_DIR=/path/to/SQUAD
|
||||
|
||||
python run_squad.py \
|
||||
--vocab_file $BERT_BASE_DIR/vocab.txt \
|
||||
--bert_config_file $BERT_BASE_DIR/bert_config.json \
|
||||
--init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_lower_case
|
||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||
--train_batch_size 12 \
|
||||
--learning_rate 3e-5 \
|
||||
--num_train_epochs 2.0 \
|
||||
--max_seq_length 384 \
|
||||
--doc_stride 128 \
|
||||
--output_dir ../debug_squad/
|
||||
```
|
||||
|
||||
Training with the previous hyper-parameters gave us the following results:
|
||||
```bash
|
||||
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
|
||||
```
|
||||
|
||||
# Fine-tuning BERT-large on GPUs
|
||||
|
||||
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
|
||||
|
||||
For example, fine-tuning BERT-large on SQuAD can be done on a server with 4 k-80 (these are pretty old now) in 18 hours. Our results are similar to the TensorFlow implementation results (actually slightly higher):
|
||||
```bash
|
||||
{"exact_match": 84.56953642384106, "f1": 91.04028647786927}
|
||||
```
|
||||
To get these results that we used a combination of:
|
||||
- multi-GPU training (automatically activated on a multi-GPU server),
|
||||
- 2 steps of gradient accumulation and
|
||||
- perform the optimization step on CPU to store Adam's averages in RAM.
|
||||
|
||||
Here are the full list of hyper-parameters we used for this run:
|
||||
```bash
|
||||
python ./run_squad.py --vocab_file $BERT_LARGE_DIR/vocab.txt --bert_config_file $BERT_LARGE_DIR/bert_config.json --init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin --do_lower_case --do_train --do_predict --train_file $SQUAD_TRAIN --predict_file $SQUAD_EVAL --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir $OUTPUT_DIR/bert_large_bsz_24 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu
|
||||
```
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modeling import BertConfig, BertModel
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def convert():
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(args.bert_config_file)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from TF model
|
||||
path = args.tf_checkpoint_path
|
||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||
|
||||
init_vars = tf.train.list_variables(path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(path, name)
|
||||
print("Numpy array shape {}".format(array.shape))
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[5:] # skip "bert/"
|
||||
print("Loading {}".format(name))
|
||||
name = name.split('/')
|
||||
if name[0] in ['redictions', 'eq_relationship']:
|
||||
print("Skipping")
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,668 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import six
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
def print_status(self):
|
||||
"""
|
||||
Wonseok add this.
|
||||
"""
|
||||
|
||||
print( f"vocab size: {self.vocab_size}")
|
||||
print( f"hidden_size: {self.hidden_size}")
|
||||
print( f"num_hidden_layer: {self.num_hidden_layers}")
|
||||
print( f"num_attention_heads: {self.num_attention_heads}")
|
||||
print( f"hidden_act: {self.hidden_act}")
|
||||
print( f"intermediate_size: {self.intermediate_size}")
|
||||
print( f"hidden_dropout_prob: {self.hidden_dropout_prob}")
|
||||
print( f"attention_probs_dropout_prob: {self.attention_probs_dropout_prob}")
|
||||
print( f"max_position_embeddings: {self.max_position_embeddings}")
|
||||
print( f"type_vocab_size: {self.type_vocab_size}")
|
||||
print( f"initializer_range: {self.initializer_range}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size=None)
|
||||
for (key, value) in six.iteritems(json_object):
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r") as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
||||
class BERTLayerNorm(nn.Module):
|
||||
def __init__(self, config, variance_epsilon=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BERTLayerNorm, self).__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
|
||||
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
|
||||
self.variance_epsilon = variance_epsilon
|
||||
|
||||
def forward(self, x):
|
||||
# x = input to the neuron.
|
||||
# normalize each vector (each token).
|
||||
# regularize x.
|
||||
# If x follows Gaussian distribution, it becomes standard Normal distribution (i.e., mu=0, std=1).
|
||||
u = x.mean(-1, keepdim=True) # keepdim = keeprank of tensor.
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True) # variance
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) # standard
|
||||
|
||||
# Gamma & Beta is trainable parameters.
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
class BERTEmbeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTEmbeddings, self).__init__()
|
||||
"""Construct the embedding module from word, position and token_type embeddings.
|
||||
"""
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BERTSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
""" From this,
|
||||
|
||||
"""
|
||||
# Maybe: x = [B, seq_len, all_head_size=hidden_size ]
|
||||
# [B, seq_len] + [num_attention_heads, attention_head_size] ???
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # it is append (for torch.Size tensor, + acts as an concat. operation)
|
||||
|
||||
# [B, seq_len, all_head_size=hidden_size ] -> [B, seq_len, num_attention_heads, attention_head_size]
|
||||
x = x.view(*new_x_shape)
|
||||
|
||||
# [B, seq_len, num_attention_heads, attention_head_size] -> [B, num_attention_heads, seq_len, attention_head_size]
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
# [B, num_attention_heads, seq_len, attention_head_size] * [B, num_attention_heads, attention_head_size, seq_len]
|
||||
# -> [B, num_attention_heads, seq_len, seq_len]
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask # sort of multiplication in soft-max step. It is ~ -10000
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# [B, num_attention_heads, seq_len, seq_len] * [B, num_attention_heads, seq_len, attention_head_size]
|
||||
# -> [B, num_attention_heads, seq_len, attention_head_size]
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
# -> [B, seq_len, num_attention_heads, attention_head_size]
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# [B, seq_len] + [all_head_size=hidden_size] -> [B, seq_len, all_head_size]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
class BERTSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTAttention, self).__init__()
|
||||
self.self = BERTSelfAttention(config)
|
||||
self.output = BERTSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
return attention_output
|
||||
|
||||
|
||||
class BERTIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = gelu
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTLayer, self).__init__()
|
||||
self.attention = BERTAttention(config)
|
||||
self.intermediate = BERTIntermediate(config)
|
||||
self.output = BERTOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BERTEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTEncoder, self).__init__()
|
||||
layer = BERTLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BERTPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = modeling.BertModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config: BertConfig):
|
||||
"""Constructor for BertModel.
|
||||
|
||||
Args:
|
||||
config: `BertConfig` instance.
|
||||
"""
|
||||
super(BertModel, self).__init__()
|
||||
self.embeddings = BERTEmbeddings(config)
|
||||
self.encoder = BERTEncoder(config)
|
||||
self.pooler = BERTPooler(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.float()
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
return all_encoder_layers, pooled_output
|
||||
|
||||
class BertForSequenceClassification(nn.Module):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
num_labels = 2
|
||||
|
||||
model = BertForSequenceClassification(config, num_labels)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForSequenceClassification, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
|
||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
|
||||
class BertForQuestionAnswering(nn.Module):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the sequence output that computes start_logits and end_logits
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = BertForQuestionAnswering(config)
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertForQuestionAnswering, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
|
||||
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
return total_loss
|
||||
else:
|
||||
return start_logits, end_logits
|
||||
|
||||
|
||||
class BertNoAnswer(nn.Module):
|
||||
def __init__(self,hidden_size,context_length=317):
|
||||
super(BertNoAnswer,self).__init__()
|
||||
self.context_length = context_length
|
||||
self.W_no = nn.Linear(hidden_size,1)
|
||||
self.no_answer = nn.Sequential(nn.Linear(hidden_size*3,hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size,2))
|
||||
|
||||
def forward(self,sequence_output,start_logit,end_logit,mask=None):
|
||||
if mask is None:
|
||||
nbatch,length,_ = sequence_output.size()
|
||||
mask = torch.ones(nbatch,length)
|
||||
mask = mask.float()
|
||||
mask = mask.unsqueeze(-1)[:,1:self.context_length+1]
|
||||
mask = (1.0 - mask) * -10000.0
|
||||
sequence_output = sequence_output[:,1:self.context_length+1]
|
||||
start_logit = start_logit[:,1:self.context_length+1] + mask
|
||||
end_logit = end_logit[:,1:self.context_length+1] + mask
|
||||
|
||||
# No-answer option
|
||||
pa_1 = nn.functional.softmax(start_logit.transpose(1,2),-1) # B,1,T
|
||||
v1 = torch.bmm(pa_1,sequence_output).squeeze(1) # B,H
|
||||
pa_2 = nn.functional.softmax(end_logit.transpose(1,2),-1) # B,1,T
|
||||
v2 = torch.bmm(pa_2,sequence_output).squeeze(1) # B,H
|
||||
pa_3 = self.W_no(sequence_output) + mask
|
||||
pa_3 = nn.functional.softmax(pa_3.transpose(1,2),-1) # B,1,T
|
||||
v3 = torch.bmm(pa_3,sequence_output).squeeze(1) # B,H
|
||||
|
||||
bias = self.no_answer(torch.cat([v1,v2,v3],-1)) # B,1
|
||||
|
||||
return bias
|
||||
|
||||
class BertForSQuAD2(nn.Module):
|
||||
def __init__(self, config, context_length=317):
|
||||
super(BertForSQuAD2, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 2)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.na_head = BertNoAnswer(config.hidden_size, context_length)
|
||||
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None, labels=None):
|
||||
all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
span_logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = span_logits.split(1, dim=-1)
|
||||
na_logits = self.na_head(sequence_output,start_logits,end_logits,attention_mask)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
logits = (na_logits+logits)/2 # mean
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
unanswerable_loss = loss_fct(logits, labels)
|
||||
span_loss = (start_loss + end_loss) / 2
|
||||
total_loss = span_loss + unanswerable_loss
|
||||
return total_loss
|
||||
else:
|
||||
probs = nn.functional.softmax(logits,-1)
|
||||
_, probs = probs.split(1, dim=-1)
|
||||
return start_logits, end_logits, probs
|
||||
|
||||
|
||||
class BertForWikiSQL(nn.Module):
|
||||
def __init__(self, config, context_length=317):
|
||||
super(BertForWikiSQL, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 2)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.na_head = BertNoAnswer(config.hidden_size, context_length)
|
||||
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None, labels=None):
|
||||
all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
span_logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = span_logits.split(1, dim=-1)
|
||||
na_logits = self.na_head(sequence_output, start_logits, end_logits, attention_mask)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
logits = (na_logits + logits) / 2 # mean
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
unanswerable_loss = loss_fct(logits, labels)
|
||||
span_loss = (start_loss + end_loss) / 2
|
||||
total_loss = span_loss + unanswerable_loss
|
||||
return total_loss
|
||||
else:
|
||||
probs = nn.functional.softmax(logits, -1)
|
||||
_, probs = probs.split(1, dim=-1)
|
||||
return start_logits, end_logits, probs
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,333 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(vocab[token])
|
||||
return ids
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_tokens_to_ids(self.vocab, tokens)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,287 @@
|
|||
""" Decoder for the SQL generation problem."""
|
||||
|
||||
from collections import namedtuple
|
||||
import copy
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import torch_utils
|
||||
|
||||
from .token_predictor import PredictionInput, PredictionInputWithSchema, PredictionStepInputWithSchema
|
||||
from .beam_search import BeamSearch
|
||||
import data_util.snippets as snippet_handler
|
||||
from . import embedder
|
||||
from data_util.vocabulary import EOS_TOK, UNK_TOK
|
||||
import math
|
||||
|
||||
def flatten_distribution(distribution_map, probabilities):
|
||||
""" Flattens a probability distribution given a map of "unique" values.
|
||||
All values in distribution_map with the same value should get the sum
|
||||
of the probabilities.
|
||||
|
||||
Arguments:
|
||||
distribution_map (list of str): List of values to get the probability for.
|
||||
probabilities (np.ndarray): Probabilities corresponding to the values in
|
||||
distribution_map.
|
||||
|
||||
Returns:
|
||||
list, np.ndarray of the same size where probabilities for duplicates
|
||||
in distribution_map are given the sum of the probabilities in probabilities.
|
||||
"""
|
||||
assert len(distribution_map) == len(probabilities)
|
||||
if len(distribution_map) != len(set(distribution_map)):
|
||||
idx_first_dup = 0
|
||||
seen_set = set()
|
||||
for i, tok in enumerate(distribution_map):
|
||||
if tok in seen_set:
|
||||
idx_first_dup = i
|
||||
break
|
||||
seen_set.add(tok)
|
||||
new_dist_map = distribution_map[:idx_first_dup] + list(
|
||||
set(distribution_map) - set(distribution_map[:idx_first_dup]))
|
||||
assert len(new_dist_map) == len(set(new_dist_map))
|
||||
new_probs = np.array(
|
||||
probabilities[:idx_first_dup] \
|
||||
+ [0. for _ in range(len(set(distribution_map)) \
|
||||
- idx_first_dup)])
|
||||
assert len(new_probs) == len(new_dist_map)
|
||||
|
||||
for i, token_name in enumerate(
|
||||
distribution_map[idx_first_dup:]):
|
||||
if token_name not in new_dist_map:
|
||||
new_dist_map.append(token_name)
|
||||
|
||||
new_index = new_dist_map.index(token_name)
|
||||
new_probs[new_index] += probabilities[i +
|
||||
idx_first_dup]
|
||||
new_probs = new_probs.tolist()
|
||||
else:
|
||||
new_dist_map = distribution_map
|
||||
new_probs = probabilities
|
||||
|
||||
assert len(new_dist_map) == len(new_probs)
|
||||
|
||||
return new_dist_map, new_probs
|
||||
|
||||
class SQLPrediction(namedtuple('SQLPrediction',
|
||||
('predictions',
|
||||
'sequence',
|
||||
'probability',
|
||||
'beam'))):
|
||||
"""Contains prediction for a sequence."""
|
||||
__slots__ = ()
|
||||
|
||||
def __str__(self):
|
||||
return str(self.probability) + "\t" + " ".join(self.sequence)
|
||||
|
||||
class SequencePredictorWithSchema(torch.nn.Module):
|
||||
""" Predicts a sequence.
|
||||
|
||||
Attributes:
|
||||
lstms (list of dy.RNNBuilder): The RNN used.
|
||||
token_predictor (TokenPredictor): Used to actually predict tokens.
|
||||
"""
|
||||
def __init__(self,
|
||||
params,
|
||||
input_size,
|
||||
output_embedder,
|
||||
column_name_token_embedder,
|
||||
token_predictor):
|
||||
super().__init__()
|
||||
|
||||
self.lstms = torch_utils.create_multilayer_lstm_params(params.decoder_num_layers, input_size, params.decoder_state_size, "LSTM-d")
|
||||
self.token_predictor = token_predictor
|
||||
self.output_embedder = output_embedder
|
||||
self.column_name_token_embedder = column_name_token_embedder
|
||||
self.start_token_embedding = torch_utils.add_params((params.output_embedding_size,), "y-0")
|
||||
|
||||
self.input_size = input_size
|
||||
self.params = params
|
||||
|
||||
def _initialize_decoder_lstm(self, encoder_state, utterance_index):
|
||||
decoder_lstm_states = []
|
||||
for i, lstm in enumerate(self.lstms):
|
||||
encoder_layer_num = 0
|
||||
if len(encoder_state[0]) > 1:
|
||||
encoder_layer_num = i
|
||||
|
||||
# check which one is h_0, which is c_0
|
||||
c_0 = encoder_state[0][encoder_layer_num].view(1,-1)
|
||||
h_0 = encoder_state[1][encoder_layer_num].view(1,-1)
|
||||
|
||||
decoder_lstm_states.append((h_0, c_0))
|
||||
return decoder_lstm_states
|
||||
|
||||
def get_output_token_embedding(self, output_token, input_schema, snippets):
|
||||
if self.params.use_snippets and snippet_handler.is_snippet(output_token):
|
||||
output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema)
|
||||
else:
|
||||
if input_schema:
|
||||
assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
|
||||
if self.output_embedder.in_vocabulary(output_token):
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
else:
|
||||
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
|
||||
else:
|
||||
output_token_embedding = self.output_embedder(output_token)
|
||||
return output_token_embedding
|
||||
|
||||
def get_decoder_input(self, output_token_embedding, prediction):
|
||||
if self.params.use_schema_attention and self.params.use_query_attention:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector, prediction.query_attention_results.vector], dim=0)
|
||||
elif self.params.use_schema_attention:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector], dim=0)
|
||||
else:
|
||||
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector], dim=0)
|
||||
return decoder_input
|
||||
|
||||
def forward(self,
|
||||
final_encoder_state,
|
||||
encoder_states,
|
||||
schema_states,
|
||||
max_generation_length,
|
||||
utterance_index=None,
|
||||
snippets=None,
|
||||
gold_sequence=None,
|
||||
input_sequence=None,
|
||||
previous_queries=None,
|
||||
previous_query_states=None,
|
||||
input_schema=None,
|
||||
dropout_amount=0.):
|
||||
""" Generates a sequence. """
|
||||
index = 0
|
||||
|
||||
context_vector_size = self.input_size - self.params.output_embedding_size
|
||||
|
||||
# Decoder states: just the initialized decoder.
|
||||
# Current input to decoder: phi(start_token) ; zeros the size of the
|
||||
# context vector
|
||||
predictions = []
|
||||
sequence = []
|
||||
probability = 1.
|
||||
|
||||
decoder_states = self._initialize_decoder_lstm(final_encoder_state, utterance_index)
|
||||
|
||||
if self.start_token_embedding.is_cuda:
|
||||
decoder_input = torch.cat([self.start_token_embedding, torch.cuda.FloatTensor(context_vector_size).fill_(0)], dim=0)
|
||||
else:
|
||||
decoder_input = torch.cat([self.start_token_embedding, torch.zeros(context_vector_size)], dim=0)
|
||||
|
||||
beam_search = BeamSearch(is_end_of_sequence=self.is_end_of_sequence, max_steps=max_generation_length, beam_size=10)
|
||||
prediction_step_input = PredictionStepInputWithSchema(
|
||||
index=index,
|
||||
decoder_input=decoder_input,
|
||||
decoder_states=decoder_states,
|
||||
encoder_states=encoder_states,
|
||||
schema_states=schema_states,
|
||||
snippets=snippets,
|
||||
gold_sequence=gold_sequence,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema,
|
||||
dropout_amount=dropout_amount,
|
||||
predictions=predictions,
|
||||
)
|
||||
|
||||
sequence, probability, final_state, beam = beam_search.search(start_state=prediction_step_input, step_function=self.beam_search_step_function, append_token_function=self.beam_search_append_token)
|
||||
|
||||
return SQLPrediction(final_state.predictions,
|
||||
sequence,
|
||||
probability,
|
||||
beam)
|
||||
|
||||
def beam_search_step_function(self, prediction_step_input):
|
||||
# get decoded token probabilities
|
||||
prediction, tokens, token_probabilities, decoder_states = self.decode_next_token(prediction_step_input)
|
||||
return tokens, token_probabilities, (prediction, decoder_states)
|
||||
|
||||
def beam_search_append_token(self, prediction_step_input, step_function_output, token_to_append, token_log_probability):
|
||||
output_token_embedding = self.get_output_token_embedding(token_to_append, prediction_step_input.input_schema, prediction_step_input.snippets)
|
||||
decoder_input = self.get_decoder_input(output_token_embedding, step_function_output[0])
|
||||
|
||||
predictions = prediction_step_input.predictions.copy()
|
||||
predictions.append(step_function_output[0])
|
||||
new_state = prediction_step_input._replace(
|
||||
index=prediction_step_input.index+1,
|
||||
decoder_input=decoder_input,
|
||||
predictions=predictions,
|
||||
decoder_states=step_function_output[1],
|
||||
)
|
||||
|
||||
return new_state
|
||||
|
||||
def is_end_of_sequence(self, prediction_step_input, max_generation_length, sequence):
|
||||
if prediction_step_input.gold_sequence:
|
||||
return prediction_step_input.index >= len(prediction_step_input.gold_sequence)
|
||||
else:
|
||||
return prediction_step_input.index >= max_generation_length or (len(sequence) > 0 and sequence[-1] == EOS_TOK)
|
||||
|
||||
def decode_next_token(self, prediction_step_input):
|
||||
(index,
|
||||
decoder_input,
|
||||
decoder_states,
|
||||
encoder_states,
|
||||
schema_states,
|
||||
snippets,
|
||||
gold_sequence,
|
||||
input_sequence,
|
||||
previous_queries,
|
||||
previous_query_states,
|
||||
input_schema,
|
||||
dropout_amount,
|
||||
predictions) = prediction_step_input
|
||||
|
||||
_, decoder_state, decoder_states = torch_utils.forward_one_multilayer(self.lstms, decoder_input, decoder_states, dropout_amount)
|
||||
prediction_input = PredictionInputWithSchema(decoder_state=decoder_state,
|
||||
input_hidden_states=encoder_states,
|
||||
schema_states=schema_states,
|
||||
snippets=snippets,
|
||||
input_sequence=input_sequence,
|
||||
previous_queries=previous_queries,
|
||||
previous_query_states=previous_query_states,
|
||||
input_schema=input_schema)
|
||||
|
||||
prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount)
|
||||
|
||||
#gold_sequence = None
|
||||
|
||||
if gold_sequence:
|
||||
decoded_token = gold_sequence[index]
|
||||
distribution_map = [decoded_token]
|
||||
probabilities = [1.0]
|
||||
|
||||
else:#/mnt/lichaochao/rqy/editsql_data/glove.840B.300d.txt
|
||||
assert prediction.scores.dim() == 1
|
||||
probabilities = F.softmax(prediction.scores, dim=0).cpu().data.numpy().tolist()
|
||||
|
||||
distribution_map = prediction.aligned_tokens
|
||||
assert len(probabilities) == len(distribution_map)
|
||||
|
||||
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
|
||||
assert prediction.query_scores.dim() == 1
|
||||
query_token_probabilities = F.softmax(prediction.query_scores, dim=0).cpu().data.numpy().tolist()
|
||||
|
||||
query_token_distribution_map = prediction.query_tokens
|
||||
|
||||
assert len(query_token_probabilities) == len(query_token_distribution_map)
|
||||
|
||||
copy_switch = prediction.copy_switch.cpu().data.numpy()
|
||||
|
||||
# Merge the two
|
||||
probabilities = ((np.array(probabilities) * (1 - copy_switch)).tolist() +
|
||||
(np.array(query_token_probabilities) * copy_switch).tolist()
|
||||
)
|
||||
distribution_map = distribution_map + query_token_distribution_map
|
||||
assert len(probabilities) == len(distribution_map)
|
||||
|
||||
# Get a new probabilities and distribution_map consolidating duplicates
|
||||
distribution_map, probabilities = flatten_distribution(distribution_map, probabilities)
|
||||
|
||||
# Modify the probability distribution so that the UNK token can never be produced
|
||||
probabilities[distribution_map.index(UNK_TOK)] = 0.
|
||||
|
||||
return prediction, distribution_map, probabilities, decoder_states
|
|
@ -0,0 +1,100 @@
|
|||
""" Embedder for tokens. """
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import data_util.snippets as snippet_handler
|
||||
import data_util.vocabulary as vocabulary_handler
|
||||
|
||||
class Embedder(torch.nn.Module):
|
||||
""" Embeds tokens. """
|
||||
def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True):
|
||||
super().__init__()
|
||||
|
||||
if vocabulary:
|
||||
assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \
|
||||
str(num_tokens)
|
||||
self.in_vocabulary = lambda token: token in vocabulary.tokens
|
||||
self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token)
|
||||
if use_unk:
|
||||
self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK)
|
||||
else:
|
||||
self.unknown_token_id = -1
|
||||
self.vocabulary_size = len(vocabulary)
|
||||
else:
|
||||
def check_vocab(index):
|
||||
""" Makes sure the index is in the vocabulary."""
|
||||
assert index < num_tokens, "Passed token ID " + \
|
||||
str(index) + "; expecting something less than " + str(num_tokens)
|
||||
return index < num_tokens
|
||||
self.in_vocabulary = check_vocab
|
||||
self.vocab_token_lookup = lambda x: x
|
||||
self.unknown_token_id = num_tokens # Deliberately throws an error here,
|
||||
# But should crash before this
|
||||
self.vocabulary_size = num_tokens
|
||||
|
||||
self.anonymizer = anonymizer
|
||||
|
||||
emb_name = name + "-tokens"
|
||||
#print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size))
|
||||
|
||||
if initializer is not None:
|
||||
word_embeddings_tensor = torch.FloatTensor(initializer)
|
||||
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze)
|
||||
else:
|
||||
init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1)
|
||||
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
|
||||
|
||||
if self.anonymizer:
|
||||
emb_name = name + "-entities"
|
||||
entity_size = len(self.anonymizer.entity_types)
|
||||
#print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size))
|
||||
init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1)
|
||||
self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
|
||||
|
||||
|
||||
def forward(self, token):
|
||||
assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"
|
||||
|
||||
if self.in_vocabulary(token):
|
||||
index_list = torch.LongTensor([self.vocab_token_lookup(token)])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.token_embedding_matrix(index_list).squeeze()
|
||||
elif self.anonymizer and self.anonymizer.is_anon_tok(token):
|
||||
index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.entity_embedding_matrix(index_list).squeeze()
|
||||
else:
|
||||
index_list = torch.LongTensor([self.unknown_token_id])
|
||||
if self.token_embedding_matrix.weight.is_cuda:
|
||||
index_list = index_list.cuda()
|
||||
return self.token_embedding_matrix(index_list).squeeze()
|
||||
|
||||
|
||||
def bow_snippets(token, snippets, output_embedder, input_schema):
|
||||
""" Bag of words embedding for snippets"""
|
||||
assert snippet_handler.is_snippet(token) and snippets
|
||||
|
||||
snippet_sequence = []
|
||||
for snippet in snippets:
|
||||
if snippet.name == token:
|
||||
snippet_sequence = snippet.sequence
|
||||
break
|
||||
assert snippet_sequence
|
||||
|
||||
if input_schema:
|
||||
snippet_embeddings = []
|
||||
for output_token in snippet_sequence:
|
||||
assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
|
||||
if output_embedder.in_vocabulary(output_token):
|
||||
snippet_embeddings.append(output_embedder(output_token))
|
||||
else:
|
||||
snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True))
|
||||
else:
|
||||
snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence]
|
||||
|
||||
snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size
|
||||
return torch.mean(snippet_embeddings, dim=0) # emb_size
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue