add: r2sql repo

This commit is contained in:
出蛰 2022-07-15 19:18:00 +08:00
parent d3eba969df
commit 2659e4dd89
135 changed files with 79350 additions and 0 deletions

145
r2sql/README.md Normal file
View File

@ -0,0 +1,145 @@
# R²SQL
The PyTorch implementation of paper [Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing.](https://arxiv.org/pdf/2101.01686) (AAAI 2021)
## Requirements
The model is tested in python 3.6 with following requirements:
```
torch==1.0.0
transformers==2.10.0
sqlparse
pymysql
progressbar
nltk
numpy
six
spacy
```
All experiments on SParC and CoSQL datasets were run on NVIDIA V100 GPU with 32GB GPU memory.
* Tips: The 16GB GPU memory may appear out-of-memory error.
## Setup
The SParC and CoSQL experiments in two different folders, you need to download different datasets from [[SParC](https://yale-lily.github.io/spider) | [CoSQL](https://yale-lily.github.io/cosql)] to the `{sparc|cosql}/data` folder separately.
Another related data file could be download from [EditSQL](https://github.com/ryanzhumich/editsql/tree/master/data).
Then, download the database sqlite files from [[here](https://drive.google.com/file/d/1a828mkHcgyQCBgVla0jGxKJ58aV8RsYK/view?usp=sharing)] as `data/database`.
Download Pretrained BERT model from [[here](https://drive.google.com/file/d/1f_LEWVgrtZLRuoiExJa5fNzTS8-WcAX9/view?usp=sharing)] as `model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin`.
Download Glove embeddings file (`glove.840B.300d.txt`) and change the `GLOVE_PATH` for your own path in all scripts.
Download Reranker models from [[SParC reranker](https://drive.google.com/file/d/1cA106xgSx6KeonOxD2sZ06Eolptxt_OG/view?usp=sharing) | [CoSQL reranker](https://drive.google.com/file/d/1UURYw15T6zORcYRTvP51MYkzaxNmvRIU/view?usp=sharing)] as `submit_models/reranker_roberta.pt`, besides the roberta-base model could download from [here](https://drive.google.com/file/d/1LkTe-Z0AFg2dAAWgUKuCLEhSmtW-CWXh/view?usp=sharing) for `./[sparc|cosql]/local_param/`.
## Usage
Train the model from scratch.
```bash
./sparc_train.sh
```
Test the model for the concrete checkpoint:
```bash
./sparc_test.sh
```
then the dev prediction file will be appeared in `results` folder, named like `save_%d_predictions.json`.
Get the evaluation result from the prediction file:
```bash
./sparc_evaluate.sh
```
the final result will be appeared in `results` folder, named `*.eval`.
Similarly, the CoSQL experiments could be reproduced in same way.
---
You could download our trained checkpoint and results in here:
* SParC: [[log](https://drive.google.com/file/d/19ySQ_4x3R-T0cML2uJQBaYI2EyTlPr1G/view?usp=sharing) | [results](https://drive.google.com/file/d/12-kTEnNJKKblPDx5UIz5W0lVvf_sWpyS/view?usp=sharing)]
* CoSQL: [[log](https://drive.google.com/file/d/1QaxM8AUu3cQUXIZvCgoqW115tZCcEppl/view?usp=sharing) | [results](https://drive.google.com/file/d/1fCTRagV46gvEKU5XPje0Um69rMkEAztU/view?usp=sharing)]
### Reranker
If your want train your own reranker model, you could download the training file from here:
* SParC: [[reranker training data](https://drive.google.com/file/d/1XEiYUmDsVGouCO6NZS1yyMkUDxvWgCZ9/view?usp=sharing)]
* CoSQL: [[reranker training data](https://drive.google.com/file/d/1mzjywnMiABOTHYC9BWOoUOn4HnokcX8i/view?usp=sharing)]
Then you could train, test and predict it:
train:
```bash
python -m reranker.main --train --batch_size 64 --epoches 50
```
test:
```bash
python -m reranker.main --test --batch_size 64
```
predict:
```bash
python -m reranker.predict
```
## Improvements
We have improved the origin version (descripted in paper) and got more performance improvements :partying_face:!
Compare with the origin version, we have made the following improvements
* add the self-ensemble strategy for prediction, which use different epoch checkpoint to get final result. In order to easily perform this strategy, we remove the task-related representation in Reranker module.
* remove the decay function in DCRI, we find that DCRI is unstable with decay function, so we let DCRI degenerate into vanilla cross attention.
* replace the BERT-based with RoBERTa-based model for Reranker module.
The final performance comparison on dev as follows:
<table>
<tr>
<th></th>
<th colspan="2">SParC</th>
<th colspan="2">CoSQL</th>
</tr>
<tr>
<td></td>
<td>QM</td>
<td>IM</td>
<td>QM</td>
<td>IM</td>
</tr>
<tr>
<td>EditSQL</td>
<td>47.2</td>
<td>29.5</td>
<td>39.9</td>
<td>12.3</td>
</tr>
<tr>
<td>R²SQL v1 (origin paper)</td>
<td>54.1</td>
<td>35.2</td>
<td>45.7</td>
<td>19.5</td>
</tr>
<tr>
<td>R²SQL v2 (this repo)</td>
<td>54.0</td>
<td>35.2</td>
<td>46.3</td>
<td>19.5</td>
</tr>
<tr>
<td>R²SQL v2 + ensemble </td>
<td>55.1</td>
<td>36.8</td>
<td>47.3</td>
<td>20.9</td>
</tr>
</table>
## Citation
Please star this repo and cite paper if you want to use it in your work.
## Acknowledgments
This implementation is based on ["Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions"](https://github.com/ryanzhumich/editsql) EMNLP 2019.

1
r2sql/cosql/.gitattributes vendored Normal file
View File

@ -0,0 +1 @@
save_* filter=lfs diff=lfs merge=lfs -text

4
r2sql/cosql/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*pyc
data/*
logs_spider_editsql/*
processed_data_spider_removefrom/*

4
r2sql/cosql/cosql_evaluate.sh Executable file
View File

@ -0,0 +1,4 @@
#! /bin/bash
# 3. get evaluation result
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file results/save_14_predictions.json --remove_from

View File

@ -0,0 +1,4 @@
#! /bin/bash
# 3. get evaluation result
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file results/save_14_predictions.json,results/save_9_predictions.json,results/save_6_predictions.json,results/save_3_predictions.json --remove_from

36
r2sql/cosql/cosql_test.sh Executable file
View File

@ -0,0 +1,36 @@
#! /bin/bash
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
LOGDIR="log/cosql"
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_14"
# 3. get evaluation result
#python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from

View File

@ -0,0 +1,125 @@
#! /bin/bash
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
LOGDIR="log/cosql"
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_3"
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_6"
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_9"
CUDA_VISIBLE_DEVICES=1 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_14"
# 3. get evaluation result
#python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from

40
r2sql/cosql/cosql_train.sh Executable file
View File

@ -0,0 +1,40 @@
#! /bin/bash
# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/
python3 preprocess.py --dataset=cosql --remove_from
# 2. train and evaluate.
# the result (models, logs, prediction outputs) are saved in $LOGDIR
LOGDIR='log/cosql'
GLOVE_PATH="/dev/shm/glove_embeddings.pkl" # you need to change this
CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--train=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--seed=484

View File

@ -0,0 +1,262 @@
"""Code for identifying and anonymizing entities in NL and SQL."""
import copy
import json
from . import util
ENTITY_NAME = "ENTITY"
CONSTANT_NAME = "CONSTANT"
TIME_NAME = "TIME"
SEPARATOR = "#"
def timeval(string):
"""Returns the numeric version of a time.
Inputs:
string (str): String representing a time.
Returns:
String representing the absolute time.
"""
if string.endswith("am") or string.endswith(
"pm") and string[:-2].isdigit():
numval = int(string[:-2])
if len(string) == 3 or len(string) == 4:
numval *= 100
if string.endswith("pm"):
numval += 1200
return str(numval)
return ""
def is_time(string):
"""Returns whether a string represents a time.
Inputs:
string (str): String to check.
Returns:
Whether the string represents a time.
"""
if string.endswith("am") or string.endswith("pm"):
if string[:-2].isdigit():
return True
return False
def deanonymize(sequence, ent_dict, key):
"""Deanonymizes a sequence.
Inputs:
sequence (list of str): List of tokens to deanonymize.
ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary.
key (str): The key to use, in this case either natural language or SQL.
Returns:
Deanonymized sequence of tokens.
"""
new_sequence = []
for token in sequence:
if token in ent_dict:
new_sequence.extend(ent_dict[token][key])
else:
new_sequence.append(token)
return new_sequence
class Anonymizer:
"""Anonymization class for keeping track of entities in this domain and
scripts for anonymizing/deanonymizing.
Members:
anonymization_map (list of dict (str->str)): Containing entities from
the anonymization file.
entity_types (list of str): All entities in the anonymization file.
keys (set of str): Possible keys (types of text handled); in this case it should be
one for natural language and another for SQL.
entity_set (set of str): entity_types as a set.
"""
def __init__(self, filename):
self.anonymization_map = []
self.entity_types = []
self.keys = set()
pairs = [json.loads(line) for line in open(filename).readlines()]
for pair in pairs:
for key in pair:
if key != "type":
self.keys.add(key)
self.anonymization_map.append(pair)
if pair["type"] not in self.entity_types:
self.entity_types.append(pair["type"])
self.entity_types.append(ENTITY_NAME)
self.entity_types.append(CONSTANT_NAME)
self.entity_types.append(TIME_NAME)
self.entity_set = set(self.entity_types)
def get_entity_type_from_token(self, token):
"""Gets the type of an entity given an anonymized token.
Inputs:
token (str): The entity token.
Returns:
str, representing the type of the entity.
"""
# these are in the pattern NAME:#, so just strip the thing after the
# colon
colon_loc = token.index(SEPARATOR)
entity_type = token[:colon_loc]
assert entity_type in self.entity_set
return entity_type
def is_anon_tok(self, token):
"""Returns whether a token is an anonymized token or not.
Input:
token (str): The token to check.
Returns:
bool, whether the token is an anonymized token.
"""
return token.split(SEPARATOR)[0] in self.entity_set
def get_anon_id(self, token):
"""Gets the entity index (unique ID) for a token.
Input:
token (str): The token to get the index from.
Returns:
int, the token ID if it is an anonymized token; otherwise -1.
"""
if self.is_anon_tok(token):
return self.entity_types.index(token.split(SEPARATOR)[0])
else:
return -1
def anonymize(self,
sequence,
tok_to_entity_dict,
key,
add_new_anon_toks=False):
"""Anonymizes a sequence.
Inputs:
sequence (list of str): Sequence to anonymize.
tok_to_entity_dict (dict): Existing dictionary mapping from anonymized
tokens to entities.
key (str): Which kind of text this is (natural language or SQL)
add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict.
Returns:
list of str, the anonymized sequence.
"""
# Sort the token-tok-entity dict by the length of the modality.
sorted_dict = sorted(tok_to_entity_dict.items(),
key=lambda k: len(k[1][key]))[::-1]
anonymized_sequence = copy.deepcopy(sequence)
if add_new_anon_toks:
type_counts = {}
for entity_type in self.entity_types:
type_counts[entity_type] = 0
for token in tok_to_entity_dict:
entity_type = self.get_entity_type_from_token(token)
type_counts[entity_type] += 1
# First find occurrences of things in the anonymization dictionary.
for token, modalities in sorted_dict:
our_modality = modalities[key]
# Check if this key's version of the anonymized thing is in our
# sequence.
while util.subsequence(our_modality, anonymized_sequence):
found = False
for startidx in range(
len(anonymized_sequence) - len(our_modality) + 1):
if anonymized_sequence[startidx:startidx +
len(our_modality)] == our_modality:
anonymized_sequence = anonymized_sequence[:startidx] + [
token] + anonymized_sequence[startidx + len(our_modality):]
found = True
break
assert found, "Thought " \
+ str(our_modality) + " was in [" \
+ str(anonymized_sequence) + "] but could not find it"
# Now add new keys if they are present.
if add_new_anon_toks:
# For every span in the sequence, check whether it is in the anon map
# for this modality
sorted_anon_map = sorted(self.anonymization_map,
key=lambda k: len(k[key]))[::-1]
for pair in sorted_anon_map:
our_modality = pair[key]
token_type = pair["type"]
new_token = token_type + SEPARATOR + \
str(type_counts[token_type])
while util.subsequence(our_modality, anonymized_sequence):
found = False
for startidx in range(
len(anonymized_sequence) - len(our_modality) + 1):
if anonymized_sequence[startidx:startidx + \
len(our_modality)] == our_modality:
if new_token not in tok_to_entity_dict:
type_counts[token_type] += 1
tok_to_entity_dict[new_token] = pair
anonymized_sequence = anonymized_sequence[:startidx] + [
new_token] + anonymized_sequence[startidx + len(our_modality):]
found = True
break
assert found, "Thought " \
+ str(our_modality) + " was in [" \
+ str(anonymized_sequence) + "] but could not find it"
# Also replace integers with constants
for index, token in enumerate(anonymized_sequence):
if token.isdigit() or is_time(token):
if token.isdigit():
entity_type = CONSTANT_NAME
value = new_token
if is_time(token):
entity_type = TIME_NAME
value = timeval(token)
# First try to find the constant in the entity dictionary already,
# and get the name if it's found.
new_token = ""
new_dict = {}
found = False
for entity, value in tok_to_entity_dict.items():
if value[key][0] == token:
new_token = entity
new_dict = value
found = True
break
if not found:
new_token = entity_type + SEPARATOR + \
str(type_counts[entity_type])
new_dict = {}
for tempkey in self.keys:
new_dict[tempkey] = [token]
tok_to_entity_dict[new_token] = new_dict
type_counts[entity_type] += 1
anonymized_sequence[index] = new_token
return anonymized_sequence

View File

@ -0,0 +1,350 @@
# TODO: review this entire file and make it much simpler.
import copy
from . import snippets as snip
from . import sql_util
from . import vocabulary as vocab
class UtteranceItem():
def __init__(self, interaction, index):
self.interaction = interaction
self.utterance_index = index
def __str__(self):
return str(self.interaction.utterances[self.utterance_index])
def histories(self, maximum):
if maximum > 0:
history_seqs = []
for utterance in self.interaction.utterances[:self.utterance_index]:
history_seqs.append(utterance.input_seq_to_use)
if len(history_seqs) > maximum:
history_seqs = history_seqs[-maximum:]
return history_seqs
return []
def input_sequence(self):
return self.interaction.utterances[self.utterance_index].input_seq_to_use
def previous_query(self):
if self.utterance_index == 0:
return []
return self.interaction.utterances[self.utterance_index -
1].anonymized_gold_query
def anonymized_gold_query(self):
return self.interaction.utterances[self.utterance_index].anonymized_gold_query
def snippets(self):
return self.interaction.utterances[self.utterance_index].available_snippets
def original_gold_query(self):
return self.interaction.utterances[self.utterance_index].original_gold_query
def contained_entities(self):
return self.interaction.utterances[self.utterance_index].contained_entities
def original_gold_queries(self):
return [
q[0] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
def gold_tables(self):
return [
q[1] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
def gold_query(self):
return self.interaction.utterances[self.utterance_index].gold_query_to_use + [
vocab.EOS_TOK]
def gold_edit_sequence(self):
return self.interaction.utterances[self.utterance_index].gold_edit_sequence
def gold_table(self):
return self.interaction.utterances[self.utterance_index].gold_sql_results
def all_snippets(self):
return self.interaction.snippets
def within_limits(self,
max_input_length=float('inf'),
max_output_length=float('inf')):
return self.interaction.utterances[self.utterance_index].length_valid(
max_input_length, max_output_length)
def expand_snippets(self, sequence):
# Remove the EOS
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
# First remove the snippets
no_snippets_sequence = self.interaction.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
return no_snippets_sequence
def flatten_sequence(self, sequence):
# Remove the EOS
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
# First remove the snippets
no_snippets_sequence = self.interaction.expand_snippets(sequence)
# Deanonymize
deanon_sequence = self.interaction.deanonymize(
no_snippets_sequence, "sql")
return deanon_sequence
class UtteranceBatch():
def __init__(self, items):
self.items = items
def __len__(self):
return len(self.items)
def start(self):
self.index = 0
def next(self):
item = self.items[self.index]
self.index += 1
return item
def done(self):
return self.index >= len(self.items)
class PredUtteranceItem():
def __init__(self,
input_sequence,
interaction_item,
previous_query,
index,
available_snippets):
self.input_seq_to_use = input_sequence
self.interaction_item = interaction_item
self.index = index
self.available_snippets = available_snippets
self.prev_pred_query = previous_query
def input_sequence(self):
return self.input_seq_to_use
def histories(self, maximum):
if maximum == 0:
return histories
histories = []
for utterance in self.interaction_item.processed_utterances[:self.index]:
histories.append(utterance.input_sequence())
if len(histories) > maximum:
histories = histories[-maximum:]
return histories
def snippets(self):
return self.available_snippets
def previous_query(self):
return self.prev_pred_query
def flatten_sequence(self, sequence):
return self.interaction_item.flatten_sequence(sequence)
def remove_snippets(self, sequence):
return sql_util.fix_parentheses(
self.interaction_item.expand_snippets(sequence))
def set_predicted_query(self, query):
self.anonymized_pred_query = query
# Mocks an Interaction item, but allows for the parameters to be updated during
# the process
class InteractionItem():
def __init__(self,
interaction,
max_input_length=float('inf'),
max_output_length=float('inf'),
nl_to_sql_dict={},
maximum_length=float('inf')):
if maximum_length != float('inf'):
self.interaction = copy.deepcopy(interaction)
self.interaction.utterances = self.interaction.utterances[:maximum_length]
else:
self.interaction = interaction
self.processed_utterances = []
self.snippet_bank = []
self.identifier = self.interaction.identifier
self.max_input_length = max_input_length
self.max_output_length = max_output_length
self.nl_to_sql_dict = nl_to_sql_dict
self.index = 0
def __len__(self):
return len(self.interaction)
def __str__(self):
s = "Utterances, gold queries, and predictions:\n"
for i, utterance in enumerate(self.interaction.utterances):
s += " ".join(utterance.input_seq_to_use) + "\n"
pred_utterance = self.processed_utterances[i]
s += " ".join(pred_utterance.gold_query()) + "\n"
s += " ".join(pred_utterance.anonymized_query()) + "\n"
s += "\n"
s += "Snippets:\n"
for snippet in self.snippet_bank:
s += str(snippet) + "\n"
return s
def start_interaction(self):
assert len(self.snippet_bank) == 0
assert len(self.processed_utterances) == 0
assert self.index == 0
def next_utterance(self):
utterance = self.interaction.utterances[self.index]
self.index += 1
available_snippets = self.available_snippets(snippet_keep_age=1)
return PredUtteranceItem(utterance.input_seq_to_use,
self,
self.processed_utterances[-1].anonymized_pred_query if len(self.processed_utterances) > 0 else [],
self.index - 1,
available_snippets)
def done(self):
return len(self.processed_utterances) == len(self.interaction)
def finish(self):
self.snippet_bank = []
self.processed_utterances = []
self.index = 0
def utterance_within_limits(self, utterance_item):
return utterance_item.within_limits(self.max_input_length,
self.max_output_length)
def available_snippets(self, snippet_keep_age):
return [
snippet for snippet in self.snippet_bank if snippet.index <= snippet_keep_age]
def gold_utterances(self):
utterances = []
for i, utterance in enumerate(self.interaction.utterances):
utterances.append(UtteranceItem(self.interaction, i))
return utterances
def get_schema(self):
return self.interaction.schema
def add_utterance(
self,
utterance,
predicted_sequence,
snippets=None,
previous_snippets=[],
simple=False):
if not snippets:
self.add_snippets(
predicted_sequence,
previous_snippets=previous_snippets, simple=simple)
else:
for snippet in snippets:
snippet.assign_id(len(self.snippet_bank))
self.snippet_bank.append(snippet)
for snippet in self.snippet_bank:
snippet.increase_age()
self.processed_utterances.append(utterance)
def add_snippets(self, sequence, previous_snippets=[], simple=False):
if sequence:
if simple:
snippets = sql_util.get_subtrees_simple(
sequence, oldsnippets=previous_snippets)
else:
snippets = sql_util.get_subtrees(
sequence, oldsnippets=previous_snippets)
for snippet in snippets:
snippet.assign_id(len(self.snippet_bank))
self.snippet_bank.append(snippet)
for snippet in self.snippet_bank:
snippet.increase_age()
def expand_snippets(self, sequence):
return sql_util.fix_parentheses(
snip.expand_snippets(
sequence, self.snippet_bank))
def remove_snippets(self, sequence):
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
no_snippets_sequence = self.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
return no_snippets_sequence
def flatten_sequence(self, sequence, gold_snippets=False):
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
if gold_snippets:
no_snippets_sequence = self.interaction.expand_snippets(sequence)
else:
no_snippets_sequence = self.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
deanon_sequence = self.interaction.deanonymize(
no_snippets_sequence, "sql")
return deanon_sequence
def gold_query(self, index):
return self.interaction.utterances[index].gold_query_to_use + [
vocab.EOS_TOK]
def original_gold_query(self, index):
return self.interaction.utterances[index].original_gold_query
def gold_table(self, index):
return self.interaction.utterances[index].gold_sql_results
class InteractionBatch():
def __init__(self, items):
self.items = items
def __len__(self):
return len(self.items)
def start(self):
self.timestep = 0
self.current_interactions = []
def get_next_utterance_batch(self, snippet_keep_age, use_gold=False):
items = []
self.current_interactions = []
for interaction in self.items:
if self.timestep < len(interaction):
utterance_item = interaction.original_utterances(
snippet_keep_age, use_gold)[self.timestep]
self.current_interactions.append(interaction)
items.append(utterance_item)
self.timestep += 1
return UtteranceBatch(items)
def done(self):
finished = True
for interaction in self.items:
if self.timestep < len(interaction):
finished = False
return finished
return finished

View File

@ -0,0 +1,376 @@
""" Utility functions for loading and processing ATIS data."""
import os
import random
import json
from . import anonymization as anon
from . import atis_batch
from . import dataset_split as ds
from .interaction import load_function
from .entities import NLtoSQLDict
from .atis_vocab import ATISVocabulary
ENTITIES_FILENAME = 'data/entities.txt'
ANONYMIZATION_FILENAME = 'data/anonymization.txt'
class ATISDataset():
""" Contains the ATIS data. """
def __init__(self, params):
self.anonymizer = None
if params.anonymize:
self.anonymizer = anon.Anonymizer(ANONYMIZATION_FILENAME)
if not os.path.exists(params.data_directory):
os.mkdir(params.data_directory)
self.entities_dictionary = NLtoSQLDict(ENTITIES_FILENAME)
database_schema = None
if params.database_schema_filename:
if 'removefrom' not in params.data_directory:
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema_simple(params.database_schema_filename)
else:
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema(params.database_schema_filename)
int_load_function = load_function(params,
self.entities_dictionary,
self.anonymizer,
database_schema=database_schema)
def collapse_list(the_list):
""" Collapses a list of list into a single list."""
return [s for i in the_list for s in i]
if 'atis' not in params.data_directory:
self.train_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_train_filename),
params.raw_train_filename,
int_load_function)
self.valid_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_validation_filename),
params.raw_validation_filename,
int_load_function)
train_input_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.input_seqs()))
valid_input_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.input_seqs()))
all_input_seqs = train_input_seqs + valid_input_seqs
self.input_vocabulary = ATISVocabulary(
all_input_seqs,
os.path.join(params.data_directory, params.input_vocabulary_filename),
params,
is_input='input',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
self.output_vocabulary_schema = ATISVocabulary(
column_names_embedder_input,
os.path.join(params.data_directory, 'schema_'+params.output_vocabulary_filename),
params,
is_input='schema',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
train_output_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.output_seqs()))
valid_output_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.output_seqs()))
all_output_seqs = train_output_seqs + valid_output_seqs
sql_keywords = ['.', 't1', 't2', '=', 'select', 'as', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'group', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
sql_keywords += ['count', 'from', 'value', 'order']
sql_keywords += ['group_by', 'order_by', 'limit_value', '!=']
# skip column_names_surface_form but keep sql_keywords
skip_tokens = list(set(column_names_surface_form) - set(sql_keywords))
if params.data_directory == 'processed_data_sparc_removefrom_test':
all_output_seqs = []
out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'avg', 'and', 'having', '<', 'in', 'max', 'sum', 'asc', 'like', 'not', 'or', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
for i in range(len(out_vocab_ordered)):
all_output_seqs.append(out_vocab_ordered[:i+1])
self.output_vocabulary = ATISVocabulary(
all_output_seqs,
os.path.join(params.data_directory, params.output_vocabulary_filename),
params,
is_input='output',
anonymizer=self.anonymizer if params.anonymization_scoring else None,
skip=skip_tokens)
else:
self.train_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_train_filename),
params.raw_train_filename,
int_load_function)
if params.train:
self.valid_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_validation_filename),
params.raw_validation_filename,
int_load_function)
if params.evaluate or params.attention:
self.dev_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_dev_filename),
params.raw_dev_filename,
int_load_function)
if params.enable_testing:
self.test_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_test_filename),
params.raw_test_filename,
int_load_function)
train_input_seqs = []
train_input_seqs = collapse_list(
self.train_data.get_ex_properties(
lambda i: i.input_seqs()))
self.input_vocabulary = ATISVocabulary(
train_input_seqs,
os.path.join(params.data_directory, params.input_vocabulary_filename),
params,
is_input='input',
min_occur=2,
anonymizer=self.anonymizer if params.anonymization_scoring else None)
train_output_seqs = collapse_list(
self.train_data.get_ex_properties(
lambda i: i.output_seqs()))
self.output_vocabulary = ATISVocabulary(
train_output_seqs,
os.path.join(params.data_directory, params.output_vocabulary_filename),
params,
is_input='output',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
self.output_vocabulary_schema = None
def read_database_schema_simple(self, database_schema_filename):
with open(database_schema_filename, "r") as f:
database_schema = json.load(f)
database_schema_dict = {}
column_names_surface_form = []
column_names_embedder_input = []
for table_schema in database_schema:
db_id = table_schema['db_id']
database_schema_dict[db_id] = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
for table_name in table_names_original:
column_names_surface_form.append(table_name.lower())
for i, (table_id, column_name) in enumerate(column_names):
column_name_embedder_input = column_name
column_names_embedder_input.append(column_name_embedder_input.split())
for table_name in table_names:
column_names_embedder_input.append(table_name.split())
database_schema = database_schema_dict
return database_schema, column_names_surface_form, column_names_embedder_input
def read_database_schema(self, database_schema_filename):
with open(database_schema_filename, "r") as f:
database_schema = json.load(f)
database_schema_dict = {}
column_names_surface_form = []
column_names_embedder_input = []
for table_schema in database_schema:
db_id = table_schema['db_id']
database_schema_dict[db_id] = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
for i, (table_id, column_name) in enumerate(column_names):
if table_id >= 0:
table_name = table_names[table_id]
column_name_embedder_input = table_name + ' . ' + column_name
else:
column_name_embedder_input = column_name
column_names_embedder_input.append(column_name_embedder_input.split())
for table_name in table_names:
column_name_embedder_input = table_name + ' . *'
column_names_embedder_input.append(column_name_embedder_input.split())
database_schema = database_schema_dict
return database_schema, column_names_surface_form, column_names_embedder_input
def get_all_utterances(self,
dataset,
max_input_length=float('inf'),
max_output_length=float('inf')):
""" Returns all utterances in a dataset."""
items = []
for interaction in dataset.examples:
for i, utterance in enumerate(interaction.utterances):
if utterance.length_valid(max_input_length, max_output_length):
items.append(atis_batch.UtteranceItem(interaction, i))
return items
def get_all_interactions(self,
dataset,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf'),
sorted_by_length=False):
"""Gets all interactions in a dataset that fit the criteria.
Inputs:
dataset (ATISDatasetSplit): The dataset to use.
max_interaction_length (int): Maximum interaction length to keep.
max_input_length (int): Maximum input sequence length to keep.
max_output_length (int): Maximum output sequence length to keep.
sorted_by_length (bool): Whether to sort the examples by interaction length.
"""
ints = [
atis_batch.InteractionItem(
interaction,
max_input_length,
max_output_length,
self.entities_dictionary,
max_interaction_length) for interaction in dataset.examples]
if sorted_by_length:
return sorted(ints, key=lambda x: len(x))[::-1]
else:
return ints
# This defines a standard way of training: each example is an utterance, and
# the batch can contain unrelated utterances.
def get_utterance_batches(self,
batch_size,
max_input_length=float('inf'),
max_output_length=float('inf'),
randomize=True):
"""Gets batches of utterances in the data.
Inputs:
batch_size (int): Batch size to use.
max_input_length (int): Maximum length of input to keep.
max_output_length (int): Maximum length of output to use.
randomize (bool): Whether to randomize the ordering.
"""
# First, get all interactions and the positions of the utterances that are
# possible in them.
items = self.get_all_utterances(self.train_data,
max_input_length,
max_output_length)
if randomize:
random.shuffle(items)
batches = []
current_batch_items = []
for item in items:
if len(current_batch_items) >= batch_size:
batches.append(atis_batch.UtteranceBatch(current_batch_items))
current_batch_items = []
current_batch_items.append(item)
batches.append(atis_batch.UtteranceBatch(current_batch_items))
assert sum([len(batch) for batch in batches]) == len(items)
return batches
def get_interaction_batches(self,
batch_size,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf'),
randomize=True):
"""Gets batches of interactions in the data.
Inputs:
batch_size (int): Batch size to use.
max_interaction_length (int): Maximum length of interaction to keep
max_input_length (int): Maximum length of input to keep.
max_output_length (int): Maximum length of output to keep.
randomize (bool): Whether to randomize the ordering.
"""
items = self.get_all_interactions(self.train_data,
max_interaction_length,
max_input_length,
max_output_length,
sorted_by_length=not randomize)
if randomize:
random.shuffle(items)
batches = []
current_batch_items = []
for item in items:
if len(current_batch_items) >= batch_size:
batches.append(
atis_batch.InteractionBatch(current_batch_items))
current_batch_items = []
current_batch_items.append(item)
batches.append(atis_batch.InteractionBatch(current_batch_items))
assert sum([len(batch) for batch in batches]) == len(items)
return batches
def get_random_utterances(self,
num_samples,
max_input_length=float('inf'),
max_output_length=float('inf')):
"""Gets a random selection of utterances in the data.
Inputs:
num_samples (bool): Number of random utterances to get.
max_input_length (int): Limit of input length.
max_output_length (int): Limit on output length.
"""
items = self.get_all_utterances(self.train_data,
max_input_length,
max_output_length)
random.shuffle(items)
return items[:num_samples]
def get_random_interactions(self,
num_samples,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf')):
"""Gets a random selection of interactions in the data.
Inputs:
num_samples (bool): Number of random interactions to get.
max_input_length (int): Limit of input length.
max_output_length (int): Limit on output length.
"""
items = self.get_all_interactions(self.train_data,
max_interaction_length,
max_input_length,
max_output_length)
random.shuffle(items)
return items[:num_samples]
def num_utterances(dataset):
"""Returns the total number of utterances in the dataset."""
return sum([len(interaction) for interaction in dataset.examples])

View File

@ -0,0 +1,74 @@
"""Gets and stores vocabulary for the ATIS data."""
from . import snippets
from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK
INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK]
OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK]
MIN_INPUT_OCCUR = 1
MIN_OUTPUT_OCCUR = 1
class ATISVocabulary():
""" Stores the vocabulary for the ATIS data.
Attributes:
raw_vocab (Vocabulary): Vocabulary object.
tokens (set of str): Set of all of the strings in the vocabulary.
inorder_tokens (list of str): List of all tokens, with a strict and
unchanging order.
"""
def __init__(self,
token_sequences,
filename,
params,
is_input='input',
min_occur=1,
anonymizer=None,
skip=None):
if is_input=='input':
functional_types = INPUT_FN_TYPES
elif is_input=='output':
functional_types = OUTPUT_FN_TYPES
elif is_input=='schema':
functional_types = [UNK_TOK]
else:
functional_types = []
self.raw_vocab = Vocabulary(
token_sequences,
filename,
functional_types=functional_types,
min_occur=min_occur,
ignore_fn=lambda x: snippets.is_snippet(x) or (
anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) )
self.tokens = set(self.raw_vocab.token_to_id.keys())
self.inorder_tokens = self.raw_vocab.id_to_token
assert len(self.inorder_tokens) == len(self.raw_vocab)
def __len__(self):
return len(self.raw_vocab)
def token_to_id(self, token):
""" Maps from a token to a unique ID.
Inputs:
token (str): The token to look up.
Returns:
int, uniquely identifying the token.
"""
return self.raw_vocab.token_to_id[token]
def id_to_token(self, identifier):
""" Maps from a unique integer to an identifier.
Inputs:
identifier (int): The unique ID.
Returns:
string, representing the token.
"""
return self.raw_vocab.id_to_token[identifier]

View File

@ -0,0 +1,56 @@
""" Utility functions for loading and processing ATIS data.
"""
import os
import pickle
class DatasetSplit:
"""Stores a split of the ATIS dataset.
Attributes:
examples (list of Interaction): Stores the examples in the split.
"""
def __init__(self, processed_filename, raw_filename, load_function):
if os.path.exists(processed_filename):
print("Loading preprocessed data from " + processed_filename)
with open(processed_filename, 'rb') as infile:
self.examples = pickle.load(infile)
else:
print(
"Loading raw data from " +
raw_filename +
" and writing to " +
processed_filename)
infile = open(raw_filename, 'rb')
examples_from_file = pickle.load(infile)
assert isinstance(examples_from_file, list), raw_filename + \
" does not contain a list of examples"
infile.close()
self.examples = []
for example in examples_from_file:
obj, keep = load_function(example)
if keep:
self.examples.append(obj)
print("Loaded " + str(len(self.examples)) + " examples")
outfile = open(processed_filename, 'wb')
pickle.dump(self.examples, outfile)
outfile.close()
def get_ex_properties(self, function):
""" Applies some function to the examples in the dataset.
Inputs:
function: (lambda Interaction -> T): Function to apply to all
examples.
Returns
list of the return value of the function
"""
elems = []
for example in self.examples:
elems.append(function(example))
return elems

View File

@ -0,0 +1,63 @@
""" Classes for keeping track of the entities in a natural language string. """
import json
class NLtoSQLDict:
"""
Entity dict file should contain, on each line, a JSON dictionary with
"input" and "output" keys specifying the string for the input and output
pairs. The idea is that the existence of the key in an input sequence
likely corresponds to the existence of the value in the output sequence.
The entity_dict should map keys (input strings) to a list of values (output
strings) where this property holds. This allows keys to map to multiple
output strings (e.g. for times).
"""
def __init__(self, entity_dict_filename):
self.entity_dict = {}
pairs = [json.loads(line)
for line in open(entity_dict_filename).readlines()]
for pair in pairs:
input_seq = pair["input"]
output_seq = pair["output"]
if input_seq not in self.entity_dict:
self.entity_dict[input_seq] = []
self.entity_dict[input_seq].append(output_seq)
def get_sql_entities(self, tokenized_nl_string):
"""
Gets the output-side entities which correspond to the input entities in
the input sequence.
Inputs:
tokenized_input_string: list of tokens in the input string.
Outputs:
set of output strings.
"""
assert len(tokenized_nl_string) > 0
flat_input_string = " ".join(tokenized_nl_string)
entities = []
# See if any input strings are in our input sequence, and add the
# corresponding output strings if so.
for entry, values in self.entity_dict.items():
in_middle = " " + entry + " " in flat_input_string
leftspace = " " + entry
at_end = leftspace in flat_input_string and flat_input_string.endswith(
leftspace)
rightspace = entry + " "
at_beginning = rightspace in flat_input_string and flat_input_string.startswith(
rightspace)
if in_middle or at_end or at_beginning:
for out_string in values:
entities.append(out_string)
# Also add any integers in the input string (these aren't in the entity)
# dict.
for token in tokenized_nl_string:
if token.isnumeric():
entities.append(token)
return entities

View File

@ -0,0 +1,312 @@
""" Contains the class for an interaction in ATIS. """
from . import anonymization as anon
from . import sql_util
from .snippets import expand_snippets
from .utterance import Utterance, OUTPUT_KEY, ANON_INPUT_KEY
import torch
class Schema:
def __init__(self, table_schema, simple=False):
if simple:
self.helper1(table_schema)
else:
self.helper2(table_schema)
def helper1(self, table_schema):
self.table_schema = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
column_keep_index = []
self.column_names_surface_form = []
self.column_names_surface_form_to_id = {}
for i, (table_id, column_name) in enumerate(column_names_original):
column_name_surface_form = column_name
column_name_surface_form = column_name_surface_form.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index.append(i)
column_keep_index_2 = []
for i, table_name in enumerate(table_names_original):
column_name_surface_form = table_name.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index_2.append(i)
self.column_names_embedder_input = []
self.column_names_embedder_input_to_id = {}
for i, (table_id, column_name) in enumerate(column_names):
column_name_embedder_input = column_name
if i in column_keep_index:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
for i, table_name in enumerate(table_names):
column_name_embedder_input = table_name
if i in column_keep_index_2:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
self.num_col = len(self.column_names_surface_form)
def helper2(self, table_schema):
self.table_schema = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
column_keep_index = []
self.column_names_surface_form = []
self.column_names_surface_form_to_id = {}
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
column_name_surface_form = column_name
column_name_surface_form = column_name_surface_form.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index.append(i)
start_i = len(self.column_names_surface_form_to_id)
for i, table_name in enumerate(table_names_original):
column_name_surface_form = '{}.*'.format(table_name.lower())
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = i + start_i
self.column_names_embedder_input = []
self.column_names_embedder_input_to_id = {}
for i, (table_id, column_name) in enumerate(column_names):
if table_id >= 0:
table_name = table_names[table_id]
column_name_embedder_input = table_name + ' . ' + column_name
else:
column_name_embedder_input = column_name
if i in column_keep_index:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
start_i = len(self.column_names_embedder_input_to_id)
for i, table_name in enumerate(table_names):
column_name_embedder_input = table_name + ' . *'
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = i + start_i
assert len(self.column_names_surface_form) == len(self.column_names_surface_form_to_id) == len(self.column_names_embedder_input) == len(self.column_names_embedder_input_to_id)
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
self.num_col = len(self.column_names_surface_form)
def __len__(self):
return self.num_col
def in_vocabulary(self, column_name, surface_form=False):
if surface_form:
return column_name in self.column_names_surface_form_to_id
else:
return column_name in self.column_names_embedder_input_to_id
def column_name_embedder_bow(self, column_name, surface_form=False, column_name_token_embedder=None):
assert self.in_vocabulary(column_name, surface_form)
if surface_form:
column_name_id = self.column_names_surface_form_to_id[column_name]
column_name_embedder_input = self.column_names_embedder_input[column_name_id]
else:
column_name_embedder_input = column_name
column_name_embeddings = [column_name_token_embedder(token) for token in column_name_embedder_input.split()]
column_name_embeddings = torch.stack(column_name_embeddings, dim=0)
return torch.mean(column_name_embeddings, dim=0)
def set_column_name_embeddings(self, column_name_embeddings):
self.column_name_embeddings = column_name_embeddings
assert len(self.column_name_embeddings) == self.num_col
def column_name_embedder(self, column_name, surface_form=False):
assert self.in_vocabulary(column_name, surface_form)
if surface_form:
column_name_id = self.column_names_surface_form_to_id[column_name]
else:
column_name_id = self.column_names_embedder_input_to_id[column_name]
return self.column_name_embeddings[column_name_id]
class Interaction:
""" ATIS interaction class.
Attributes:
utterances (list of Utterance): The utterances in the interaction.
snippets (list of Snippet): The snippets that appear through the interaction.
anon_tok_to_ent:
identifier (str): Unique identifier for the interaction in the dataset.
"""
def __init__(self,
utterances,
schema,
snippets,
anon_tok_to_ent,
identifier,
params):
self.utterances = utterances
self.schema = schema
self.snippets = snippets
self.anon_tok_to_ent = anon_tok_to_ent
self.identifier = identifier
# Ensure that each utterance's input and output sequences, when remapped
# without anonymization or snippets, are the same as the original
# version.
for i, utterance in enumerate(self.utterances):
deanon_input = self.deanonymize(utterance.input_seq_to_use,
ANON_INPUT_KEY)
assert deanon_input == utterance.original_input_seq, "Anonymized sequence [" \
+ " ".join(utterance.input_seq_to_use) + "] is not the same as [" \
+ " ".join(utterance.original_input_seq) + "] when deanonymized (is [" \
+ " ".join(deanon_input) + "] instead)"
desnippet_gold = self.expand_snippets(utterance.gold_query_to_use)
deanon_gold = self.deanonymize(desnippet_gold, OUTPUT_KEY)
assert deanon_gold == utterance.original_gold_query, \
"Anonymized and/or snippet'd query " \
+ " ".join(utterance.gold_query_to_use) + " is not the same as " \
+ " ".join(utterance.original_gold_query)
def __str__(self):
string = "Utterances:\n"
for utterance in self.utterances:
string += str(utterance) + "\n"
string += "Anonymization dictionary:\n"
for ent_tok, deanon in self.anon_tok_to_ent.items():
string += ent_tok + "\t" + str(deanon) + "\n"
return string
def __len__(self):
return len(self.utterances)
def deanonymize(self, sequence, key):
""" Deanonymizes a predicted query or an input utterance.
Inputs:
sequence (list of str): The sequence to deanonymize.
key (str): The key in the anonymization table, e.g. NL or SQL.
"""
return anon.deanonymize(sequence, self.anon_tok_to_ent, key)
def expand_snippets(self, sequence):
""" Expands snippets for a sequence.
Inputs:
sequence (list of str): A SQL query.
"""
return expand_snippets(sequence, self.snippets)
def input_seqs(self):
in_seqs = []
for utterance in self.utterances:
in_seqs.append(utterance.input_seq_to_use)
return in_seqs
def output_seqs(self):
out_seqs = []
for utterance in self.utterances:
out_seqs.append(utterance.gold_query_to_use)
return out_seqs
def load_function(parameters,
nl_to_sql_dict,
anonymizer,
database_schema=None):
def fn(interaction_example):
keep = False
raw_utterances = interaction_example["interaction"]
if "database_id" in interaction_example:
database_id = interaction_example["database_id"]
interaction_id = interaction_example["interaction_id"]
identifier = database_id + '/' + str(interaction_id)
else:
identifier = interaction_example["id"]
schema = None
if database_schema:
if 'removefrom' not in parameters.data_directory:
schema = Schema(database_schema[database_id], simple=True)
else:
schema = Schema(database_schema[database_id])
snippet_bank = []
utterance_examples = []
anon_tok_to_ent = {}
for utterance in raw_utterances:
available_snippets = [
snippet for snippet in snippet_bank if snippet.index <= 1]
proc_utterance = Utterance(utterance,
available_snippets,
nl_to_sql_dict,
parameters,
anon_tok_to_ent,
anonymizer)
keep_utterance = proc_utterance.keep
if schema:
assert keep_utterance
if keep_utterance:
keep = True
utterance_examples.append(proc_utterance)
# Update the snippet bank, and age each snippet in it.
if parameters.use_snippets:
if 'atis' in parameters.data_directory:
snippets = sql_util.get_subtrees(
proc_utterance.anonymized_gold_query,
proc_utterance.available_snippets)
else:
snippets = sql_util.get_subtrees_simple(
proc_utterance.anonymized_gold_query,
proc_utterance.available_snippets)
for snippet in snippets:
snippet.assign_id(len(snippet_bank))
snippet_bank.append(snippet)
for snippet in snippet_bank:
snippet.increase_age()
interaction = Interaction(utterance_examples,
schema,
snippet_bank,
anon_tok_to_ent,
identifier,
parameters)
return interaction, keep
return fn

View File

@ -0,0 +1,105 @@
""" Contains the Snippet class and methods for handling snippets.
Attributes:
SNIPPET_PREFIX: string prefix for snippets.
"""
SNIPPET_PREFIX = "SNIPPET_"
def is_snippet(token):
""" Determines whether a token is a snippet or not.
Inputs:
token (str): The token to check.
Returns:
bool, indicating whether it's a snippet.
"""
return token.startswith(SNIPPET_PREFIX)
def expand_snippets(sequence, snippets):
""" Given a sequence and a list of snippets, expand the snippets in the sequence.
Inputs:
sequence (list of str): Query containing snippet references.
snippets (list of Snippet): List of available snippets.
return list of str representing the expanded sequence
"""
snippet_id_to_snippet = {}
for snippet in snippets:
assert snippet.name not in snippet_id_to_snippet
snippet_id_to_snippet[snippet.name] = snippet
expanded_seq = []
for token in sequence:
if token in snippet_id_to_snippet:
expanded_seq.extend(snippet_id_to_snippet[token].sequence)
else:
assert not is_snippet(token)
expanded_seq.append(token)
return expanded_seq
def snippet_index(token):
""" Returns the index of a snippet.
Inputs:
token (str): The snippet to check.
Returns:
integer, the index of the snippet.
"""
assert is_snippet(token)
return int(token.split("_")[-1])
class Snippet():
""" Contains a snippet. """
def __init__(self,
sequence,
startpos,
sql,
age=0):
self.sequence = sequence
self.startpos = startpos
self.sql = sql
# TODO: age vs. index?
self.age = age
self.index = 0
self.name = ""
self.embedding = None
self.endpos = self.startpos + len(self.sequence)
assert self.endpos < len(self.sql), "End position of snippet is " + str(
self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")"
assert self.sequence == self.sql[self.startpos:self.endpos], \
"Value of snippet (" + " ".join(self.sequence) + ") " \
"is not the same as SQL at the same positions (" \
+ " ".join(self.sql[self.startpos:self.endpos]) + ")"
def __str__(self):
return self.name + "\t" + \
str(self.age) + "\t" + " ".join(self.sequence)
def __len__(self):
return len(self.sequence)
def increase_age(self):
""" Ages a snippet by one. """
self.index += 1
def assign_id(self, number):
""" Assigns the name of the snippet to be the prefix + the number. """
self.name = SNIPPET_PREFIX + str(number)
def set_embedding(self, embedding):
""" Sets the embedding of the snippet.
Inputs:
embedding (dy.Expression)
"""
self.embedding = embedding

View File

@ -0,0 +1,445 @@
import copy
import pymysql
import random
import signal
import sqlparse
from . import util
from .snippets import Snippet
from sqlparse import tokens as token_types
from sqlparse import sql as sql_types
interesting_selects = ["DISTINCT", "MAX", "MIN", "count"]
ignored_subtrees = [["1", "=", "1"]]
# strip_whitespace_front
# Strips whitespace and punctuation from the front of a SQL token list.
#
# Inputs:
# token_list: the token list.
#
# Outputs:
# new token list.
def strip_whitespace_front(token_list):
new_token_list = []
found_valid = False
for token in token_list:
if not (token.is_whitespace or token.ttype ==
token_types.Punctuation) or found_valid:
found_valid = True
new_token_list.append(token)
return new_token_list
# strip_whitespace
# Strips whitespace from a token list.
#
# Inputs:
# token_list: the token list.
#
# Outputs:
# new token list with no whitespace/punctuation surrounding.
def strip_whitespace(token_list):
subtokens = strip_whitespace_front(token_list)
subtokens = strip_whitespace_front(subtokens[::-1])[::-1]
return subtokens
# token_list_to_seq
# Converts a Token list to a sequence of strings, stripping out surrounding
# punctuation and all whitespace.
#
# Inputs:
# token_list: the list of tokens.
#
# Outputs:
# list of strings
def token_list_to_seq(token_list):
subtokens = strip_whitespace(token_list)
seq = []
flat = sqlparse.sql.TokenList(subtokens).flatten()
for i, token in enumerate(flat):
strip_token = str(token).strip()
if len(strip_token) > 0:
seq.append(strip_token)
if len(seq) > 0:
if seq[0] == "(" and seq[-1] == ")":
seq = seq[1:-1]
return seq
# TODO: clean this up
# find_subtrees
# Finds subtrees for a subsequence of SQL.
#
# Inputs:
# sequence: sequence of SQL tokens.
# current_subtrees: current list of subtrees.
#
# Optional inputs:
# where_parent: whether the parent of the current sequence was a where clause
# keep_conj_subtrees: whether to look for a conjunction in this sequence and
# keep its arguments
def find_subtrees(sequence,
current_subtrees,
where_parent=False,
keep_conj_subtrees=False):
# If the parent of the subsequence was a WHERE clause, keep everything in the
# sequence except for the beginning WHERE and any surrounding parentheses.
if where_parent:
# Strip out the beginning WHERE, and any punctuation or whitespace at the
# beginning or end of the token list.
seq = token_list_to_seq(sequence.tokens[1:])
if len(seq) > 0 and seq not in current_subtrees:
current_subtrees.append(seq)
# If the current sequence has subtokens, i.e. if it's a node that can be
# expanded, check for a conjunction in its subtrees, and expand its subtrees.
# Also check for any SELECT statements and keep track of what follows.
if sequence.is_group:
if keep_conj_subtrees:
subtokens = strip_whitespace(sequence.tokens)
# Check if there is a conjunction in the subsequence. If so, keep the
# children. Also make sure you don't split where AND is used within a
# child -- the subtokens sequence won't treat those ANDs differently (a
# bit hacky but it works)
has_and = False
for i, token in enumerate(subtokens):
if token.value == "OR" or token.value == "AND":
has_and = True
break
if has_and:
and_subtrees = []
current_subtree = []
for i, token in enumerate(subtokens):
if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len(
subtokens) and subtokens[i - 4].value != "BETWEEN"):
and_subtrees.append(current_subtree)
current_subtree = []
else:
current_subtree.append(token)
and_subtrees.append(current_subtree)
for subtree in and_subtrees:
seq = token_list_to_seq(subtree)
if len(seq) > 0 and seq[0] == "WHERE":
seq = seq[1:]
if seq not in current_subtrees:
current_subtrees.append(seq)
in_select = False
select_toks = []
for i, token in enumerate(sequence.tokens):
# Mark whether this current token is a WHERE.
is_where = (isinstance(token, sql_types.Where))
# If you are in a SELECT, start recording what follows until you hit a
# FROM
if token.value == "SELECT":
in_select = True
elif in_select:
select_toks.append(token)
if token.value == "FROM":
in_select = False
seq = []
if len(sequence.tokens) > i + 2:
seq = token_list_to_seq(
select_toks + [sequence.tokens[i + 2]])
if seq not in current_subtrees and len(
seq) > 0 and seq[0] in interesting_selects:
current_subtrees.append(seq)
select_toks = []
# Recursively find subtrees in the children of the node.
find_subtrees(token,
current_subtrees,
is_where,
where_parent or keep_conj_subtrees)
# get_subtrees
def get_subtrees(sql, oldsnippets=[]):
parsed = sqlparse.parse(" ".join(sql))[0]
subtrees = []
find_subtrees(parsed, subtrees)
final_subtrees = []
for subtree in subtrees:
if subtree not in ignored_subtrees:
final_version = []
keep = True
parens_counts = 0
for i, token in enumerate(subtree):
if token == ".":
newtoken = final_version[-1] + "." + subtree[i + 1]
final_version = final_version[:-1] + [newtoken]
keep = False
elif keep:
final_version.append(token)
else:
keep = True
if token == "(":
parens_counts -= 1
elif token == ")":
parens_counts += 1
if parens_counts == 0:
final_subtrees.append(final_version)
snippets = []
sql = [str(tok) for tok in sql]
for subtree in final_subtrees:
startpos = -1
for i in range(len(sql) - len(subtree) + 1):
if sql[i:i + len(subtree)] == subtree:
startpos = i
if startpos >= 0 and startpos + len(subtree) < len(sql):
age = 0
for prevsnippet in oldsnippets:
if prevsnippet.sequence == subtree:
age = prevsnippet.age + 1
snippet = Snippet(subtree, startpos, sql, age=age)
snippets.append(snippet)
return snippets
def get_subtrees_simple(sql, oldsnippets=[]):
sql_string = " ".join(sql)
format_sql = sqlparse.format(sql_string, reindent=True)
# get subtrees
subtrees = []
for sub_sql in format_sql.split('\n'):
sub_sql = sub_sql.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ')
subtree = sub_sql.strip().split()
if len(subtree) > 1:
subtrees.append(subtree)
final_subtrees = subtrees
snippets = []
sql = [str(tok) for tok in sql]
for subtree in final_subtrees:
startpos = -1
for i in range(len(sql) - len(subtree) + 1):
if sql[i:i + len(subtree)] == subtree:
startpos = i
if startpos >= 0 and startpos + len(subtree) <= len(sql):
age = 0
for prevsnippet in oldsnippets:
if prevsnippet.sequence == subtree:
age = prevsnippet.age + 1
new_sql = sql + [';']
snippet = Snippet(subtree, startpos, new_sql, age=age)
snippets.append(snippet)
return snippets
conjunctions = {"AND", "OR", "WHERE"}
def get_all_in_parens(sequence):
if sequence[-1] == ";":
sequence = sequence[:-1]
if not "(" in sequence:
return []
if sequence[0] == "(" and sequence[-1] == ")":
in_parens = sequence[1:-1]
return [in_parens] + get_all_in_parens(in_parens)
else:
paren_subseqs = []
current_seq = []
num_parens = 0
in_parens = False
for token in sequence:
if in_parens:
current_seq.append(token)
if token == ")":
num_parens -= 1
if num_parens == 0:
in_parens = False
paren_subseqs.append(current_seq)
current_seq = []
elif token == "(":
in_parens = True
current_seq.append(token)
if token == "(":
num_parens += 1
all_subseqs = []
for subseq in paren_subseqs:
all_subseqs.extend(get_all_in_parens(subseq))
return all_subseqs
def split_by_conj(sequence):
num_parens = 0
current_seq = []
subsequences = []
for token in sequence:
if num_parens == 0:
if token in conjunctions:
subsequences.append(current_seq)
current_seq = []
break
current_seq.append(token)
if token == "(":
num_parens += 1
elif token == ")":
num_parens -= 1
assert num_parens >= 0
return subsequences
def get_sql_snippets(sequence):
# First, get all subsequences of the sequence that are surrounded by
# parentheses.
all_in_parens = get_all_in_parens(sequence)
all_subseq = []
# Then for each one, split the sequence on conjunctions (AND/OR).
for seq in all_in_parens:
subsequences = split_by_conj(seq)
all_subseq.append(seq)
all_subseq.extend(subsequences)
# Finally, also get "interesting" selects
for i, seq in enumerate(all_subseq):
print(str(i) + "\t" + " ".join(seq))
exit()
# add_snippets_to_query
def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.):
query_copy = copy.copy(query)
# Replace the longest snippets first, so sort by length descending.
sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1]
for snippet in sorted_snippets:
ignore = False
snippet_seq = snippet.sequence
# TODO: continue here
# If it contains an ignored entity, then don't use it.
for entity in ignored_entities:
ignore = ignore or util.subsequence(entity, snippet_seq)
# No NL entities found in snippet, then see if snippet is a substring of
# the gold sequence
if not ignore:
snippet_length = len(snippet_seq)
# Iterate through gold sequence to see if it's a subsequence.
for start_idx in range(len(query_copy) - snippet_length + 1):
if query_copy[start_idx:start_idx +
snippet_length] == snippet_seq:
align = random.random() < prob_align
if align:
prev_length = len(query_copy)
# At the start position of the snippet, replace with an
# identifier.
query_copy[start_idx] = snippet.name
# Then cut out the indices which were collapsed into
# the snippet.
query_copy = query_copy[:start_idx + 1] + \
query_copy[start_idx + snippet_length:]
# Make sure the length is as expected
assert len(query_copy) == prev_length - \
(snippet_length - 1)
return query_copy
def execution_results(query, username, password, timeout=3):
connection = pymysql.connect(user=username, password=password)
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException
signal.signal(signal.SIGALRM, timeout_handler)
syntactic = True
semantic = True
table = []
with connection.cursor() as cursor:
signal.alarm(timeout)
try:
cursor.execute("SET sql_mode='IGNORE_SPACE';")
cursor.execute("use atis3;")
cursor.execute(query)
table = cursor.fetchall()
cursor.close()
except TimeoutException:
signal.alarm(0)
cursor.close()
except pymysql.err.ProgrammingError:
syntactic = False
semantic = False
cursor.close()
except pymysql.err.InternalError:
semantic = False
cursor.close()
except Exception as e:
signal.alarm(0)
signal.alarm(0)
cursor.close()
signal.alarm(0)
connection.close()
return (syntactic, semantic, sorted(table))
def executable(query, username, password, timeout=2):
return execution_results(query, username, password, timeout)[1]
def fix_parentheses(sequence):
num_left = sequence.count("(")
num_right = sequence.count(")")
if num_right < num_left:
fixed_sequence = sequence[:-1] + \
[")" for _ in range(num_left - num_right)] + [sequence[-1]]
return fixed_sequence
return sequence

View File

@ -0,0 +1,83 @@
"""Tokenizers for natural language SQL queries, and lambda calculus."""
import nltk
nltk.data.path.append('./nltk_data/')
import sqlparse
def nl_tokenize(string):
"""Tokenizes a natural language string into tokens.
Inputs:
string: the string to tokenize.
Outputs:
a list of tokens.
Assumes data is space-separated (this is true of ZC07 data in ATIS2/3).
"""
return nltk.word_tokenize(string)
def sql_tokenize(string):
""" Tokenizes a SQL statement into tokens.
Inputs:
string: string to tokenize.
Outputs:
a list of tokens.
"""
tokens = []
statements = sqlparse.parse(string)
# SQLparse gives you a list of statements.
for statement in statements:
# Flatten the tokens in each statement and add to the tokens list.
flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten()
for token in flat_tokens:
strip_token = str(token).strip()
if len(strip_token) > 0:
tokens.append(strip_token)
newtokens = []
keep = True
for i, token in enumerate(tokens):
if token == ".":
newtoken = newtokens[-1] + "." + tokens[i + 1]
newtokens = newtokens[:-1] + [newtoken]
keep = False
elif keep:
newtokens.append(token)
else:
keep = True
return newtokens
def lambda_tokenize(string):
""" Tokenizes a lambda-calculus statement into tokens.
Inputs:
string: a lambda-calculus string
Outputs:
a list of tokens.
"""
space_separated = string.split(" ")
new_tokens = []
# Separate the string by spaces, then separate based on existence of ( or
# ).
for token in space_separated:
tokens = []
current_token = ""
for char in token:
if char == ")" or char == "(":
tokens.append(current_token)
tokens.append(char)
current_token = ""
else:
current_token += char
tokens.append(current_token)
new_tokens.extend([tok for tok in tokens if tok])
return new_tokens

View File

@ -0,0 +1,16 @@
"""Contains various utility functions."""
def subsequence(first_sequence, second_sequence):
"""
Returns whether the first sequence is a subsequence of the second sequence.
Inputs:
first_sequence (list): A sequence.
second_sequence (list): Another sequence.
Returns:
Boolean indicating whether first_sequence is a subsequence of second_sequence.
"""
for startidx in range(len(second_sequence) - len(first_sequence) + 1):
if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence:
return True
return False

View File

@ -0,0 +1,115 @@
""" Contains the Utterance class. """
from . import sql_util
from . import tokenizers
ANON_INPUT_KEY = "cleaned_nl"
OUTPUT_KEY = "sql"
class Utterance:
""" Utterance class. """
def process_input_seq(self,
anonymize,
anonymizer,
anon_tok_to_ent):
assert not anon_tok_to_ent or anonymize
assert not anonymize or anonymizer
if anonymize:
assert anonymizer
self.input_seq_to_use = anonymizer.anonymize(
self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True)
else:
self.input_seq_to_use = self.original_input_seq
def process_gold_seq(self,
output_sequences,
nl_to_sql_dict,
available_snippets,
anonymize,
anonymizer,
anon_tok_to_ent):
# Get entities in the input sequence:
# anonymized entity types
# othe recognized entities (this includes "flight")
entities_in_input = [
[tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent]
entities_in_input.extend(
nl_to_sql_dict.get_sql_entities(
self.input_seq_to_use))
# Get the shortest gold query (this is what we use to train)
shortest_gold_and_results = min(output_sequences,
key=lambda x: len(x[0]))
# Tokenize and anonymize it if necessary.
self.original_gold_query = shortest_gold_and_results[0]
self.gold_sql_results = shortest_gold_and_results[1]
self.contained_entities = entities_in_input
# Keep track of all gold queries and the resulting tables so that we can
# give credit if it predicts a different correct sequence.
self.all_gold_queries = output_sequences
self.anonymized_gold_query = self.original_gold_query
if anonymize:
self.anonymized_gold_query = anonymizer.anonymize(
self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False)
# Add snippets to it.
self.gold_query_to_use = sql_util.add_snippets_to_query(
available_snippets, entities_in_input, self.anonymized_gold_query)
def __init__(self,
example,
available_snippets,
nl_to_sql_dict,
params,
anon_tok_to_ent={},
anonymizer=None):
# Get output and input sequences from the dictionary representation.
output_sequences = example[OUTPUT_KEY]
self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key])
self.available_snippets = available_snippets
self.keep = False
# pruned_output_sequences = []
# for sequence in output_sequences:
# if len(sequence[0]) > 3:
# pruned_output_sequences.append(sequence)
# output_sequences = pruned_output_sequences
if len(output_sequences) > 0 and len(self.original_input_seq) > 0:
# Only keep this example if there is at least one output sequence.
self.keep = True
if len(output_sequences) == 0 or len(self.original_input_seq) == 0:
return
# Process the input sequence
self.process_input_seq(params.anonymize,
anonymizer,
anon_tok_to_ent)
# Process the gold sequence
self.process_gold_seq(output_sequences,
nl_to_sql_dict,
self.available_snippets,
params.anonymize,
anonymizer,
anon_tok_to_ent)
def __str__(self):
string = "Original input: " + " ".join(self.original_input_seq) + "\n"
string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n"
string += "Original output: " + " ".join(self.original_gold_query) + "\n"
string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n"
string += "Snippets:\n"
for snippet in self.available_snippets:
string += str(snippet) + "\n"
return string
def length_valid(self, input_limit, output_limit):
return (len(self.input_seq_to_use) < input_limit \
and len(self.gold_query_to_use) < output_limit)

View File

@ -0,0 +1,98 @@
"""Contains class and methods for storing and computing a vocabulary from text."""
import operator
import os
import pickle
# Special sequencing tokens.
UNK_TOK = "_UNK" # Replaces out-of-vocabulary words.
EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end.
DEL_TOK = ";"
class Vocabulary:
"""Vocabulary class: stores information about words in a corpus.
Members:
functional_types (list of str): Functional vocabulary words, such as EOS.
max_size (int): The maximum size of vocabulary to keep.
min_occur (int): The minimum number of times a word should occur to keep it.
id_to_token (list of str): Ordered list of word types.
token_to_id (dict str->int): Maps from each unique word type to its index.
"""
def get_vocab(self, sequences, ignore_fn):
"""Gets vocabulary from a list of sequences.
Inputs:
sequences (list of list of str): Sequences from which to compute the vocabulary.
ignore_fn (lambda str: bool): Function used to tell whether to ignore a
token during computation of the vocabulary.
Returns:
list of str, representing the unique word types in the vocabulary.
"""
type_counts = {}
for sequence in sequences:
for token in sequence:
if not ignore_fn(token):
if token not in type_counts:
type_counts[token] = 0
type_counts[token] += 1
# Create sorted list of tokens, by their counts. Reverse so it is in order of
# most frequent to least frequent.
sorted_type_counts = sorted(sorted(type_counts.items()),
key=operator.itemgetter(1))[::-1]
sorted_types = [typecount[0]
for typecount in sorted_type_counts if typecount[1] >= self.min_occur]
# Append the necessary functional tokens.
sorted_types = self.functional_types + sorted_types
# Cut off if vocab_size is set (nonnegative)
if self.max_size >= 0:
vocab = sorted_types[:max(self.max_size, len(sorted_types))]
else:
vocab = sorted_types
return vocab
def __init__(self,
sequences,
filename,
functional_types=None,
max_size=-1,
min_occur=0,
ignore_fn=lambda x: False):
self.functional_types = functional_types
self.max_size = max_size
self.min_occur = min_occur
vocab = self.get_vocab(sequences, ignore_fn)
self.id_to_token = []
self.token_to_id = {}
for i, word_type in enumerate(vocab):
self.id_to_token.append(word_type)
self.token_to_id[word_type] = i
# Load the previous vocab, if it exists.
if os.path.exists(filename):
infile = open(filename, 'rb')
loaded_vocab = pickle.load(infile)
infile.close()
print("Loaded vocabulary from " + str(filename))
if loaded_vocab.id_to_token != self.id_to_token \
or loaded_vocab.token_to_id != self.token_to_id:
print("Loaded vocabulary is different than generated vocabulary.")
else:
print("Writing vocabulary to " + str(filename))
outfile = open(filename, 'wb')
pickle.dump(self, outfile)
outfile.close()
def __len__(self):
return len(self.id_to_token)

View File

@ -0,0 +1,866 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os, sys
import json
import sqlite3
import traceback
import argparse
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)

View File

@ -0,0 +1,938 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os, sys
import json
import sqlite3
import traceback
import argparse
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for key, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
levels = ['easy', 'medium', 'hard', 'extra', 'all', "joint_all"]
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print("\n\n{:20} {:20} {:20} {:20} {:20} {:20}".format("", *turns))
counts = [scores[turn]['count'] for turn in turns]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== TRUN XECUTION ACCURACY =====================')
this_scores = [scores[turn]['exec'] for turn in turns]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== TRUN EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[turn]['exact'] for turn in turns]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = []
gseq_one = []
for l in f.readlines():
if len(l.strip()) == 0:
glist.append(gseq_one)
gseq_one = []
else:
lstrip = l.strip().split('\t')
gseq_one.append(lstrip)
#glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = []
pseq_one = []
for l in f.readlines():
if len(l.strip()) == 0:
plist.append(pseq_one)
pseq_one = []
else:
pseq_one.append(l.strip().split('\t'))
#plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
# glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
evaluator = Evaluator()
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for turn in turns:
scores[turn] = {'count': 0, 'exact': 0.}
scores[turn]['exec'] = 0
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
scores['joint_all']['count'] += 1
turn_scores = {"exec": [], "exact": []}
for idx, pg in enumerate(zip(p, g)):
p, g = pg
p_str = p[0]
p_str = p_str.replace("value", "1")
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
if idx > 3:
idx = ">4"
else:
idx += 1
turn_id = "turn " + str(idx)
scores[turn_id]['count'] += 1
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
scores[turn_id]['exec'] += 1
turn_scores['exec'].append(1)
else:
turn_scores['exec'].append(0)
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
turn_scores['exact'].append(0)
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
else:
turn_scores['exact'].append(1)
scores[turn_id]['exact'] += exact_score
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
if all(v == 1 for v in turn_scores["exec"]):
scores['joint_all']['exec'] += 1
if all(v == 1 for v in turn_scores["exact"]):
scores['joint_all']['exact'] += 1
for turn in turns:
if scores[turn]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[turn]['exec'] /= scores[turn]['count']
if etype in ["all", "match"]:
scores[turn]['exact'] /= scores[turn]['count']
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,53 @@
import json
import sys
predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line]
string_count = 0.
sem_count = 0.
syn_count = 0.
table_count = 0.
strict_table_count = 0.
precision_denom = 0.
precision = 0.
recall_denom = 0.
recall = 0.
f1_score = 0.
f1_denom = 0.
time = 0.
for prediction in predictions:
if prediction["correct_string"]:
string_count += 1.
if prediction["semantic"]:
sem_count += 1.
if prediction["syntactic"]:
syn_count += 1.
if prediction["correct_table"]:
table_count += 1.
if prediction["strict_correct_table"]:
strict_table_count += 1.
if prediction["gold_tables"] !="[[]]":
precision += prediction["table_prec"]
precision_denom += 1
if prediction["pred_table"] != "[]":
recall += prediction["table_rec"]
recall_denom += 1
if prediction["gold_tables"] != "[[]]":
f1_score += prediction["table_f1"]
f1_denom += 1
num_p = len(predictions)
print("string precision: " + str(string_count / num_p))
print("% semantic: " + str(sem_count / num_p))
print("% syntactic: " + str(syn_count / num_p))
print("table prec: " + str(table_count / num_p))
print("strict table prec: " + str(strict_table_count / num_p))
print("table row prec: " + str(precision / precision_denom))
print("table row recall: " + str(recall / recall_denom))
print("table row f1: " + str(f1_score / f1_denom))
print("inference time: " + str(time / num_p))

View File

@ -0,0 +1,569 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
import sys
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
# make limit value can work, cannot assume put 1 as a fake limit number
if type(toks[idx-1]) != int:
return idx, 1
return idx, int(toks[idx-1])
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema):
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

58
r2sql/cosql/logger.py Normal file
View File

@ -0,0 +1,58 @@
"""Contains the logging class."""
class Logger():
"""Attributes:
fileptr (file): File pointer for input/output.
lines (list of str): The lines read from the log.
"""
def __init__(self, filename, option):
self.fileptr = open(filename, option)
if option == "r":
self.lines = self.fileptr.readlines()
else:
self.lines = []
def put(self, string):
"""Writes to the file."""
self.fileptr.write(string + "\n")
self.fileptr.flush()
def close(self):
"""Closes the logger."""
self.fileptr.close()
def findlast(self, identifier, default=0.):
"""Finds the last line in the log with a certain value."""
for line in self.lines[::-1]:
if line.lower().startswith(identifier):
string = line.strip().split("\t")[1]
if string.replace(".", "").isdigit():
return float(string)
elif string.lower() == "true":
return True
elif string.lower() == "false":
return False
else:
return string
return default
def contains(self, string):
"""Dtermines whether the string is present in the log."""
for line in self.lines[::-1]:
if string.lower() in line.lower():
return True
return False
def findlast_log_before(self, before_str):
"""Finds the last entry in the log before another entry."""
loglines = []
in_line = False
for line in self.lines[::-1]:
if line.startswith(before_str):
in_line = True
elif in_line:
loglines.append(line)
if line.strip() == "" and in_line:
return "".join(loglines[::-1])
return "".join(loglines[::-1])

View File

@ -0,0 +1,80 @@
"""Contains classes for computing and keeping track of attention distributions.
"""
from collections import namedtuple
import torch
import torch.nn.functional as F
from . import torch_utils
class AttentionResult(namedtuple('AttentionResult',
('scores',
'distribution',
'vector'))):
"""Stores the result of an attention calculation."""
__slots__ = ()
class Attention(torch.nn.Module):
"""Attention mechanism class. Stores parameters for and computes attention.
Attributes:
transform_query (bool): Whether or not to transform the query being
passed in with a weight transformation before computing attentino.
transform_key (bool): Whether or not to transform the key being
passed in with a weight transformation before computing attentino.
transform_value (bool): Whether or not to transform the value being
passed in with a weight transformation before computing attentino.
key_size (int): The size of the key vectors.
value_size (int): The size of the value vectors.
the query or key.
query_weights (dy.Parameters): Weights for transforming the query.
key_weights (dy.Parameters): Weights for transforming the key.
value_weights (dy.Parameters): Weights for transforming the value.
"""
def __init__(self, query_size, key_size, value_size):
super().__init__()
self.key_size = key_size
self.value_size = value_size
self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q")
def transform_arguments(self, query, keys, values):
""" Transforms the query/key/value inputs before attention calculations.
Arguments:
query (dy.Expression): Vector representing the query (e.g., hidden state.)
keys (list of dy.Expression): List of vectors representing the key
values.
values (list of dy.Expression): List of vectors representing the values.
Returns:
triple of dy.Expression, where the first represents the (transformed)
query, the second represents the (transformed and concatenated)
keys, and the third represents the (transformed and concatenated)
values.
"""
assert len(keys) == len(values)
all_keys = torch.stack(keys, dim=1)
all_values = torch.stack(values, dim=1)
assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0])
assert all_values.size()[0] == self.value_size
query = torch_utils.linear_layer(query, self.query_weights)
return query, all_keys, all_values
def forward(self, query, keys, values=None):
if not values:
values = keys
query_t, keys_t, values_t = self.transform_arguments(query, keys, values)
scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query)
distribution = F.softmax(scores, dim=0) # len(key) x len(query)
context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query)
return AttentionResult(scores, distribution, context_vector)

View File

@ -0,0 +1,97 @@
import heapq
import math
import numpy as np
from typing import Callable, Dict, Tuple, List
import torch
class BeamSearchState:
def __init__(self, sequence=[], log_probability=0, state={}):
self.sequence = sequence
self.log_probability = log_probability
self.state = state
def __lt__(self, other):
return self.log_probability < other.log_probability
def __str__(self):
return f"{self.log_probability}: {' '.join(self.sequence)}"
class BeamSearch:
"""
Implements the beam search algorithm for decoding the most likely sequences.
"""
def __init__(
self,
is_end_of_sequence: Callable[[int, str], bool],
max_steps: int = 50,
beam_size: int = 10,
per_node_beam_size: int = None,
) -> None:
self.is_end_of_sequence = is_end_of_sequence
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size
if not per_node_beam_size:
self.per_node_beam_size = beam_size
self.beam = []
def append(self, log_probability, new_state):
if len(self.beam) < self.beam_size:
heapq.heappush(self.beam, (log_probability, new_state))
else:
heapq.heapreplace(self.beam, (log_probability, new_state))
def get_largest(self, beam):
return list(heapq.nlargest(1, beam))[0]
def search(
self,
start_state: Dict,
step_function: Callable[[Dict], Tuple[List[str], torch.Tensor, Tuple]],
append_token_function: Callable[[Tuple, Tuple, str, float], Tuple]
) -> Tuple[List, float, Tuple]:
state = BeamSearchState(state=start_state)
heapq.heappush(self.beam, (0, state))
done = False
while not done:
current_beam = self.beam
self.beam = []
found_non_terminal = False
for (log_probability, state) in current_beam:
if self.is_end_of_sequence(state.state, self.max_steps, state.sequence):
self.append(log_probability, state)
else:
found_non_terminal = True
tokens, token_probabilities, extra_state = step_function(state.state)
tps = np.array(token_probabilities)
top_indices = heapq.nlargest(self.per_node_beam_size, range(len(tps)), tps.take)
top_tokens = [tokens[i] for i in top_indices]
top_probabilities = [token_probabilities[i] for i in top_indices]
sequence_length = len(state.sequence)
for i in range(len(top_tokens)):
if top_probabilities[i] > 1e-6:
#lp = ((log_probability * sequence_length) + math.log(top_probabilities[i])) / (sequence_length + 1.0)
lp = log_probability + math.log(top_probabilities[i])
t = top_tokens[i]
new_state = append_token_function(state.state, extra_state, t, lp)
new_sequence = state.sequence.copy()
new_sequence.append(t)
ns = BeamSearchState(sequence=new_sequence, log_probability=lp, state=new_state)
self.append(lp, ns)
if not found_non_terminal:
done = True
output = self.get_largest(self.beam)[1]
return output.sequence, output.log_probability, output.state, self.beam

View File

@ -0,0 +1,336 @@
""" Decoder for the SQL generation problem."""
from collections import namedtuple
import copy
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import torch_utils
from .token_predictor import PredictionInput, PredictionInputWithSchema, PredictionStepInputWithSchema
from .beam_search import BeamSearch
import data_util.snippets as snippet_handler
from . import embedder
from data_util.vocabulary import EOS_TOK, UNK_TOK
import math
def flatten_distribution(distribution_map, probabilities):
""" Flattens a probability distribution given a map of "unique" values.
All values in distribution_map with the same value should get the sum
of the probabilities.
Arguments:
distribution_map (list of str): List of values to get the probability for.
probabilities (np.ndarray): Probabilities corresponding to the values in
distribution_map.
Returns:
list, np.ndarray of the same size where probabilities for duplicates
in distribution_map are given the sum of the probabilities in probabilities.
"""
assert len(distribution_map) == len(probabilities)
if len(distribution_map) != len(set(distribution_map)):
idx_first_dup = 0
seen_set = set()
for i, tok in enumerate(distribution_map):
if tok in seen_set:
idx_first_dup = i
break
seen_set.add(tok)
new_dist_map = distribution_map[:idx_first_dup] + list(
set(distribution_map) - set(distribution_map[:idx_first_dup]))
assert len(new_dist_map) == len(set(new_dist_map))
new_probs = np.array(
probabilities[:idx_first_dup] \
+ [0. for _ in range(len(set(distribution_map)) \
- idx_first_dup)])
assert len(new_probs) == len(new_dist_map)
for i, token_name in enumerate(
distribution_map[idx_first_dup:]):
if token_name not in new_dist_map:
new_dist_map.append(token_name)
new_index = new_dist_map.index(token_name)
new_probs[new_index] += probabilities[i +
idx_first_dup]
new_probs = new_probs.tolist()
else:
new_dist_map = distribution_map
new_probs = probabilities
assert len(new_dist_map) == len(new_probs)
return new_dist_map, new_probs
class SQLPrediction(namedtuple('SQLPrediction',
('predictions',
'sequence',
'probability',
'beam'))):
"""Contains prediction for a sequence."""
__slots__ = ()
def __str__(self):
return str(self.probability) + "\t" + " ".join(self.sequence)
Count = 0
class SequencePredictorWithSchema(torch.nn.Module):
""" Predicts a sequence.
Attributes:
lstms (list of dy.RNNBuilder): The RNN used.
token_predictor (TokenPredictor): Used to actually predict tokens.
"""
def __init__(self,
params,
input_size,
output_embedder,
column_name_token_embedder,
token_predictor):
super().__init__()
self.lstms = torch_utils.create_multilayer_lstm_params(params.decoder_num_layers, input_size, params.decoder_state_size, "LSTM-d")
self.token_predictor = token_predictor
self.output_embedder = output_embedder
self.column_name_token_embedder = column_name_token_embedder
self.start_token_embedding = torch_utils.add_params((params.output_embedding_size,), "y-0")
self.input_size = input_size
self.params = params
self.params.use_turn_num = False
if self.params.use_turn_num:
assert params.decoder_num_layers <= 2
state_size = params.decoder_state_size
self.layer1_c_0 = nn.Linear(state_size+5, state_size)
self.layer1_h_0 = nn.Linear(state_size+5, state_size)
if params.decoder_num_layers == 2:
self.layer2_c_0 = nn.Linear(state_size+5, state_size)
self.layer2_h_0 = nn.Linear(state_size+5, state_size)
def _initialize_decoder_lstm(self, encoder_state, utterance_index):
decoder_lstm_states = []
#print('utterance_index', utterance_index)
#print('self.params.decoder_num_layers', self.params.decoder_num_layers)
if self.params.use_turn_num:
if utterance_index >= 4:
utterance_index = 4;
turn_number_embedding = torch.zeros([1, 5]).cuda()
turn_number_embedding[0, utterance_index] = 1
#print('turn_number_embedding', turn_number_embedding)
for i, lstm in enumerate(self.lstms):
encoder_layer_num = 0
if len(encoder_state[0]) > 1:
encoder_layer_num = i
# check which one is h_0, which is c_0
c_0 = encoder_state[0][encoder_layer_num].view(1,-1)
h_0 = encoder_state[1][encoder_layer_num].view(1,-1)
if self.params.use_turn_num:
if i == 0:
c_0 = self.layer1_c_0( torch.cat([c_0, turn_number_embedding], -1) )
h_0 = self.layer1_h_0( torch.cat([h_0, turn_number_embedding], -1) )
else:
c_0 = self.layer2_c_0( torch.cat([c_0, turn_number_embedding], -1) )
h_0 = self.layer2_h_0( torch.cat([h_0, turn_number_embedding], -1) )
decoder_lstm_states.append((h_0, c_0))
return decoder_lstm_states
def get_output_token_embedding(self, output_token, input_schema, snippets):
if self.params.use_snippets and snippet_handler.is_snippet(output_token):
output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema)
else:
if input_schema:
assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
if self.output_embedder.in_vocabulary(output_token):
output_token_embedding = self.output_embedder(output_token)
else:
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
else:
output_token_embedding = self.output_embedder(output_token)
return output_token_embedding
def get_decoder_input(self, output_token_embedding, prediction):
if self.params.use_schema_attention and self.params.use_query_attention:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector, prediction.query_attention_results.vector], dim=0)
elif self.params.use_schema_attention:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector], dim=0)
else:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector], dim=0)
return decoder_input
def forward(self,
final_encoder_state,
encoder_states,
schema_states,
max_generation_length,
utterance_index=None,
snippets=None,
gold_sequence=None,
input_sequence=None,
previous_queries=None,
previous_query_states=None,
input_schema=None,
dropout_amount=0.):
global Count
""" Generates a sequence. """
index = 0
context_vector_size = self.input_size - self.params.output_embedding_size
# Decoder states: just the initialized decoder.
# Current input to decoder: phi(start_token) ; zeros the size of the
# context vector
predictions = []
sequence = []
probability = 1.
decoder_states = self._initialize_decoder_lstm(final_encoder_state, utterance_index)
if self.start_token_embedding.is_cuda:
decoder_input = torch.cat([self.start_token_embedding, torch.cuda.FloatTensor(context_vector_size).fill_(0)], dim=0)
else:
decoder_input = torch.cat([self.start_token_embedding, torch.zeros(context_vector_size)], dim=0)
beam_search = BeamSearch(is_end_of_sequence=self.is_end_of_sequence, max_steps=max_generation_length, beam_size=10)
prediction_step_input = PredictionStepInputWithSchema(
index=index,
decoder_input=decoder_input,
decoder_states=decoder_states,
encoder_states=encoder_states,
schema_states=schema_states,
snippets=snippets,
gold_sequence=gold_sequence,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
dropout_amount=dropout_amount,
predictions=predictions,
)
sequence, probability, final_state, beam = beam_search.search(start_state=prediction_step_input, step_function=self.beam_search_step_function, append_token_function=self.beam_search_append_token)
sum_p = 0
for x in beam:
# print('math.exp(x[0])', math.exp(x[0]))
#print('x[1]', x[1].sequence)
sum_p += math.exp(x[0])
#print('sum_p', sum_p)
assert sum_p <= 1.2
'''if len(sequence) <= 2:
print('beam', beam)
for x in beam:
print('x[0]', x[0])
print('x[1]', x[1].sequence)
print('='*20)
Count+=1
assert Count <= 10'''
return SQLPrediction(final_state.predictions,
sequence,
probability,
beam)
def beam_search_step_function(self, prediction_step_input):
# get decoded token probabilities
prediction, tokens, token_probabilities, decoder_states = self.decode_next_token(prediction_step_input)
return tokens, token_probabilities, (prediction, decoder_states)
def beam_search_append_token(self, prediction_step_input, step_function_output, token_to_append, token_log_probability):
output_token_embedding = self.get_output_token_embedding(token_to_append, prediction_step_input.input_schema, prediction_step_input.snippets)
decoder_input = self.get_decoder_input(output_token_embedding, step_function_output[0])
predictions = prediction_step_input.predictions.copy()
predictions.append(step_function_output[0])
new_state = prediction_step_input._replace(
index=prediction_step_input.index+1,
decoder_input=decoder_input,
predictions=predictions,
decoder_states=step_function_output[1],
)
return new_state
def is_end_of_sequence(self, prediction_step_input, max_generation_length, sequence):
if prediction_step_input.gold_sequence:
return prediction_step_input.index >= len(prediction_step_input.gold_sequence)
else:
return prediction_step_input.index >= max_generation_length or (len(sequence) > 0 and sequence[-1] == EOS_TOK)
def decode_next_token(self, prediction_step_input):
(index,
decoder_input,
decoder_states,
encoder_states,
schema_states,
snippets,
gold_sequence,
input_sequence,
previous_queries,
previous_query_states,
input_schema,
dropout_amount,
predictions) = prediction_step_input
_, decoder_state, decoder_states = torch_utils.forward_one_multilayer(self.lstms, decoder_input, decoder_states, dropout_amount)
prediction_input = PredictionInputWithSchema(decoder_state=decoder_state,
input_hidden_states=encoder_states,
schema_states=schema_states,
snippets=snippets,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema)
prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount)
#gold_sequence = None
if gold_sequence:
decoded_token = gold_sequence[index]
distribution_map = [decoded_token]
probabilities = [1.0]
else:#/mnt/lichaochao/rqy/editsql_data/glove.840B.300d.txt
assert prediction.scores.dim() == 1
probabilities = F.softmax(prediction.scores, dim=0).cpu().data.numpy().tolist()
distribution_map = prediction.aligned_tokens
assert len(probabilities) == len(distribution_map)
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
assert prediction.query_scores.dim() == 1
query_token_probabilities = F.softmax(prediction.query_scores, dim=0).cpu().data.numpy().tolist()
query_token_distribution_map = prediction.query_tokens
assert len(query_token_probabilities) == len(query_token_distribution_map)
copy_switch = prediction.copy_switch.cpu().data.numpy()
# Merge the two
probabilities = ((np.array(probabilities) * (1 - copy_switch)).tolist() +
(np.array(query_token_probabilities) * copy_switch).tolist()
)
distribution_map = distribution_map + query_token_distribution_map
assert len(probabilities) == len(distribution_map)
# Get a new probabilities and distribution_map consolidating duplicates
distribution_map, probabilities = flatten_distribution(distribution_map, probabilities)
# Modify the probability distribution so that the UNK token can never be produced
probabilities[distribution_map.index(UNK_TOK)] = 0.
return prediction, distribution_map, probabilities, decoder_states

View File

@ -0,0 +1,100 @@
""" Embedder for tokens. """
import torch
import torch.nn.functional as F
import data_util.snippets as snippet_handler
import data_util.vocabulary as vocabulary_handler
class Embedder(torch.nn.Module):
""" Embeds tokens. """
def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True):
super().__init__()
if vocabulary:
assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \
str(num_tokens)
self.in_vocabulary = lambda token: token in vocabulary.tokens
self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token)
if use_unk:
self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK)
else:
self.unknown_token_id = -1
self.vocabulary_size = len(vocabulary)
else:
def check_vocab(index):
""" Makes sure the index is in the vocabulary."""
assert index < num_tokens, "Passed token ID " + \
str(index) + "; expecting something less than " + str(num_tokens)
return index < num_tokens
self.in_vocabulary = check_vocab
self.vocab_token_lookup = lambda x: x
self.unknown_token_id = num_tokens # Deliberately throws an error here,
# But should crash before this
self.vocabulary_size = num_tokens
self.anonymizer = anonymizer
emb_name = name + "-tokens"
print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size))
if initializer is not None:
word_embeddings_tensor = torch.FloatTensor(initializer)
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze)
else:
init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1)
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
if self.anonymizer:
emb_name = name + "-entities"
entity_size = len(self.anonymizer.entity_types)
print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size))
init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1)
self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
def forward(self, token):
assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"
if self.in_vocabulary(token):
index_list = torch.LongTensor([self.vocab_token_lookup(token)])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.token_embedding_matrix(index_list).squeeze()
elif self.anonymizer and self.anonymizer.is_anon_tok(token):
index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.entity_embedding_matrix(index_list).squeeze()
else:
index_list = torch.LongTensor([self.unknown_token_id])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.token_embedding_matrix(index_list).squeeze()
def bow_snippets(token, snippets, output_embedder, input_schema):
""" Bag of words embedding for snippets"""
assert snippet_handler.is_snippet(token) and snippets
snippet_sequence = []
for snippet in snippets:
if snippet.name == token:
snippet_sequence = snippet.sequence
break
assert snippet_sequence
if input_schema:
snippet_embeddings = []
for output_token in snippet_sequence:
assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
if output_embedder.in_vocabulary(output_token):
snippet_embeddings.append(output_embedder(output_token))
else:
snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True))
else:
snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence]
snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size
return torch.mean(snippet_embeddings, dim=0) # emb_size

View File

@ -0,0 +1,60 @@
""" Contains code for encoding an input sequence. """
import torch
import torch.nn.functional as F
from .torch_utils import create_multilayer_lstm_params, encode_sequence
class Encoder(torch.nn.Module):
""" Encodes an input sequence. """
def __init__(self, num_layers, input_size, state_size):
super().__init__()
self.num_layers = num_layers
self.forward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-ef")
self.backward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-eb")
def forward(self, sequence, embedder, dropout_amount=0.):
""" Encodes a sequence forward and backward.
Inputs:
forward_seq (list of str): The string forwards.
backward_seq (list of str): The string backwards.
f_rnns (list of dy.RNNBuilder): The forward RNNs.
b_rnns (list of dy.RNNBuilder): The backward RNNS.
emb_fn (dict str->dy.Expression): Embedding function for tokens in the
sequence.
size (int): The size of the RNNs.
dropout_amount (float, optional): The amount of dropout to apply.
Returns:
(list of dy.Expression, list of dy.Expression), list of dy.Expression,
where the first pair is the (final cell memories, final cell states) of
all layers, and the second list is a list of the final layer's cell
state for all tokens in the sequence.
"""
forward_state, forward_outputs = encode_sequence(
sequence,
self.forward_lstms,
embedder,
dropout_amount=dropout_amount)
backward_state, backward_outputs = encode_sequence(
sequence[::-1],
self.backward_lstms,
embedder,
dropout_amount=dropout_amount)
cell_memories = []
hidden_states = []
for i in range(self.num_layers):
cell_memories.append(torch.cat([forward_state[0][i], backward_state[0][i]], dim=0))
hidden_states.append(torch.cat([forward_state[1][i], backward_state[1][i]], dim=0))
assert len(forward_outputs) == len(backward_outputs)
backward_outputs = backward_outputs[::-1]
final_outputs = []
for i in range(len(sequence)):
final_outputs.append(torch.cat([forward_outputs[i], backward_outputs[i]], dim=0))
return (cell_memories, hidden_states), final_outputs

394
r2sql/cosql/model/model.py Normal file
View File

@ -0,0 +1,394 @@
""" Class for the Sequence to sequence model for ATIS."""
import os
import torch
import torch.nn.functional as F
from . import torch_utils
from . import utils_bert
from data_util.vocabulary import DEL_TOK, UNK_TOK
from .encoder import Encoder
from .embedder import Embedder
from .token_predictor import construct_token_predictor
import numpy as np
from data_util.atis_vocab import ATISVocabulary
def get_token_indices(token, index_to_token):
""" Maps from a gold token (string) to a list of indices.
Inputs:
token (string): String to look up.
index_to_token (list of tokens): Ordered list of tokens.
Returns:
list of int, representing the indices of the token in the probability
distribution.
"""
if token in index_to_token:
if len(set(index_to_token)) == len(index_to_token): # no duplicates
return [index_to_token.index(token)]
else:
indices = []
for index, other_token in enumerate(index_to_token):
if token == other_token:
indices.append(index)
assert len(indices) == len(set(indices))
return indices
else:
return [index_to_token.index(UNK_TOK)]
def flatten_utterances(utterances):
""" Gets a flat sequence from a sequence of utterances.
Inputs:
utterances (list of list of str): Utterances to concatenate.
Returns:
list of str, representing the flattened sequence with separating
delimiter tokens.
"""
sequence = []
for i, utterance in enumerate(utterances):
sequence.extend(utterance)
if i < len(utterances) - 1:
sequence.append(DEL_TOK)
return sequence
def encode_snippets_with_states(snippets, states):
""" Encodes snippets by using previous query states instead.
Inputs:
snippets (list of Snippet): Input snippets.
states (list of dy.Expression): Previous hidden states to use.
TODO: should this by dy.Expression or vector values?
"""
for snippet in snippets:
snippet.set_embedding(torch.cat([states[snippet.startpos],states[snippet.endpos]], dim=0))
return snippets
def load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params):
print(output_vocabulary.inorder_tokens)
print()
def read_glove_embedding(embedding_filename, embedding_size):
glove_embeddings = {}
with open(embedding_filename) as f:
cnt = 1
for line in f:
cnt += 1
if params.debug or not params.train:
if cnt == 1000:
print('Read 1000 word embeddings')
break
l_split = line.split()
word = " ".join(l_split[0:len(l_split) - embedding_size])
embedding = np.array([float(val) for val in l_split[-embedding_size:]])
glove_embeddings[word] = embedding
return glove_embeddings
print('Loading Glove Embedding from', params.embedding_filename)
glove_embedding_size = 300
# glove_embeddings = read_glove_embedding(params.embedding_filename, glove_embedding_size)
import pickle as pkl
# pkl.dump(glove_embeddings, open("glove_embeddings.pkl", "wb"))
# exit()
glove_embeddings = pkl.load(open("glove_embeddings.pkl", "rb"))
print('Done')
input_embedding_size = glove_embedding_size
def create_word_embeddings(vocab):
vocabulary_embeddings = np.zeros((len(vocab), glove_embedding_size), dtype=np.float32)
vocabulary_tokens = vocab.inorder_tokens
glove_oov = 0
para_oov = 0
for token in vocabulary_tokens:
token_id = vocab.token_to_id(token)
if token in glove_embeddings:
vocabulary_embeddings[token_id][:glove_embedding_size] = glove_embeddings[token]
else:
glove_oov += 1
print('Glove OOV:', glove_oov, 'Para OOV', para_oov, 'Total', len(vocab))
return vocabulary_embeddings
input_vocabulary_embeddings = create_word_embeddings(input_vocabulary)
output_vocabulary_embeddings = create_word_embeddings(output_vocabulary)
output_vocabulary_schema_embeddings = None
if output_vocabulary_schema:
output_vocabulary_schema_embeddings = create_word_embeddings(output_vocabulary_schema)
del glove_embeddings
return input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size
class ATISModel(torch.nn.Module):
""" Sequence-to-sequence model for predicting a SQL query given an utterance
and an interaction prefix.
"""
def __init__(
self,
params,
input_vocabulary,
output_vocabulary,
output_vocabulary_schema,
anonymizer):
super().__init__()
self.params = params
if params.use_bert:
self.model_bert, self.tokenizer, self.bert_config = utils_bert.get_bert(params)
if 'atis' not in params.data_directory:
if params.use_bert:
input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size = load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params)
# Create the output embeddings
self.output_embedder = Embedder(params.output_embedding_size,
name="output-embedding",
initializer=output_vocabulary_embeddings,
vocabulary=output_vocabulary,
anonymizer=anonymizer,
freeze=False)
self.column_name_token_embedder = None
else:
input_vocabulary_embeddings, output_vocabulary_embeddings, output_vocabulary_schema_embeddings, input_embedding_size = load_word_embeddings(input_vocabulary, output_vocabulary, output_vocabulary_schema, params)
params.input_embedding_size = input_embedding_size
self.params.input_embedding_size = input_embedding_size
# Create the input embeddings
self.input_embedder = Embedder(params.input_embedding_size,
name="input-embedding",
initializer=input_vocabulary_embeddings,
vocabulary=input_vocabulary,
anonymizer=anonymizer,
freeze=params.freeze)
# Create the output embeddings
self.output_embedder = Embedder(params.output_embedding_size,
name="output-embedding",
initializer=output_vocabulary_embeddings,
vocabulary=output_vocabulary,
anonymizer=anonymizer,
freeze=False)
self.column_name_token_embedder = Embedder(params.input_embedding_size,
name="schema-embedding",
initializer=output_vocabulary_schema_embeddings,
vocabulary=output_vocabulary_schema,
anonymizer=anonymizer,
freeze=params.freeze)
else:
# Create the input embeddings
self.input_embedder = Embedder(params.input_embedding_size,
name="input-embedding",
vocabulary=input_vocabulary,
anonymizer=anonymizer,
freeze=False)
# Create the output embeddings
self.output_embedder = Embedder(params.output_embedding_size,
name="output-embedding",
vocabulary=output_vocabulary,
anonymizer=anonymizer,
freeze=False)
self.column_name_token_embedder = None
# Create the encoder
encoder_input_size = params.input_embedding_size
encoder_output_size = params.encoder_state_size
if params.use_bert:
encoder_input_size = self.bert_config.hidden_size
if params.discourse_level_lstm:
encoder_input_size += params.encoder_state_size / 2
self.utterance_encoder = Encoder(params.encoder_num_layers, encoder_input_size, encoder_output_size)
# Positional embedder for utterances
attention_key_size = params.encoder_state_size
self.schema_attention_key_size = attention_key_size
if params.state_positional_embeddings:
attention_key_size += params.positional_embedding_size
self.positional_embedder = Embedder(
params.positional_embedding_size,
name="positional-embedding",
num_tokens=params.maximum_utterances)
self.utterance_attention_key_size = attention_key_size
# Create the discourse-level LSTM parameters
if params.discourse_level_lstm:
self.discourse_lstms = torch_utils.create_multilayer_lstm_params(1, params.encoder_state_size, params.encoder_state_size / 2, "LSTM-t")
self.initial_discourse_state = torch_utils.add_params(tuple([params.encoder_state_size / 2]), "V-turn-state-0")
# Snippet encoder
final_snippet_size = 0
if params.use_snippets and not params.previous_decoder_snippet_encoding:
snippet_encoding_size = int(params.encoder_state_size / 2)
final_snippet_size = params.encoder_state_size
if params.snippet_age_embedding:
snippet_encoding_size -= int(
params.snippet_age_embedding_size / 4)
self.snippet_age_embedder = Embedder(
params.snippet_age_embedding_size,
name="snippet-age-embedding",
num_tokens=params.max_snippet_age_embedding)
final_snippet_size = params.encoder_state_size + params.snippet_age_embedding_size / 2
self.snippet_encoder = Encoder(params.snippet_num_layers,
params.output_embedding_size,
snippet_encoding_size)
# Previous query Encoder
if params.use_previous_query:
self.query_encoder = Encoder(params.encoder_num_layers, params.output_embedding_size, params.encoder_state_size)
self.final_snippet_size = final_snippet_size
self.dropout = 0.
def _encode_snippets(self, previous_query, snippets, input_schema):
""" Computes a single vector representation for each snippet.
Inputs:
previous_query (list of str): Previous query in the interaction.
snippets (list of Snippet): Snippets extracted from the previous
Returns:
list of Snippets, where the embedding is set to a vector.
"""
startpoints = [snippet.startpos for snippet in snippets]
endpoints = [snippet.endpos for snippet in snippets]
assert len(startpoints) == 0 or min(startpoints) >= 0
if input_schema:
assert len(endpoints) == 0 or max(endpoints) <= len(previous_query)
else:
assert len(endpoints) == 0 or max(endpoints) < len(previous_query)
snippet_embedder = lambda query_token: self.get_query_token_embedding(query_token, input_schema)
if previous_query and snippets:
_, previous_outputs = self.snippet_encoder(
previous_query, snippet_embedder, dropout_amount=self.dropout)
assert len(previous_outputs) == len(previous_query)
for snippet in snippets:
if input_schema:
embedding = torch.cat([previous_outputs[snippet.startpos],previous_outputs[snippet.endpos-1]], dim=0)
else:
embedding = torch.cat([previous_outputs[snippet.startpos],previous_outputs[snippet.endpos]], dim=0)
if self.params.snippet_age_embedding:
embedding = torch.cat([embedding, self.snippet_age_embedder(min(snippet.age, self.params.max_snippet_age_embedding - 1))], dim=0)
snippet.set_embedding(embedding)
return snippets
def _initialize_discourse_states(self):
discourse_state = self.initial_discourse_state
discourse_lstm_states = []
for lstm in self.discourse_lstms:
hidden_size = lstm.weight_hh.size()[1]
if lstm.weight_hh.is_cuda:
h_0 = torch.cuda.FloatTensor(1,hidden_size).fill_(0)
c_0 = torch.cuda.FloatTensor(1,hidden_size).fill_(0)
else:
h_0 = torch.zeros(1,hidden_size)
c_0 = torch.zeros(1,hidden_size)
discourse_lstm_states.append((h_0, c_0))
return discourse_state, discourse_lstm_states
def _add_positional_embeddings(self, hidden_states, utterances, group=False):
grouped_states = []
start_index = 0
for utterance in utterances:
grouped_states.append(hidden_states[start_index:start_index + len(utterance)])
start_index += len(utterance)
assert len(hidden_states) == sum([len(seq) for seq in grouped_states]) == sum([len(utterance) for utterance in utterances])
new_states = []
flat_sequence = []
num_utterances_to_keep = min(self.params.maximum_utterances, len(utterances))
for i, (states, utterance) in enumerate(zip(
grouped_states[-num_utterances_to_keep:], utterances[-num_utterances_to_keep:])):
positional_sequence = []
index = num_utterances_to_keep - i - 1
for state in states:
positional_sequence.append(torch.cat([state, self.positional_embedder(index)], dim=0))
assert len(positional_sequence) == len(utterance), \
"Expected utterance and state sequence length to be the same, " \
+ "but they were " + str(len(utterance)) \
+ " and " + str(len(positional_sequence))
if group:
new_states.append(positional_sequence)
else:
new_states.extend(positional_sequence)
flat_sequence.extend(utterance)
return new_states, flat_sequence
def build_optim(self):
params_trainer = []
params_bert_trainer = []
for name, param in self.named_parameters():
if param.requires_grad:
if 'model_bert' in name:
params_bert_trainer.append(param)
else:
params_trainer.append(param)
self.trainer = torch.optim.Adam(params_trainer, lr=self.params.initial_learning_rate)
if self.params.fine_tune_bert:
self.bert_trainer = torch.optim.Adam(params_bert_trainer, lr=self.params.lr_bert)
def set_dropout(self, value):
""" Sets the dropout to a specified value.
Inputs:
value (float): Value to set dropout to.
"""
self.dropout = value
def set_learning_rate(self, value):
""" Sets the learning rate for the trainer.
Inputs:
value (float): The new learning rate.
"""
for param_group in self.trainer.param_groups:
param_group['lr'] = value
def save(self, filename):
""" Saves the model to the specified filename.
Inputs:
filename (str): The filename to save to.
"""
torch.save(self.state_dict(), filename)
def load(self, filename):
""" Loads saved parameters into the parameter collection.
Inputs:
filename (str): Name of file containing parameters.
"""
self.load_state_dict(torch.load(filename))
print("Loaded model from file " + filename)

View File

@ -0,0 +1,740 @@
""" Class for the Sequence to sequence model for ATIS."""
import torch
import torch.nn.functional as F
from . import torch_utils
import data_util.snippets as snippet_handler
import data_util.sql_util
import data_util.vocabulary as vocab
from data_util.vocabulary import EOS_TOK, UNK_TOK
import data_util.tokenizers
from .token_predictor import construct_token_predictor
from .attention import Attention
from .model import ATISModel, encode_snippets_with_states, get_token_indices
from data_util.utterance import ANON_INPUT_KEY
from .encoder import Encoder
from .transformer_layer.transformer_attention import TransformerAttention
from .decoder import SequencePredictorWithSchema
from . import utils_bert
import data_util.atis_batch
from postprocess_eval import read_schema, postprocess_one
from eval_scripts.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
LIMITED_INTERACTIONS = {"raw/atis2/12-1.1/ATIS2/TEXT/TRAIN/SRI/QS0/1": 22,
"raw/atis3/17-1.1/ATIS3/SP_TRN/MIT/8K7/5": 14,
"raw/atis2/12-1.1/ATIS2/TEXT/TEST/NOV92/770/5": -1}
END_OF_INTERACTION = {"quit", "exit", "done"}
class SchemaInteractionATISModel(ATISModel):
""" Interaction ATIS model, where an interaction is processed all at once.
"""
def __init__(self,
params,
input_vocabulary,
output_vocabulary,
output_vocabulary_schema,
anonymizer):
ATISModel.__init__(
self,
params,
input_vocabulary,
output_vocabulary,
output_vocabulary_schema,
anonymizer)
if self.params.use_schema_encoder:
# Create the schema encoder
schema_encoder_num_layer = 1
schema_encoder_input_size = params.input_embedding_size
schema_encoder_state_size = params.encoder_state_size
if params.use_bert:
schema_encoder_input_size = self.bert_config.hidden_size
self.schema_encoder = Encoder(schema_encoder_num_layer, schema_encoder_input_size, schema_encoder_state_size)
# self-attention
if self.params.use_schema_self_attention:
self.schema2schema_attention_module = Attention(self.schema_attention_key_size, self.schema_attention_key_size, self.schema_attention_key_size)
# utterance level attention
if self.params.use_utterance_attention:
self.utterance_attention_module = Attention(self.params.encoder_state_size, self.params.encoder_state_size, self.params.encoder_state_size)
# Use attention module between input_hidden_states and schema_states
# schema_states: self.schema_attention_key_size x len(schema)
# input_hidden_states: self.utterance_attention_key_size x len(input)
if params.use_encoder_attention:
self.utterance2schema_attention_module = Attention(self.schema_attention_key_size, self.utterance_attention_key_size, self.utterance_attention_key_size)
self.schema2utterance_attention_module = Attention(self.utterance_attention_key_size, self.schema_attention_key_size, self.schema_attention_key_size)
new_attention_key_size = self.schema_attention_key_size + self.utterance_attention_key_size
if self.params.use_transformer_relation:
self.transoformer_attention_key_size = self.schema_attention_key_size
self.transoformer_attention_module = TransformerAttention(self.transoformer_attention_key_size, self.transoformer_attention_key_size)
new_attention_key_size += self.transoformer_attention_key_size
self.schema_attention_key_size = new_attention_key_size
self.utterance_attention_key_size = new_attention_key_size
if self.params.use_schema_encoder_2:
self.schema_encoder_2 = Encoder(schema_encoder_num_layer, self.schema_attention_key_size, self.schema_attention_key_size)
self.utterance_encoder_2 = Encoder(params.encoder_num_layers, self.utterance_attention_key_size, self.utterance_attention_key_size)
self.token_predictor = construct_token_predictor(params,
output_vocabulary,
self.utterance_attention_key_size,
self.schema_attention_key_size,
self.final_snippet_size,
anonymizer)
# Use schema_attention in decoder
if params.use_schema_attention and params.use_query_attention:
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size + self.schema_attention_key_size + params.encoder_state_size
elif params.use_schema_attention:
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size + self.schema_attention_key_size
else:
decoder_input_size = params.output_embedding_size + self.utterance_attention_key_size
self.decoder = SequencePredictorWithSchema(params, decoder_input_size, self.output_embedder, self.column_name_token_embedder, self.token_predictor)
table_schema_path = params.table_schema_path
self.database_schema = read_schema(table_schema_path)
def predict_turn(self,
utterance_final_state,
input_hidden_states,
schema_states,
max_generation_length,
gold_query=None,
snippets=None,
input_sequence=None,
previous_queries=None,
previous_query_states=None,
input_schema=None,
feed_gold_tokens=False,
training=False,
unflat_sequence=None):
""" Gets a prediction for a single turn -- calls decoder and updates loss, etc.
TODO: this can probably be split into two methods, one that just predicts
and another that computes the loss.
"""
predicted_sequence = []
fed_sequence = []
loss = None
token_accuracy = 0.
if self.params.use_encoder_attention:
schema_attention = self.utterance2schema_attention_module(torch.stack(schema_states,dim=0), input_hidden_states).vector # input_value_size x len(schema)
utterance_attention = self.schema2utterance_attention_module(torch.stack(input_hidden_states,dim=0), schema_states).vector # schema_value_size x len(input)
if schema_attention.dim() == 1:
schema_attention = schema_attention.unsqueeze(1)
if utterance_attention.dim() == 1:
utterance_attention = utterance_attention.unsqueeze(1)
if self.params.use_transformer_relation:
utterance_transoformer_attention, schema_transoformer_attention = self.transoformer_attention_module(torch.stack(input_hidden_states,dim=0 ), torch.stack(schema_states,dim=0), input_sequence, unflat_sequence, input_schema, self.schema, dropout_amount=self.dropout)
schema_attention = torch.cat([schema_transoformer_attention, schema_attention], dim=0)
utterance_attention = torch.cat([utterance_transoformer_attention, utterance_attention], dim=0)
new_schema_states = torch.cat([torch.stack(schema_states, dim=1), schema_attention], dim=0) # (input_value_size+schema_value_size) x len(schema)
schema_states = list(torch.split(new_schema_states, split_size_or_sections=1, dim=1))
schema_states = [schema_state.squeeze() for schema_state in schema_states]
new_input_hidden_states = torch.cat([torch.stack(input_hidden_states, dim=1), utterance_attention], dim=0) # (input_value_size+schema_value_size) x len(input)
input_hidden_states = list(torch.split(new_input_hidden_states, split_size_or_sections=1, dim=1))
input_hidden_states = [input_hidden_state.squeeze() for input_hidden_state in input_hidden_states]
# bi-lstm over schema_states and input_hidden_states (embedder is an identify function)
if self.params.use_schema_encoder_2:
final_schema_state, schema_states = self.schema_encoder_2(schema_states, lambda x: x, dropout_amount=self.dropout)
final_utterance_state, input_hidden_states = self.utterance_encoder_2(input_hidden_states, lambda x: x, dropout_amount=self.dropout)
if feed_gold_tokens:
decoder_results = self.decoder(utterance_final_state,
input_hidden_states,
schema_states,
max_generation_length,
gold_sequence=gold_query,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
snippets=snippets,
dropout_amount=self.dropout)
all_scores = []
all_alignments = []
for prediction in decoder_results.predictions:
scores = F.softmax(prediction.scores, dim=0)
alignments = prediction.aligned_tokens
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
query_scores = F.softmax(prediction.query_scores, dim=0)
copy_switch = prediction.copy_switch
scores = torch.cat([scores * (1 - copy_switch), query_scores * copy_switch], dim=0)
alignments = alignments + prediction.query_tokens
all_scores.append(scores)
all_alignments.append(alignments)
# Compute the loss
gold_sequence = gold_query
loss = torch_utils.compute_loss(gold_sequence, all_scores, all_alignments, get_token_indices)
if not training:
predicted_sequence = torch_utils.get_seq_from_scores(all_scores, all_alignments)
token_accuracy = torch_utils.per_token_accuracy(gold_sequence, predicted_sequence)
fed_sequence = gold_sequence
else:
decoder_results = self.decoder(utterance_final_state,
input_hidden_states,
schema_states,
max_generation_length,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
snippets=snippets,
dropout_amount=self.dropout)
predicted_sequence = decoder_results.sequence
fed_sequence = predicted_sequence
decoder_states = [pred.decoder_state for pred in decoder_results.predictions]
# fed_sequence contains EOS, which we don't need when encoding snippets.
# also ignore the first state, as it contains the BEG encoding.
for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
if snippet_handler.is_snippet(token):
snippet_length = 0
for snippet in snippets:
if snippet.name == token:
snippet_length = len(snippet.sequence)
break
assert snippet_length > 0
decoder_states.extend([state for _ in range(snippet_length)])
else:
decoder_states.append(state)
return (predicted_sequence,
loss,
token_accuracy,
decoder_states,
decoder_results)
def encode_schema_bow_simple(self, input_schema):
schema_states = []
for column_name in input_schema.column_names_embedder_input:
schema_states.append(input_schema.column_name_embedder_bow(column_name, surface_form=False, column_name_token_embedder=self.column_name_token_embedder))
input_schema.set_column_name_embeddings(schema_states)
return schema_states
def encode_schema_self_attention(self, schema_states):
schema_self_attention = self.schema2schema_attention_module(torch.stack(schema_states,dim=0), schema_states).vector
if schema_self_attention.dim() == 1:
schema_self_attention = schema_self_attention.unsqueeze(1)
residual_schema_states = list(torch.split(schema_self_attention, split_size_or_sections=1, dim=1))
residual_schema_states = [schema_state.squeeze() for schema_state in residual_schema_states]
new_schema_states = [schema_state+residual_schema_state for schema_state, residual_schema_state in zip(schema_states, residual_schema_states)]
return new_schema_states
def encode_schema(self, input_schema, dropout=False):
schema_states = []
for column_name_embedder_input in input_schema.column_names_embedder_input:
tokens = column_name_embedder_input.split()
if dropout:
final_schema_state_one, schema_states_one = self.schema_encoder(tokens, self.column_name_token_embedder, dropout_amount=self.dropout)
else:
final_schema_state_one, schema_states_one = self.schema_encoder(tokens, self.column_name_token_embedder)
# final_schema_state_one: 1 means hidden_states instead of cell_memories, -1 means last layer
schema_states.append(final_schema_state_one[1][-1])
input_schema.set_column_name_embeddings(schema_states)
# self-attention over schema_states
if self.params.use_schema_self_attention:
schema_states = self.encode_schema_self_attention(schema_states)
return schema_states
def get_bert_encoding(self, input_sequence, input_schema, discourse_state, dropout):
utterance_states, schema_token_states = utils_bert.get_bert_encoding(self.bert_config, self.model_bert, self.tokenizer, input_sequence, input_schema, bert_input_version=self.params.bert_input_version, num_out_layers_n=1, num_out_layers_h=1)
if self.params.discourse_level_lstm:
utterance_token_embedder = lambda x: torch.cat([x, discourse_state], dim=0)
else:
utterance_token_embedder = lambda x: x
if dropout:
final_utterance_state, utterance_states = self.utterance_encoder(
utterance_states,
utterance_token_embedder,
dropout_amount=self.dropout)
else:
final_utterance_state, utterance_states = self.utterance_encoder(
utterance_states,
utterance_token_embedder)
schema_states = []
for schema_token_states1 in schema_token_states:
if dropout:
final_schema_state_one, schema_states_one = self.schema_encoder(schema_token_states1, lambda x: x, dropout_amount=self.dropout)
else:
final_schema_state_one, schema_states_one = self.schema_encoder(schema_token_states1, lambda x: x)
# final_schema_state_one: 1 means hidden_states instead of cell_memories, -1 means last layer
schema_states.append(final_schema_state_one[1][-1])
input_schema.set_column_name_embeddings(schema_states)
# self-attention over schema_states
if self.params.use_schema_self_attention:
schema_states = self.encode_schema_self_attention(schema_states)
return final_utterance_state, utterance_states, schema_states
def get_query_token_embedding(self, output_token, input_schema):
if input_schema:
if not (self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)):
output_token = 'value'
if self.output_embedder.in_vocabulary(output_token):
output_token_embedding = self.output_embedder(output_token)
else:
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
else:
output_token_embedding = self.output_embedder(output_token)
return output_token_embedding
def get_utterance_attention(self, final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep):
# self-attention between utterance_states
final_utterance_states_c.append(final_utterance_state[0][0])
final_utterance_states_h.append(final_utterance_state[1][0])
final_utterance_states_c = final_utterance_states_c[-num_utterances_to_keep:]
final_utterance_states_h = final_utterance_states_h[-num_utterances_to_keep:]
attention_result = self.utterance_attention_module(final_utterance_states_c[-1], final_utterance_states_c)
final_utterance_state_attention_c = final_utterance_states_c[-1] + attention_result.vector.squeeze()
attention_result = self.utterance_attention_module(final_utterance_states_h[-1], final_utterance_states_h)
final_utterance_state_attention_h = final_utterance_states_h[-1] + attention_result.vector.squeeze()
final_utterance_state = ([final_utterance_state_attention_c],[final_utterance_state_attention_h])
return final_utterance_states_c, final_utterance_states_h, final_utterance_state
def get_previous_queries(self, previous_queries, previous_query_states, previous_query, input_schema):
previous_queries.append(previous_query)
num_queries_to_keep = min(self.params.maximum_queries, len(previous_queries))
previous_queries = previous_queries[-num_queries_to_keep:]
query_token_embedder = lambda query_token: self.get_query_token_embedding(query_token, input_schema)
_, previous_outputs = self.query_encoder(previous_query, query_token_embedder, dropout_amount=self.dropout)
assert len(previous_outputs) == len(previous_query)
previous_query_states.append(previous_outputs)
previous_query_states = previous_query_states[-num_queries_to_keep:]
return previous_queries, previous_query_states
def train_step(self, interaction, max_generation_length, snippet_alignment_probability=1.):
""" Trains the interaction-level model on a single interaction.
Inputs:
interaction (Interaction): The interaction to train on.
learning_rate (float): Learning rate to use.
snippet_keep_age (int): Age of oldest snippets to use.
snippet_alignment_probability (float): The probability that a snippet will
be used in constructing the gold sequence.
"""
# assert self.params.discourse_level_lstm
losses = []
total_gold_tokens = 0
input_hidden_states = []
input_sequences = []
final_utterance_states_c = []
final_utterance_states_h = []
previous_query_states = []
previous_queries = []
decoder_states = []
discourse_state = None
if self.params.discourse_level_lstm:
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
discourse_states = []
# Schema and schema embeddings
input_schema = interaction.get_schema()
schema_states = []
if input_schema and not self.params.use_bert:
schema_states = self.encode_schema_bow_simple(input_schema)
identifier = interaction.identifier
if len(identifier.split('/')) == 2:
database_id, interaction_id = identifier.split('/')
else:
database_id = 'atis'
interaction_id = identifier
self.schema = self.database_schema[database_id]
for utterance_index, utterance in enumerate(interaction.gold_utterances()):
if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]:
break
input_sequence = utterance.input_sequence()
available_snippets = utterance.snippets()
previous_query = utterance.previous_query()
# Get the gold query: reconstruct if the alignment probability is less than one
if snippet_alignment_probability < 1.:
gold_query = sql_util.add_snippets_to_query(
available_snippets,
utterance.contained_entities(),
utterance.anonymized_gold_query(),
prob_align=snippet_alignment_probability) + [vocab.EOS_TOK]
else:
gold_query = utterance.gold_query()
# Encode the utterance, and update the discourse-level states
if not self.params.use_bert:
if self.params.discourse_level_lstm:
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
else:
utterance_token_embedder = self.input_embedder
final_utterance_state, utterance_states = self.utterance_encoder(
input_sequence,
utterance_token_embedder,
dropout_amount=self.dropout)
else:
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=True)
input_hidden_states.extend(utterance_states)
input_sequences.append(input_sequence)
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
# final_utterance_state[1][0] is the first layer's hidden states at the last time step (concat forward lstm and backward lstm)
if self.params.discourse_level_lstm:
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states, self.dropout)
if self.params.use_utterance_attention:
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
if self.params.state_positional_embeddings:
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
else:
flat_sequence = []
for utt in input_sequences[-num_utterances_to_keep:]:
flat_sequence.extend(utt)
snippets = None
if self.params.use_snippets:
if self.params.previous_decoder_snippet_encoding:
snippets = encode_snippets_with_states(available_snippets, decoder_states)
else:
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
if self.params.use_previous_query and len(previous_query) > 0:
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
if len(gold_query) <= max_generation_length and len(previous_query) <= max_generation_length:
prediction = self.predict_turn(final_utterance_state,
utterance_states,
schema_states,
max_generation_length,
gold_query=gold_query,
snippets=snippets,
input_sequence=flat_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
feed_gold_tokens=True,
training=True)
loss = prediction[1]
decoder_states = prediction[3]
total_gold_tokens += len(gold_query)
losses.append(loss)
print('<<', loss.item())
else:
# Break if previous decoder snippet encoding -- because the previous
# sequence was too long to run the decoder.
if self.params.previous_decoder_snippet_encoding:
break
continue
torch.cuda.empty_cache()
if losses:
average_loss = torch.sum(torch.stack(losses)) / total_gold_tokens
if average_loss.item() > 80.0:
print("Loss is too large", average_loss.item())
loss_scalar = 0.0
else:
# Renormalize so the effect is normalized by the batch size.
normalized_loss = average_loss
if self.params.reweight_batch:
normalized_loss = len(losses) * average_loss / float(self.params.batch_size)
normalized_loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=10, norm_type=2)
self.trainer.step()
if self.params.fine_tune_bert:
self.bert_trainer.step()
self.zero_grad()
loss_scalar = normalized_loss.item()
else:
loss_scalar = 0.
return loss_scalar
def predict_with_predicted_queries(self, interaction, max_generation_length, syntax_restrict=True):
""" Predicts an interaction, using the predicted queries to get snippets."""
# assert self.params.discourse_level_lstm
syntax_restrict=False
predictions = []
input_hidden_states = []
input_sequences = []
final_utterance_states_c = []
final_utterance_states_h = []
previous_query_states = []
previous_queries = []
discourse_state = None
if self.params.discourse_level_lstm:
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
discourse_states = []
# Schema and schema embeddings
input_schema = interaction.get_schema()
schema_states = []
if input_schema and not self.params.use_bert:
schema_states = self.encode_schema_bow_simple(input_schema)
interaction.start_interaction()
while not interaction.done():
utterance = interaction.next_utterance()
available_snippets = utterance.snippets()
previous_query = utterance.previous_query()
input_sequence = utterance.input_sequence()
if not self.params.use_bert:
if self.params.discourse_level_lstm:
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
else:
utterance_token_embedder = self.input_embedder
final_utterance_state, utterance_states = self.utterance_encoder(
input_sequence,
utterance_token_embedder)
else:
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=False)
input_hidden_states.extend(utterance_states)
input_sequences.append(input_sequence)
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
if self.params.discourse_level_lstm:
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states)
if self.params.use_utterance_attention:
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
if self.params.state_positional_embeddings:
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
else:
flat_sequence = []
for utt in input_sequences[-num_utterances_to_keep:]:
flat_sequence.extend(utt)
snippets = None
if self.params.use_snippets:
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
if self.params.use_previous_query and len(previous_query) > 0:
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
identifier = interaction.identifier
if len(identifier.split('/')) == 2:
database_id, interaction_id = identifier.split('/')
else:
database_id = 'atis'
interaction_id = identifier
self.schema = self.database_schema[database_id]
results = self.predict_turn(final_utterance_state,
utterance_states,
schema_states,
max_generation_length,
input_sequence=flat_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
snippets=snippets)
predicted_sequence = results[0]
predictions.append(results)
# Update things necessary for using predicted queries
anonymized_sequence = utterance.remove_snippets(predicted_sequence)
if EOS_TOK in anonymized_sequence:
anonymized_sequence = anonymized_sequence[:-1] # Remove _EOS
else:
anonymized_sequence = ['select', '*', 'from', 't1']
if not syntax_restrict:
utterance.set_predicted_query(interaction.remove_snippets(predicted_sequence))
if input_schema:
# on SParC
interaction.add_utterance(utterance, anonymized_sequence, previous_snippets=utterance.snippets(), simple=True)
else:
# on ATIS
interaction.add_utterance(utterance, anonymized_sequence, previous_snippets=utterance.snippets(), simple=False)
else:
utterance.set_predicted_query(utterance.previous_query())
interaction.add_utterance(utterance, utterance.previous_query(), previous_snippets=utterance.snippets())
return predictions
def predict_with_gold_queries(self, interaction, max_generation_length, feed_gold_query=False):
""" Predicts SQL queries for an interaction.
Inputs:
interaction (Interaction): Interaction to predict for.
feed_gold_query (bool): Whether or not to feed the gold token to the
generation step.
"""
# assert self.params.discourse_level_lstm
predictions = []
input_hidden_states = []
input_sequences = []
final_utterance_states_c = []
final_utterance_states_h = []
previous_query_states = []
previous_queries = []
decoder_states = []
discourse_state = None
if self.params.discourse_level_lstm:
discourse_state, discourse_lstm_states = self._initialize_discourse_states()
discourse_states = []
# Schema and schema embeddings
input_schema = interaction.get_schema()
schema_states = []
if input_schema and not self.params.use_bert:
schema_states = self.encode_schema_bow_simple(input_schema)
for utterance in interaction.gold_utterances():
input_sequence = utterance.input_sequence()
available_snippets = utterance.snippets()
previous_query = utterance.previous_query()
# Encode the utterance, and update the discourse-level states
if not self.params.use_bert:
if self.params.discourse_level_lstm:
utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
else:
utterance_token_embedder = self.input_embedder
final_utterance_state, utterance_states = self.utterance_encoder(
input_sequence,
utterance_token_embedder,
dropout_amount=self.dropout)
else:
final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=True)
input_hidden_states.extend(utterance_states)
input_sequences.append(input_sequence)
num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))
if self.params.discourse_level_lstm:
_, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states, self.dropout)
if self.params.use_utterance_attention:
final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)
if self.params.state_positional_embeddings:
utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
else:
flat_sequence = []
for utt in input_sequences[-num_utterances_to_keep:]:
flat_sequence.extend(utt)
snippets = None
if self.params.use_snippets:
if self.params.previous_decoder_snippet_encoding:
snippets = encode_snippets_with_states(available_snippets, decoder_states)
else:
snippets = self._encode_snippets(previous_query, available_snippets, input_schema)
if self.params.use_previous_query and len(previous_query) > 0:
previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)
identifier = interaction.identifier
if len(identifier.split('/')) == 2:
database_id, interaction_id = identifier.split('/')
else:
database_id = 'atis'
interaction_id = identifier
self.schema = self.database_schema[database_id]
prediction = self.predict_turn(final_utterance_state,
utterance_states,
schema_states,
max_generation_length,
gold_query=utterance.gold_query(),
snippets=snippets,
input_sequence=flat_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
feed_gold_tokens=feed_gold_query)
decoder_states = prediction[3]
predictions.append(prediction)
return predictions

View File

@ -0,0 +1,476 @@
"""Predicts a token."""
from collections import namedtuple
import torch
import torch.nn.functional as F
from . import torch_utils
from .attention import Attention, AttentionResult
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PredictionInput(namedtuple('PredictionInput',
('decoder_state',
'input_hidden_states',
'snippets',
'input_sequence'))):
""" Inputs to the token predictor. """
__slots__ = ()
class PredictionStepInputWithSchema(namedtuple('PredictionStepInputWithSchema',
('index',
'decoder_input',
'decoder_states',
'encoder_states',
'schema_states',
'snippets',
'gold_sequence',
'input_sequence',
'previous_queries',
'previous_query_states',
'input_schema',
'dropout_amount',
'predictions'))):
""" Inputs to the next token predictor. """
__slots__ = ()
class PredictionInputWithSchema(namedtuple('PredictionInputWithSchema',
('decoder_state',
'input_hidden_states',
'schema_states',
'snippets',
'input_sequence',
'previous_queries',
'previous_query_states',
'input_schema'))):
""" Inputs to the token predictor. """
__slots__ = ()
class TokenPrediction(namedtuple('TokenPrediction',
('scores',
'aligned_tokens',
'utterance_attention_results',
'schema_attention_results',
'query_attention_results',
'copy_switch',
'query_scores',
'query_tokens',
'decoder_state'))):
"""A token prediction."""
__slots__ = ()
def score_snippets(snippets, scorer):
""" Scores snippets given a scorer.
Inputs:
snippets (list of Snippet): The snippets to score.
scorer (dy.Expression): Dynet vector against which to score the snippets.
Returns:
dy.Expression, list of str, where the first is the scores and the second
is the names of the snippets that were scored.
"""
snippet_expressions = [snippet.embedding for snippet in snippets]
all_snippet_embeddings = torch.stack(snippet_expressions, dim=1)
scores = torch.t(torch.mm(torch.t(scorer), all_snippet_embeddings))
if scores.size()[0] != len(snippets):
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(snippets)) + " snippets")
return scores, [snippet.name for snippet in snippets]
def score_schema_tokens(input_schema, schema_states, scorer):
# schema_states: emd_dim x num_tokens
scores = torch.t(torch.mm(torch.t(scorer), schema_states)) # num_tokens x 1
if scores.size()[0] != len(input_schema):
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(input_schema)) + " schema tokens")
return scores, input_schema.column_names_surface_form
def score_query_tokens(previous_query, previous_query_states, scorer):
scores = torch.t(torch.mm(torch.t(scorer), previous_query_states)) # num_tokens x 1
if scores.size()[0] != len(previous_query):
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(previous_query)) + " query tokens")
return scores, previous_query
class TokenPredictor(torch.nn.Module):
""" Predicts a token given a (decoder) state.
Attributes:
vocabulary (Vocabulary): A vocabulary object for the output.
attention_module (Attention): An attention module.
state_transformation_weights (dy.Parameters): Transforms the input state
before predicting a token.
vocabulary_weights (dy.Parameters): Final layer weights.
vocabulary_biases (dy.Parameters): Final layer biases.
"""
def __init__(self, params, vocabulary, attention_key_size):
super().__init__()
self.params = params
self.vocabulary = vocabulary
self.attention_module = Attention(params.decoder_state_size, attention_key_size, attention_key_size)
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + attention_key_size, params.decoder_state_size), "weights-state-transform")
self.vocabulary_weights = torch_utils.add_params((params.decoder_state_size, len(vocabulary)), "weights-vocabulary")
self.vocabulary_biases = torch_utils.add_params(tuple([len(vocabulary)]), "biases-vocabulary")
def _get_intermediate_state(self, state, dropout_amount=0.):
intermediate_state = torch.tanh(torch_utils.linear_layer(state, self.state_transform_weights))
return F.dropout(intermediate_state, dropout_amount)
def _score_vocabulary_tokens(self, state):
scores = torch.t(torch_utils.linear_layer(state, self.vocabulary_weights, self.vocabulary_biases))
if scores.size()[0] != len(self.vocabulary.inorder_tokens):
raise ValueError("Got " + str(scores.size()[0]) + " scores for " + str(len(self.vocabulary.inorder_tokens)) + " vocabulary items")
return scores, self.vocabulary.inorder_tokens
def forward(self, prediction_input, dropout_amount=0.):
decoder_state = prediction_input.decoder_state
input_hidden_states = prediction_input.input_hidden_states
attention_results = self.attention_module(decoder_state, input_hidden_states)
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
return TokenPrediction(vocab_scores, vocab_tokens, attention_results, decoder_state)
class SchemaTokenPredictor(TokenPredictor):
""" Token predictor that also predicts snippets.
Attributes:
snippet_weights (dy.Parameter): Weights for scoring snippets against some
state.
"""
def __init__(self, params, vocabulary, utterance_attention_key_size, schema_attention_key_size, snippet_size):
TokenPredictor.__init__(self, params, vocabulary, utterance_attention_key_size)
if params.use_snippets:
if snippet_size <= 0:
raise ValueError("Snippet size must be greater than zero; was " + str(snippet_size))
self.snippet_weights = torch_utils.add_params((params.decoder_state_size, snippet_size), "weights-snippet")
if params.use_schema_attention:
self.utterance_attention_module = self.attention_module
self.schema_attention_module = Attention(params.decoder_state_size, schema_attention_key_size, schema_attention_key_size)
if self.params.use_query_attention:
self.query_attention_module = Attention(params.decoder_state_size, params.encoder_state_size, params.encoder_state_size)
self.start_query_attention_vector = torch_utils.add_params((params.encoder_state_size,), "start_query_attention_vector")
if params.use_schema_attention and self.params.use_query_attention:
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size + params.encoder_state_size, params.decoder_state_size), "weights-state-transform")
elif params.use_schema_attention:
self.state_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size, params.decoder_state_size), "weights-state-transform")
# Use lstm schema encoder
self.schema_token_weights = torch_utils.add_params((params.decoder_state_size, schema_attention_key_size), "weights-schema-token")
if self.params.use_previous_query:
self.query_token_weights = torch_utils.add_params((params.decoder_state_size, self.params.encoder_state_size), "weights-query-token")
if self.params.use_copy_switch:
if self.params.use_query_attention:
self.state2copyswitch_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size + params.encoder_state_size, 1), "weights-state-transform")
else:
self.state2copyswitch_transform_weights = torch_utils.add_params((params.decoder_state_size + utterance_attention_key_size + schema_attention_key_size, 1), "weights-state-transform")
def _get_snippet_scorer(self, state):
scorer = torch.t(torch_utils.linear_layer(state, self.snippet_weights))
return scorer
def _get_schema_token_scorer(self, state):
scorer = torch.t(torch_utils.linear_layer(state, self.schema_token_weights))
return scorer
def _get_query_token_scorer(self, state):
scorer = torch.t(torch_utils.linear_layer(state, self.query_token_weights))
return scorer
def _get_copy_switch(self, state):
copy_switch = torch.sigmoid(torch_utils.linear_layer(state, self.state2copyswitch_transform_weights))
return copy_switch.squeeze()
def forward(self, prediction_input, dropout_amount=0.):
decoder_state = prediction_input.decoder_state
input_hidden_states = prediction_input.input_hidden_states
snippets = prediction_input.snippets
input_schema = prediction_input.input_schema
schema_states = prediction_input.schema_states
if self.params.use_schema_attention:
schema_attention_results = self.schema_attention_module(decoder_state, schema_states)
utterance_attention_results = self.utterance_attention_module(decoder_state, input_hidden_states)
else:
utterance_attention_results = self.attention_module(decoder_state, input_hidden_states)
schema_attention_results = None
query_attention_results = None
if self.params.use_query_attention:
previous_query_states = prediction_input.previous_query_states
if len(previous_query_states) > 0:
query_attention_results = self.query_attention_module(decoder_state, previous_query_states[-1])
else:
query_attention_results = self.start_query_attention_vector
query_attention_results = AttentionResult(None, None, query_attention_results)
if self.params.use_schema_attention and self.params.use_query_attention:
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector, schema_attention_results.vector, query_attention_results.vector], dim=0)
elif self.params.use_schema_attention:
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector, schema_attention_results.vector], dim=0)
else:
state_and_attn = torch.cat([decoder_state, utterance_attention_results.vector], dim=0)
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
final_scores = vocab_scores
aligned_tokens = []
aligned_tokens.extend(vocab_tokens)
if self.params.use_snippets and snippets:
snippet_scores, snippet_tokens = score_snippets(snippets, self._get_snippet_scorer(intermediate_state))
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
aligned_tokens.extend(snippet_tokens)
schema_states = torch.stack(schema_states, dim=1)
schema_scores, schema_tokens = score_schema_tokens(input_schema, schema_states, self._get_schema_token_scorer(intermediate_state))
final_scores = torch.cat([final_scores, schema_scores], dim=0)
aligned_tokens.extend(schema_tokens)
# Previous Queries
previous_queries = prediction_input.previous_queries
previous_query_states = prediction_input.previous_query_states
copy_switch = None
query_scores = None
query_tokens = None
if self.params.use_previous_query and len(previous_queries) > 0:
if self.params.use_copy_switch:
copy_switch = self._get_copy_switch(state_and_attn)
for turn, (previous_query, previous_query_state) in enumerate(zip(previous_queries, previous_query_states)):
assert len(previous_query) == len(previous_query_state)
previous_query_state = torch.stack(previous_query_state, dim=1)
query_scores, query_tokens = score_query_tokens(previous_query, previous_query_state, self._get_query_token_scorer(intermediate_state))
query_scores = query_scores.squeeze()
final_scores = final_scores.squeeze()
return TokenPrediction(final_scores, aligned_tokens, utterance_attention_results, schema_attention_results, query_attention_results, copy_switch, query_scores, query_tokens, decoder_state)
class SnippetTokenPredictor(TokenPredictor):
""" Token predictor that also predicts snippets.
Attributes:
snippet_weights (dy.Parameter): Weights for scoring snippets against some
state.
"""
def __init__(self, params, vocabulary, attention_key_size, snippet_size):
TokenPredictor.__init__(self, params, vocabulary, attention_key_size)
if snippet_size <= 0:
raise ValueError("Snippet size must be greater than zero; was " + str(snippet_size))
self.snippet_weights = torch_utils.add_params((params.decoder_state_size, snippet_size), "weights-snippet")
def _get_snippet_scorer(self, state):
scorer = torch.t(torch_utils.linear_layer(state, self.snippet_weights))
return scorer
def forward(self, prediction_input, dropout_amount=0.):
decoder_state = prediction_input.decoder_state
input_hidden_states = prediction_input.input_hidden_states
snippets = prediction_input.snippets
attention_results = self.attention_module(decoder_state,
input_hidden_states)
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
intermediate_state = self._get_intermediate_state(
state_and_attn, dropout_amount=dropout_amount)
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(
intermediate_state)
final_scores = vocab_scores
aligned_tokens = []
aligned_tokens.extend(vocab_tokens)
if snippets:
snippet_scores, snippet_tokens = score_snippets(
snippets,
self._get_snippet_scorer(intermediate_state))
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
aligned_tokens.extend(snippet_tokens)
final_scores = final_scores.squeeze()
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
class AnonymizationTokenPredictor(TokenPredictor):
""" Token predictor that also predicts anonymization tokens.
Attributes:
anonymizer (Anonymizer): The anonymization object.
"""
def __init__(self, params, vocabulary, attention_key_size, anonymizer):
TokenPredictor.__init__(self, params, vocabulary, attention_key_size)
if not anonymizer:
raise ValueError("Expected an anonymizer, but was None")
self.anonymizer = anonymizer
def _score_anonymized_tokens(self,
input_sequence,
attention_scores):
scores = []
tokens = []
for i, token in enumerate(input_sequence):
if self.anonymizer.is_anon_tok(token):
scores.append(attention_scores[i])
tokens.append(token)
if len(scores) > 0:
if len(scores) != len(tokens):
raise ValueError("Got " + str(len(scores)) + " scores for "
+ str(len(tokens)) + " anonymized tokens")
anonymized_scores = torch.cat(scores, dim=0)
if anonymized_scores.dim() == 1:
anonymized_scores = anonymized_scores.unsqueeze(1)
return anonymized_scores, tokens
else:
return None, []
def forward(self,
prediction_input,
dropout_amount=0.):
decoder_state = prediction_input.decoder_state
input_hidden_states = prediction_input.input_hidden_states
input_sequence = prediction_input.input_sequence
assert input_sequence
attention_results = self.attention_module(decoder_state,
input_hidden_states)
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
intermediate_state = self._get_intermediate_state(
state_and_attn, dropout_amount=dropout_amount)
vocab_scores, vocab_tokens = self._score_vocabulary_tokens(
intermediate_state)
final_scores = vocab_scores
aligned_tokens = []
aligned_tokens.extend(vocab_tokens)
anonymized_scores, anonymized_tokens = self._score_anonymized_tokens(
input_sequence,
attention_results.scores)
if anonymized_scores:
final_scores = torch.cat([final_scores, anonymized_scores], dim=0)
aligned_tokens.extend(anonymized_tokens)
final_scores = final_scores.squeeze()
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
# For Atis
class SnippetAnonymizationTokenPredictor(SnippetTokenPredictor, AnonymizationTokenPredictor):
""" Token predictor that both anonymizes and scores snippets."""
def __init__(self, params, vocabulary, attention_key_size, snippet_size, anonymizer):
AnonymizationTokenPredictor.__init__(self, params, vocabulary, attention_key_size, anonymizer)
SnippetTokenPredictor.__init__(self, params, vocabulary, attention_key_size, snippet_size)
def forward(self, prediction_input, dropout_amount=0.):
decoder_state = prediction_input.decoder_state
assert prediction_input.input_sequence
snippets = prediction_input.snippets
input_hidden_states = prediction_input.input_hidden_states
attention_results = self.attention_module(decoder_state,
prediction_input.input_hidden_states)
state_and_attn = torch.cat([decoder_state, attention_results.vector], dim=0)
intermediate_state = self._get_intermediate_state(state_and_attn, dropout_amount=dropout_amount)
# Vocabulary tokens
final_scores, vocab_tokens = self._score_vocabulary_tokens(intermediate_state)
aligned_tokens = []
aligned_tokens.extend(vocab_tokens)
# Snippets
if snippets:
snippet_scores, snippet_tokens = score_snippets(
snippets,
self._get_snippet_scorer(intermediate_state))
final_scores = torch.cat([final_scores, snippet_scores], dim=0)
aligned_tokens.extend(snippet_tokens)
# Anonymized tokens
anonymized_scores, anonymized_tokens = self._score_anonymized_tokens(
prediction_input.input_sequence,
attention_results.scores)
if anonymized_scores is not None:
final_scores = torch.cat([final_scores, anonymized_scores], dim=0)
aligned_tokens.extend(anonymized_tokens)
final_scores = final_scores.squeeze()
return TokenPrediction(final_scores, aligned_tokens, attention_results, None, None, None, None, None, decoder_state)
def construct_token_predictor(params,
vocabulary,
utterance_attention_key_size,
schema_attention_key_size,
snippet_size,
anonymizer=None):
""" Constructs a token predictor given the parameters.
Inputs:
parameter_collection (dy.ParameterCollection): Contains the parameters.
params (dictionary): Contains the command line parameters/hyperparameters.
vocabulary (Vocabulary): Vocabulary object for output generation.
attention_key_size (int): The size of the attention keys.
anonymizer (Anonymizer): An anonymization object.
"""
if not anonymizer and not params.previous_decoder_snippet_encoding:
print('using SchemaTokenPredictor')
return SchemaTokenPredictor(params, vocabulary, utterance_attention_key_size, schema_attention_key_size, snippet_size)
elif params.use_snippets and anonymizer and not params.previous_decoder_snippet_encoding:
print('using SnippetAnonymizationTokenPredictor')
return SnippetAnonymizationTokenPredictor(params,
vocabulary,
utterance_attention_key_size,
snippet_size,
anonymizer)
else:
print('Unknown token_predictor')
exit()

View File

@ -0,0 +1,217 @@
"""Contains various utility functions for Dynet models."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def linear_layer(exp, weights, biases=None):
# exp: input as size_1 or 1 x size_1
# weight: size_1 x size_2
# bias: size_2
if exp.dim() == 1:
exp = torch.unsqueeze(exp, 0)
assert exp.size()[1] == weights.size()[0]
if biases is not None:
assert weights.size()[1] == biases.size()[0]
result = torch.mm(exp, weights) + biases
else:
result = torch.mm(exp, weights)
return result
def compute_loss(gold_seq,
scores,
index_to_token_maps,
gold_tok_to_id,
noise=0.00000001):
""" Computes the loss of a gold sequence given scores.
Inputs:
gold_seq (list of str): A sequence of gold tokens.
scores (list of dy.Expression): Expressions representing the scores of
potential output tokens for each token in gold_seq.
index_to_token_maps (list of dict str->list of int): Maps from index in the
sequence to a dictionary mapping from a string to a set of integers.
gold_tok_to_id (lambda (str, str)->list of int): Maps from the gold token
and some lookup function to the indices in the probability distribution
where the gold token occurs.
noise (float, optional): The amount of noise to add to the loss.
Returns:
dy.Expression representing the sum of losses over the sequence.
"""
assert len(gold_seq) == len(scores) == len(index_to_token_maps)
losses = []
for i, gold_tok in enumerate(gold_seq):
score = scores[i]
token_map = index_to_token_maps[i]
gold_indices = gold_tok_to_id(gold_tok, token_map)
assert len(gold_indices) > 0
noise_i = noise
if len(gold_indices) == 1:
noise_i = 0
probdist = score
prob_of_tok = noise_i + torch.sum(probdist[gold_indices])
losses.append(-torch.log(prob_of_tok))
return torch.sum(torch.stack(losses))
def get_seq_from_scores(scores, index_to_token_maps):
"""Gets the argmax sequence from a set of scores.
Inputs:
scores (list of dy.Expression): Sequences of output scores.
index_to_token_maps (list of list of str): For each output token, maps
the index in the probability distribution to a string.
Returns:
list of str, representing the argmax sequence.
"""
seq = []
for score, tok_map in zip(scores, index_to_token_maps):
# score_numpy_list = score.cpu().detach().numpy()
score_numpy_list = score.cpu().data.numpy()
assert score.size()[0] == len(tok_map) == len(list(score_numpy_list))
seq.append(tok_map[np.argmax(score_numpy_list)])
return seq
def per_token_accuracy(gold_seq, pred_seq):
""" Returns the per-token accuracy comparing two strings (recall).
Inputs:
gold_seq (list of str): A list of gold tokens.
pred_seq (list of str): A list of predicted tokens.
Returns:
float, representing the accuracy.
"""
num_correct = 0
for i, gold_token in enumerate(gold_seq):
if i < len(pred_seq) and pred_seq[i] == gold_token:
num_correct += 1
return float(num_correct) / len(gold_seq)
def forward_one_multilayer(rnns, lstm_input, layer_states, dropout_amount=0.):
""" Goes forward for one multilayer RNN cell step.
Inputs:
lstm_input (dy.Expression): Some input to the step.
layer_states (list of dy.RNNState): The states of each layer in the cell.
dropout_amount (float, optional): The amount of dropout to apply, in
between the layers.
Returns:
(list of dy.Expression, list of dy.Expression), dy.Expression, (list of dy.RNNSTate),
representing (each layer's cell memory, each layer's cell hidden state),
the final hidden state, and (each layer's updated RNNState).
"""
num_layers = len(layer_states)
new_states = []
cell_states = []
hidden_states = []
state = lstm_input
for i in range(num_layers):
# view as (1, input_size)
layer_h, layer_c = rnns[i](torch.unsqueeze(state,0), layer_states[i])
new_states.append((layer_h, layer_c))
layer_h = layer_h.squeeze()
layer_c = layer_c.squeeze()
state = layer_h
if i < num_layers - 1:
# In both Dynet and Pytorch
# p stands for probability of an element to be zeroed. i.e. p=1 means switch off all activations.
state = F.dropout(state, p=dropout_amount)
cell_states.append(layer_c)
hidden_states.append(layer_h)
return (cell_states, hidden_states), state, new_states
def encode_sequence(sequence, rnns, embedder, dropout_amount=0.):
""" Encodes a sequence given RNN cells and an embedding function.
Inputs:
seq (list of str): The sequence to encode.
rnns (list of dy._RNNBuilder): The RNNs to use.
emb_fn (dict str->dy.Expression): Function that embeds strings to
word vectors.
size (int): The size of the RNN.
dropout_amount (float, optional): The amount of dropout to apply.
Returns:
(list of dy.Expression, list of dy.Expression), list of dy.Expression,
where the first pair is the (final cell memories, final cell states) of
all layers, and the second list is a list of the final layer's cell
state for all tokens in the sequence.
"""
batch_size = 1
layer_states = []
for rnn in rnns:
hidden_size = rnn.weight_hh.size()[1]
# h_0 of shape (batch, hidden_size)
# c_0 of shape (batch, hidden_size)
if rnn.weight_hh.is_cuda:
h_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0)
c_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0)
else:
h_0 = torch.zeros(batch_size,hidden_size)
c_0 = torch.zeros(batch_size,hidden_size)
layer_states.append((h_0, c_0))
outputs = []
for token in sequence:
rnn_input = embedder(token)
(cell_states, hidden_states), output, layer_states = forward_one_multilayer(rnns,rnn_input,layer_states,dropout_amount)
outputs.append(output)
return (cell_states, hidden_states), outputs
def create_multilayer_lstm_params(num_layers, in_size, state_size, name=""):
""" Adds a multilayer LSTM to the model parameters.
Inputs:
num_layers (int): Number of layers to create.
in_size (int): The input size to the first layer.
state_size (int): The size of the states.
model (dy.ParameterCollection): The parameter collection for the model.
name (str, optional): The name of the multilayer LSTM.
"""
lstm_layers = []
for i in range(num_layers):
layer_name = name + "-" + str(i)
# print("LSTM " + layer_name + ": " + str(in_size) + " x " + str(state_size) + "; default Dynet initialization of hidden weights")
lstm_layer = torch.nn.LSTMCell(input_size=int(in_size), hidden_size=int(state_size), bias=True)
lstm_layers.append(lstm_layer)
in_size = state_size
return torch.nn.ModuleList(lstm_layers)
def add_params(size, name=""):
""" Adds parameters to the model.
Inputs:
model (dy.ParameterCollection): The parameter collection for the model.
size (tuple of int): The size to create.
name (str, optional): The name of the parameters.
"""
# if len(size) == 1:
# print("vector " + name + ": " + str(size[0]) + "; uniform in [-0.1, 0.1]")
# else:
# print("matrix " + name + ": " + str(size[0]) + " x " + str(size[1]) + "; uniform in [-0.1, 0.1]")
size_int = tuple([int(ss) for ss in size])
return torch.nn.Parameter(torch.empty(size_int).uniform_(-0.1, 0.1))

View File

@ -0,0 +1,16 @@
# -*- encoding:utf-8 -*-
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super(LayerNorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(hidden_size) * 0.2)
self.beta = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x-mean) / (std+self.eps) + self.beta

View File

@ -0,0 +1,65 @@
# -*- encoding:utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadedAttention(nn.Module):
"""
Each head is a self-attention operation.
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self, hidden_size, heads_num):
super(MultiHeadedAttention, self).__init__()
self.hidden_size = hidden_size
self.heads_num = heads_num
self.per_head_size = hidden_size // heads_num
self.linear_layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(3)
])
self.final_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, key, value, query, mask, dropout):
"""
Args:
key: [batch_size x seq_length x hidden_size]
value: [batch_size x seq_length x hidden_size]
query: [batch_size x seq_length x hidden_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size, seq_length, hidden_size = key.size()
heads_num = self.heads_num
per_head_size = self.per_head_size
def shape(x):
return x. \
contiguous(). \
view(batch_size, seq_length, heads_num, per_head_size). \
transpose(1, 2)
def unshape(x):
return x. \
transpose(1, 2). \
contiguous(). \
view(batch_size, seq_length, hidden_size)
query, key, value = [l(x). \
view(batch_size, -1, heads_num, per_head_size). \
transpose(1, 2) \
for l, x in zip(self.linear_layers, (query, key, value))
]
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / math.sqrt(float(per_head_size))
scores = scores + mask
probs = nn.Softmax(dim=-1)(scores)
probs = F.dropout(probs, p=dropout)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
return output

View File

@ -0,0 +1,19 @@
import torch.nn as nn
import math
import torch
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class PositionwiseFeedForward(nn.Module):
""" Feed Forward Layer """
def __init__(self, hidden_size, feedforward_size):
super(PositionwiseFeedForward, self).__init__()
self.linear_1 = nn.Linear(hidden_size, feedforward_size)
self.linear_2 = nn.Linear(feedforward_size, hidden_size)
def forward(self, x):
inter = gelu(self.linear_1(x))
output = self.linear_2(inter)
return output

View File

@ -0,0 +1,122 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class RatMultiHeadedAttention(nn.Module):
"""
Each head is a self-attention operation.
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self, hidden_size, heads_num, relationship_number):
super(RatMultiHeadedAttention, self).__init__()
self.hidden_size = hidden_size
self.heads_num = heads_num
self.per_head_size = hidden_size // heads_num
self.linear_layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size) for _ in range(3)
])
if relationship_number != -1:
assert relationship_number == self.per_head_size
self.relationship_K_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(heads_num, self.per_head_size)))
self.relationship_V_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(heads_num, self.per_head_size)))
# cross mutli-head
#self.relationship_K_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(1, self.per_head_size)))
#self.relationship_V_parameter = nn.Parameter(torch.nn.init.uniform_(torch.Tensor(1, self.per_head_size)))
self.final_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, key, value, query, relationship_matrix, dropout):
"""
Args:
key: [batch_size x seq_length x hidden_size]
value: [batch_size x seq_length x hidden_size]
query: [batch_size x seq_length x hidden_size]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size, seq_length, hidden_size = key.size()
heads_num = self.heads_num
per_head_size = self.per_head_size
def shape(x):
return x. \
contiguous(). \
view(batch_size, seq_length, heads_num, per_head_size). \
transpose(1, 2)
def unshape(x):
return x. \
transpose(1, 2). \
contiguous(). \
view(batch_size, seq_length, hidden_size)
# torch.Size([1, 30, 43, 10]) torch.Size([1, 30, 43, 10]) torch.Size([1, 30, 43, 10])
query, key, value = [l(x). \
view(batch_size, -1, heads_num, per_head_size). \
transpose(1, 2) \
for l, x in zip(self.linear_layers, (query, key, value))
]
#self.relationship_K_parameter [30, 10]
# relation ship torch.Size([1, 43, 43, 10])
relationship_matrix = relationship_matrix.unsqueeze(3).repeat(1, 1, 1, heads_num, 1)
# ↑ torch.Size([1, 43, 43, 30, 10])
kk = self.relationship_K_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length, 1, 1)
#kk = self.relationship_K_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length,1,1)
# ↑ torch.Size([1, 43, 43, 30, 10])
vv = self.relationship_V_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length, 1, 1)
#vv = self.relationship_V_parameter.unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, seq_length,1,1)
# ↑ torch.Size([1, 43, 43, 30, 10])
relationship_K = relationship_matrix * kk
relationship_V = relationship_matrix * vv
#print(relationship_K[0,17,7,:,:])
#print(relationship_V[0,17,7,:,:])
# [1, 43, 43 ,30 ,10] -> [1, 30, 43, 43, 10]
relationship_K = relationship_K.transpose(1, 3)
relationship_V = relationship_V.transpose(1, 3)
#[1, 30, 43, 10] -> [1, 30, 43, 43, 10]
query = query.unsqueeze(3).repeat(1, 1, 1, seq_length, 1)
key = key.unsqueeze(2).repeat(1, 1, seq_length, 1, 1)
#↓ [1, 30, 43, 43]
#scores = (query*(relationship_K)).sum(dim=-1)
scores = (query*(key+relationship_K)).sum(dim=-1)
scores = scores / math.sqrt(float(per_head_size))
probs = nn.Softmax(dim=-1)(scores)
"""
save_dct = {}
for i in range(30):
attention_map = probs[0, i, :, :]
save_dct[i] = attention_map.cpu().numpy()
import pickle as pkl
#print(save_dct)
pkl.dump(save_dct, open("./save_att_relation.pkl", "wb"))
print("finish")
time.sleep(5)
"""
#probs = F.dropout(probs, p=dropout)
probs = probs.unsqueeze(4).repeat(1, 1, 1, 1, per_head_size)
value = value.unsqueeze(2).repeat(1, 1, seq_length, 1, 1)
output = (probs*(value+relationship_V)).sum(dim=-2)
output = unshape(output)
output = self.final_linear(output)
return output

View File

@ -0,0 +1,47 @@
# -*- encoding:utf-8 -*-
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
from .layer_norm import LayerNorm
from .position_ffn import PositionwiseFeedForward
from .rat_multi_headed_attn import RatMultiHeadedAttention
class RATTransoformer(nn.Module):
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, input_size, state_size, relationship_number):
super(RATTransoformer, self).__init__()
input_size = int(input_size)
assert input_size == state_size
heads_num = 30
self.self_attn = RatMultiHeadedAttention(state_size, heads_num, relationship_number)
#self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
#self.layer_norm_2 = LayerNorm(state_size)
def forward(self, hidden, relationship_matrix, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
assert dropout_amount == 0
batch_size = 1
seq_length = hidden.size()[1]
inter = self.self_attn(hidden, hidden, hidden, relationship_matrix, dropout_amount)
output = self.feed_forward(inter)
# inter = F.dropout(self.self_attn(hidden, hidden, hidden, relationship_matrix, dropout_amount), p=dropout_amount)
#inter = self.layer_norm_1(inter + hidden)
# output = F.dropout(self.feed_forward(inter), p=dropout_amount)
#output = self.layer_norm_2(output + inter)
return output

View File

@ -0,0 +1,182 @@
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from .layer_norm import LayerNorm
from .position_ffn import PositionwiseFeedForward
from .multi_headed_attn import MultiHeadedAttention
from .rat_transformer_layer import RATTransoformer
class TransformerAttention(nn.Module):
def __init__(self, input_size, state_size):
super(TransformerAttention, self).__init__()
input_size = int(input_size)
assert input_size == state_size
self.relationship_number = 10
self.rat_transoformer = RATTransoformer(input_size, state_size, relationship_number=self.relationship_number)
def forward(self, utterance_states, schema_states, input_sequence, unflat_sequence, input_schema, schema, dropout_amount=0):
"""
Args:
utterance_states : [utterance_len x emb_size]
schema_states : [schema_len x emb_size]
input_sequence : utterance_len
input_schema : schema_len
Returns:
utterance_output: [emb_size x utterance_len]
schema_output: [emb_size x schema_len]
"""
assert utterance_states.size()[1] == 350
utterance_states = utterance_states[:,:300]
print('input_sequence', input_sequence)
#print('input_schema.column_names_surface_form', input_schema.column_names_surface_form)
if True:
#print('foreign_keys', schema['foreign_keys'])
#print('input_schema.column_names_surface_form', input_schema.column_names_surface_form)
utterance_len = utterance_states.size()[0]
schema_len = schema_states.size()[0]
#print(utterance_len, schema_len)
assert len(input_sequence) == utterance_len
assert len(input_schema.column_names_surface_form) == schema_len
relationship_matrix = torch.zeros([ utterance_len+schema_len, utterance_len+schema_len,self.relationship_number])
def get_name(s):
if s == '*':
return '', ''
table_name, column_num = s.split('.')
if column_num == '*':
column_num = ''
return table_name, column_num
def is_ngram_match(s1, s2, n):
vis = set()
for i in range(len(s1)-n+1):
vis.add( s1[i:i+n] )
for i in range(len(s2)-n+1):
if s2[i:i+n] in vis:
return True
return False
''' 0. same table
1. foreign key
2. word and table: exact math
3-5. word and table: n-gram 3-5
6. word and column: exact math
7-9. word and column: n-gram 3-5'''
for i in range(schema_len):
table_i, column_i = get_name(input_schema.column_names_surface_form[i])
idx = i+utterance_len
for j in range(schema_len):
table_j, column_j = get_name(input_schema.column_names_surface_form[j])
jdx = j+utterance_len
if table_i == table_j:
relationship_matrix[idx][jdx][0] = 1
for i, j in schema['foreign_keys']:
idx = i+utterance_len
jdx = j+utterance_len
relationship_matrix[idx][jdx][1] = 1
relationship_matrix[jdx][idx][1] = 1
#print(idx, jdx)
#print(input_schema.column_names_surface_form)
#print(input_schema.column_names_surface_form[i], input_schema.column_names_surface_form[j])
for i in range(schema_len):
table, column = get_name(input_schema.column_names_surface_form[i])
idx = i+utterance_len
for j in range(utterance_len):
word = input_sequence[j].lower()
jdx = j
#word and table
no_math = True
if word == table:
relationship_matrix[idx][jdx][2] = 1
relationship_matrix[jdx][idx][2] = 1
for n in [5,4,3]:
if no_math and is_ngram_match(word, table, n):
relationship_matrix[idx][jdx][n] = 1
relationship_matrix[jdx][idx][n] = 1
#word and column
no_math = True
if word == column:
relationship_matrix[idx][jdx][6] = 1
relationship_matrix[jdx][idx][6] = 1
for n in [5,4,3]:
if no_math and is_ngram_match(word, column, n):
relationship_matrix[idx][jdx][4+n] = 1
relationship_matrix[jdx][idx][4+n] = 1
#print(relationship_matrix)
#print(unflat_sequence)
# start = 0
# for i in unflat_sequence:
# start = len(i) + start
# relationship_matrix[0:start][0:start][:] *= 0.98
relationship_matrix = relationship_matrix.unsqueeze(0).cuda() #[1, all_len, all_len, 17]
"""
print(relationship_matrix.shape)
matrix = torch.zeros_like(relationship_matrix[:,:,:,0]).cuda()
print(matrix.shape)
for i in range(relationship_matrix.shape[-1]):
matrix += relationship_matrix[:,:,:,i]
#matrix = matrix.squeeze(0).cpu().numpy()
label = input_sequence + input_schema.column_names_surface_form
matrix = matrix.squeeze(0).cpu().numpy()
save_matrix(matrix, label, 0)
"""
utterance_and_schema_states = torch.cat([utterance_states, schema_states], dim=0).unsqueeze(0) #[1, all_len, emb_size]
utterance_and_schema_output = self.rat_transoformer(utterance_and_schema_states, relationship_matrix, 0) #[1, all_len, emb_size]
#print(utterance_and_schema_output.shape)
utterance_output = utterance_and_schema_output[0,:utterance_len,:]
schema_output = utterance_and_schema_output[0,utterance_len:,:]
else:
utterance_output = utterance_states
schema_output = schema_states
utterance_output = torch.transpose(utterance_output, 1, 0)
schema_output = torch.transpose(schema_output, 1, 0)
return utterance_output, schema_output
def save_matrix(matrix, label, i):
"""
matrix: n x n
label: len(str) == n
"""
plt.rcParams["figure.figsize"] = (1000,1000)
plt.matshow(matrix, cmap=plt.cm.Reds)
plt.xticks(np.arange(len(label)), label, rotation=90, fontsize=16)
plt.yticks(np.arange(len(label)), label, fontsize=16)
plt.margins(2)
plt.subplots_adjust()
plt.savefig('relation_' + str(i), dpi=300)
exit()

View File

@ -0,0 +1,455 @@
# -*- encoding:utf-8 -*-
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
from .layer_norm import LayerNorm
from .position_ffn import PositionwiseFeedForward
from .multi_headed_attn import MultiHeadedAttention
from ..encoder import Encoder as Encoder2
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class oldestTransformerLayer(nn.Module): #oldest
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, num_layers, input_size, state_size):
super(TransformerLayer, self).__init__()
#num_layers is no useful
print('input_size',input_size, type(input_size))
print('state_size',state_size, type(state_size))
input_size = int(input_size)
heads_num = -1
for i in range(4, 10):
if state_size % i == 0:
heads_num = i
assert heads_num != -1
self.first_linear = nn.Linear(input_size, state_size)
self.self_attn = MultiHeadedAttention(state_size, heads_num)
self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
self.layer_norm_2 = LayerNorm(state_size)
self._1_self_attn = MultiHeadedAttention(state_size, heads_num)
self._1_layer_norm_1 = LayerNorm(state_size)
self._1_feed_forward = PositionwiseFeedForward(state_size, state_size)
self._1_layer_norm_2 = LayerNorm(state_size)
self._2_self_attn = MultiHeadedAttention(state_size, heads_num)
self._2_layer_norm_1 = LayerNorm(state_size)
self._2_feed_forward = PositionwiseFeedForward(state_size, state_size)
self._2_layer_norm_2 = LayerNorm(state_size)
'''self._3_self_attn = MultiHeadedAttention(state_size, heads_num)
self._3_layer_norm_1 = LayerNorm(state_size)
self._3_feed_forward = PositionwiseFeedForward(state_size, state_size)
self._3_layer_norm_2 = LayerNorm(state_size)
self._4_self_attn = MultiHeadedAttention(state_size, heads_num)
self._4_layer_norm_1 = LayerNorm(state_size)
self._4_feed_forward = PositionwiseFeedForward(state_size, state_size)
self._4_layer_norm_2 = LayerNorm(state_size)'''
self.X = Encoder2(num_layers, input_size, state_size)
'''self.cell_memories_linear = nn.Linear(1, input_size)
self.hidden_states_linear = nn.Linear(1, input_size)'''
self.last_linear = nn.Linear(state_size*2, state_size)
def forward(self, sequence, embedder, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size = 1
seq_length = len(sequence)
hidden = []
for token in sequence:
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
'''hidden.append(self.cell_memories_linear( torch.ones([1,1,1]).cuda() ))
hidden.append(self.hidden_states_linear( torch.ones([1,1,1]).cuda() ))'''
hidden = torch.cat(hidden, 1)
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
inter = F.dropout(self._1_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self._1_layer_norm_1(inter + hidden)
output = F.dropout(self._1_feed_forward(inter), p=dropout_amount)
hidden = self._1_layer_norm_2(output + inter)
inter = F.dropout(self._2_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self._2_layer_norm_1(inter + hidden)
output = F.dropout(self._2_feed_forward(inter), p=dropout_amount)
hidden = self._2_layer_norm_2(output + inter)
'''inter = F.dropout(self._3_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self._3_layer_norm_1(inter + hidden)
output = F.dropout(self._3_feed_forward(inter), p=dropout_amount)
hidden = self._3_layer_norm_2(output + inter)
inter = F.dropout(self._4_self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self._4_layer_norm_1(inter + hidden)
output = F.dropout(self._4_feed_forward(inter), p=dropout_amount)
hidden = self._4_layer_norm_2(output + inter)'''
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
final_outputs = []
#assert len(sequence) + 2 == output.size()[1]
for i in range(len(sequence)):
x = output[0,i,]
final_outputs.append(x)
#cell_memories = [final_outputs[-2]]
#hidden_states = [final_outputs[-1]]
#final_outputs = final_outputs[:-2]
#return (cell_memories, hidden_states), final_outputs
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
#print('<', final_outputs[0].mean().item(), final_outputs[0].std().item(), '>', '<', final_outputs2[0].mean().item(), final_outputs2[0].std().item(), '>')
for i in range(len(sequence)):
final_outputs[i] = F.dropout(gelu(self.last_linear(torch.cat( [final_outputs2[i], final_outputs[i]]))), p=dropout_amount)
return x, final_outputs
class NewNewTransformerLayer(nn.Module): #Newnew
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, num_layers, input_size, _state_size):
super(TransformerLayer, self).__init__()
#num_layers is no useful
print('input_size', input_size, type(input_size))
print('_state_size', _state_size, type(_state_size))
if _state_size == 650:
state_size = 324
else:
state_size = _state_size//2
input_size = int(input_size)
heads_num = -1
for i in range(4, 10):
if state_size % i == 0:
heads_num = i
assert heads_num != -1
self.first_linear = nn.Linear(input_size, state_size)
self.self_attn = MultiHeadedAttention(state_size, heads_num)
self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
self.layer_norm_2 = LayerNorm(state_size)
self.cell_memories_linear = nn.Linear(state_size, state_size)
self.hidden_states_linear = nn.Linear(state_size, state_size)
self.X = Encoder2(num_layers, input_size, _state_size-state_size)
def forward(self, sequence, embedder, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size = 1
seq_length = len(sequence)
hidden = []
for token in sequence:
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
hidden = torch.cat(hidden, 1)
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self.layer_norm_1(inter + hidden)
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
output = self.layer_norm_2(output + inter)
final_outputs = []
assert len(sequence) == output.size()[1]
for i in range(len(sequence)):
x = output[0,i,]
final_outputs.append(x)
x = output[0].mean(dim=0)
cell_memories = self.cell_memories_linear(x)
cell_memories = [F.dropout(gelu(cell_memories), p=dropout_amount)]
hidden_states = self.hidden_states_linear(x)
hidden_states = [F.dropout(gelu(hidden_states), p=dropout_amount)]
'''print('hidden_states[0]', hidden_states[0].size())
print('cell_memories[0]', cell_memories[0].size())
print('final_outputs[0]', final_outputs[0].size())'''
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
cell_memories2 = x[0]
hidden_states2 = x[1]
for i in range(len(sequence)):
final_outputs[i] = torch.cat( [final_outputs[i], final_outputs2[i]])
cell_memories[0] = torch.cat( [cell_memories[0], cell_memories2[0]])
hidden_states[0] = torch.cat( [hidden_states[0], hidden_states2[0]])
'''print('hidden_states[0]', hidden_states[0].size())
print('cell_memories[0]', cell_memories[0].size())
print('final_outputs[0]', final_outputs[0].size())'''
return (cell_memories, hidden_states), final_outputs
class TransformerLayer(nn.Module): #New
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, num_layers, input_size, state_size):
super(TransformerLayer, self).__init__()
#num_layers is no useful
print('input_size',input_size, type(input_size))
print('state_size',state_size, type(state_size))
input_size = int(input_size)
heads_num = -1
for i in range(4, 10):
if state_size % i == 0:
heads_num = i
assert heads_num != -1
self.first_linear = nn.Linear(input_size, state_size)
self.self_attn = MultiHeadedAttention(state_size, heads_num)
self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
self.layer_norm_2 = LayerNorm(state_size)
self.X = Encoder2(num_layers, input_size, state_size)
self.last_linear = nn.Linear(state_size*2, state_size)
def forward(self, sequence, embedder, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size = 1
seq_length = len(sequence)
hidden = []
for token in sequence:
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
hidden = torch.cat(hidden, 1)
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self.layer_norm_1(inter + hidden)
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
output = self.layer_norm_2(output + inter)
final_outputs = []
assert len(sequence) == output.size()[1]
for i in range(len(sequence)):
x = output[0,i,]
final_outputs.append(x)
x, final_outputs2 = self.X(sequence, embedder, dropout_amount)
for i in range(len(sequence)):
#final_outputs[i] = F.dropout(final_outputs2[i]+final_outputs[i], p=dropout_amount)
final_outputs[i] = F.dropout(gelu(self.last_linear(torch.cat( [final_outputs2[i], final_outputs[i]]))), p=dropout_amount)
return x, final_outputs
class Old2only2TransformerLayer(nn.Module): #Old2 2020 05 24 01.47
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, num_layers, input_size, state_size):
super(TransformerLayer, self).__init__()
#num_layers is no useful
print('input_size',input_size, type(input_size))
print('state_size',state_size, type(state_size))
input_size = int(input_size)
heads_num = -1
for i in range(4, 10):
if state_size % i == 0:
heads_num = i
assert heads_num != -1
self.first_linear = nn.Linear(input_size, state_size)
self.self_attn = MultiHeadedAttention(state_size, heads_num)
self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
self.layer_norm_2 = LayerNorm(state_size)
#self.cell_memories_linear = nn.Linear(state_size, state_size)
#self.hidden_states_linear = nn.Linear(state_size, state_size)
#self.last_linear = nn.Linear(state_size+state_size, state_size)
def forward(self, sequence, embedder, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size = 1
seq_length = len(sequence)
hidden = []
for token in sequence:
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
hidden = torch.cat(hidden, 1)
mask = torch.zeros([batch_size, 1, seq_length, seq_length]).cuda()
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self.layer_norm_1(inter + hidden)
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
output = self.layer_norm_2(output + inter)
final_outputs = []
assert len(sequence) == output.size()[1]
#output = F.dropout(gelu( self.last_linear(torch.cat([output, hidden], dim=-1)) ), p=dropout_amount)
for i in range(len(sequence)):
x = output[0,i,]
final_outputs.append(x)
#x = output[0].mean(dim=0)
#cell_memories = self.cell_memories_linear(x)
#cell_memories = [F.dropout(gelu(cell_memories), p=dropout_amount)]
#hidden_states = self.hidden_states_linear(x)
#hidden_states = [F.dropout(gelu(hidden_states), p=dropout_amount)]
#return (cell_memories, hidden_states), final_outputs
return None, final_outputs
class OldTransformerLayer(nn.Module):
"""
Transformer layer mainly consists of two parts:
multi-headed self-attention and feed forward layer.
"""
def __init__(self, num_layers, input_size, state_size):
super(TransformerLayer, self).__init__()
#num_layers is no useful
print('input_size',input_size, type(input_size))
print('state_size',state_size, type(state_size))
input_size = int(input_size)
heads_num = -1
for i in range(4, 10):
if state_size % i == 0:
heads_num = i
assert heads_num != -1
self.first_linear = nn.Linear(input_size, state_size)
self.self_attn = MultiHeadedAttention(state_size, heads_num)
self.layer_norm_1 = LayerNorm(state_size)
self.feed_forward = PositionwiseFeedForward(state_size, state_size)
self.layer_norm_2 = LayerNorm(state_size)
self.cell_memories_linear = nn.Linear(1, input_size)
self.hidden_states_linear = nn.Linear(1, input_size)
def forward(self, sequence, embedder, dropout_amount=0):
"""
Args:
hidden: [batch_size x seq_length x emb_size]
mask: [batch_size x 1 x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size = 1
seq_length = len(sequence)
hidden = []
for token in sequence:
hidden.append(embedder(token).unsqueeze(0).unsqueeze(0))
hidden.append(self.cell_memories_linear( torch.ones([1,1,1]).cuda() ))
hidden.append(self.hidden_states_linear( torch.ones([1,1,1]).cuda() ))
hidden = torch.cat(hidden, 1)
mask = torch.zeros([batch_size, 1, seq_length+2, seq_length+2]).cuda()
hidden = F.dropout(gelu(self.first_linear(hidden)), p=dropout_amount)
inter = F.dropout(self.self_attn(hidden, hidden, hidden, mask, dropout_amount), p=dropout_amount)
inter = self.layer_norm_1(inter + hidden)
output = F.dropout(self.feed_forward(inter), p=dropout_amount)
output = self.layer_norm_2(output + inter)
final_outputs = []
assert len(sequence) + 2 == output.size()[1]
for i in range(len(sequence)+2):
x = output[0,i,]
final_outputs.append(x)
cell_memories = [final_outputs[-2]]
hidden_states = [final_outputs[-1]]
final_outputs = final_outputs[:-2]
return (cell_memories, hidden_states), final_outputs

View File

@ -0,0 +1,398 @@
# modified from https://github.com/naver/sqlova
import os, json
import random as rd
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .bert import tokenization as tokenization
from .bert.modeling import BertConfig, BertModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_bert(params):
BERT_PT_PATH = './model/bert/data/annotated_wikisql_and_PyTorch_bert_param'
map_bert_type_abb = {'uS': 'uncased_L-12_H-768_A-12',
'uL': 'uncased_L-24_H-1024_A-16',
'cS': 'cased_L-12_H-768_A-12',
'cL': 'cased_L-24_H-1024_A-16',
'mcS': 'multi_cased_L-12_H-768_A-12'}
bert_type = map_bert_type_abb[params.bert_type_abb]
if params.bert_type_abb == 'cS' or params.bert_type_abb == 'cL' or params.bert_type_abb == 'mcS':
do_lower_case = False
else:
do_lower_case = True
no_pretraining = False
bert_config_file = os.path.join(BERT_PT_PATH, f'bert_config_{bert_type}.json')
vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt')
init_checkpoint = os.path.join(BERT_PT_PATH, f'pytorch_model_{bert_type}.bin')
print('bert_config_file', bert_config_file)
print('vocab_file', vocab_file)
print('init_checkpoint', init_checkpoint)
bert_config = BertConfig.from_json_file(bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
bert_config.print_status()
model_bert = BertModel(bert_config)
if no_pretraining:
pass
else:
model_bert.load_state_dict(torch.load(init_checkpoint, map_location='cpu'))
print("Load pre-trained parameters.")
model_bert.to(device)
return model_bert, tokenizer, bert_config
def generate_inputs(tokenizer, nlu1_tok, hds1):
tokens = []
segment_ids = []
t_to_tt_idx_hds1 = []
tokens.append("[CLS]")
i_st_nlu = len(tokens) # to use it later
segment_ids.append(0)
for token in nlu1_tok:
tokens.append(token)
segment_ids.append(0)
i_ed_nlu = len(tokens)
tokens.append("[SEP]")
segment_ids.append(0)
i_hds = []
for i, hds11 in enumerate(hds1):
i_st_hd = len(tokens)
t_to_tt_idx_hds11 = []
sub_tok = []
for sub_tok1 in hds11.split():
t_to_tt_idx_hds11.append(len(sub_tok))
sub_tok += tokenizer.tokenize(sub_tok1)
t_to_tt_idx_hds1.append(t_to_tt_idx_hds11)
tokens += sub_tok
i_ed_hd = len(tokens)
i_hds.append((i_st_hd, i_ed_hd))
segment_ids += [1] * len(sub_tok)
if i < len(hds1)-1:
tokens.append("[SEP]")
segment_ids.append(0)
elif i == len(hds1)-1:
tokens.append("[SEP]")
segment_ids.append(1)
else:
raise EnvironmentError
i_nlu = (i_st_nlu, i_ed_nlu)
return tokens, segment_ids, i_nlu, i_hds, t_to_tt_idx_hds1
def gen_l_hpu(i_hds):
"""
# Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size
i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)])
"""
l_hpu = []
for i_hds1 in i_hds:
for i_hds11 in i_hds1:
l_hpu.append(i_hds11[1] - i_hds11[0])
return l_hpu
def get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length):
"""
Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT.
INPUT
:param model_bert:
:param tokenizer: WordPiece toknizer
:param nlu: Question
:param nlu_t: CoreNLP tokenized nlu.
:param hds: Headers
:param hs_t: None or 1st-level tokenized headers
:param max_seq_length: max input token length
OUTPUT
tokens: BERT input tokens
nlu_tt: WP-tokenized input natural language questions
orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token
tok_to_orig_index: inverse map.
"""
l_n = []
l_hs = [] # The length of columns for each batch
input_ids = []
tokens = []
segment_ids = []
input_mask = []
i_nlu = [] # index to retreive the position of contextual vector later.
i_hds = []
doc_tokens = []
nlu_tt = []
t_to_tt_idx = []
tt_to_t_idx = []
t_to_tt_idx_hds = []
for b, nlu_t1 in enumerate(nlu_t):
hds1 = hds[b]
l_hs.append(len(hds1))
# 1. 2nd tokenization using WordPiece
tt_to_t_idx1 = [] # number indicates where sub-token belongs to in 1st-level-tokens (here, CoreNLP).
t_to_tt_idx1 = [] # orig_to_tok_idx[i] = start index of i-th-1st-level-token in all_tokens.
nlu_tt1 = [] # all_doc_tokens[ orig_to_tok_idx[i] ] returns first sub-token segement of i-th-1st-level-token
for (i, token) in enumerate(nlu_t1):
t_to_tt_idx1.append(
len(nlu_tt1)) # all_doc_tokens[ indicate the start position of original 'white-space' tokens.
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tt_to_t_idx1.append(i)
nlu_tt1.append(sub_token) # all_doc_tokens are further tokenized using WordPiece tokenizer
nlu_tt.append(nlu_tt1)
tt_to_t_idx.append(tt_to_t_idx1)
t_to_tt_idx.append(t_to_tt_idx1)
l_n.append(len(nlu_tt1))
# [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP]
# 2. Generate BERT inputs & indices.
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, hds1)
assert len(t_to_tt_idx_hds1) == len(hds1)
t_to_tt_idx_hds.append(t_to_tt_idx_hds1)
input_ids1 = tokenizer.convert_tokens_to_ids(tokens1)
# Input masks
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask1 = [1] * len(input_ids1)
# 3. Zero-pad up to the sequence length.
if len(nlu_t) == 1:
max_seq_length = len(input_ids1)
while len(input_ids1) < max_seq_length:
input_ids1.append(0)
input_mask1.append(0)
segment_ids1.append(0)
assert len(input_ids1) == max_seq_length
assert len(input_mask1) == max_seq_length
assert len(segment_ids1) == max_seq_length
input_ids.append(input_ids1)
tokens.append(tokens1)
segment_ids.append(segment_ids1)
input_mask.append(input_mask1)
i_nlu.append(i_nlu1)
i_hds.append(i_hds1)
# Convert to tensor
all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device)
all_segment_ids = torch.tensor(segment_ids, dtype=torch.long).to(device)
# 4. Generate BERT output.
all_encoder_layer, pooled_output = model_bert(all_input_ids, all_segment_ids, all_input_mask)
# 5. generate l_hpu from i_hds
l_hpu = gen_l_hpu(i_hds)
assert len(set(l_n)) == 1 and len(set(i_nlu)) == 1
assert l_n[0] == i_nlu[0][1] - i_nlu[0][0]
return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \
l_n, l_hpu, l_hs, \
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds
def get_wemb_n(i_nlu, l_n, hS, num_hidden_layers, all_encoder_layer, num_out_layers_n):
"""
Get the representation of each tokens.
"""
bS = len(l_n)
l_n_max = max(l_n)
# print('wemb_n: [bS, l_n_max, hS * num_out_layers_n] = ', bS, l_n_max, hS * num_out_layers_n)
wemb_n = torch.zeros([bS, l_n_max, hS * num_out_layers_n]).to(device)
for b in range(bS):
# [B, max_len, dim]
# Fill zero for non-exist part.
l_n1 = l_n[b]
i_nlu1 = i_nlu[b]
for i_noln in range(num_out_layers_n):
i_layer = num_hidden_layers - 1 - i_noln
st = i_noln * hS
ed = (i_noln + 1) * hS
wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), st:ed] = all_encoder_layer[i_layer][b, i_nlu1[0]:i_nlu1[1], :]
return wemb_n
def get_wemb_h(i_hds, l_hpu, l_hs, hS, num_hidden_layers, all_encoder_layer, num_out_layers_h):
"""
As if
[ [table-1-col-1-tok1, t1-c1-t2, ...],
[t1-c2-t1, t1-c2-t2, ...].
...
[t2-c1-t1, ...,]
]
"""
bS = len(l_hs)
l_hpu_max = max(l_hpu)
num_of_all_hds = sum(l_hs)
wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS * num_out_layers_h]).to(device)
# print('wemb_h: [num_of_all_hds, l_hpu_max, hS * num_out_layers_h] = ', wemb_h.size())
b_pu = -1
for b, i_hds1 in enumerate(i_hds):
for b1, i_hds11 in enumerate(i_hds1):
b_pu += 1
for i_nolh in range(num_out_layers_h):
i_layer = num_hidden_layers - 1 - i_nolh
st = i_nolh * hS
ed = (i_nolh + 1) * hS
wemb_h[b_pu, 0:(i_hds11[1] - i_hds11[0]), st:ed] \
= all_encoder_layer[i_layer][b, i_hds11[0]:i_hds11[1],:]
return wemb_h
def get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=1, num_out_layers_h=1):
# get contextual output of all tokens from bert
all_encoder_layer, pooled_output, tokens, i_nlu, i_hds,\
l_n, l_hpu, l_hs, \
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length)
# all_encoder_layer: BERT outputs from all layers.
# pooled_output: output of [CLS] vec.
# tokens: BERT intput tokens
# i_nlu: start and end indices of question in tokens
# i_hds: start and end indices of headers
# get the wemb
wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer,
num_out_layers_n)
wemb_h = get_wemb_h(i_hds, l_hpu, l_hs, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer,
num_out_layers_h)
return wemb_n, wemb_h, l_n, l_hpu, l_hs, \
nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds
def prepare_input(tokenizer, input_sequence, input_schema, max_seq_length):
nlu_t = []
hds = []
nlu_t1 = input_sequence
all_hds = input_schema.column_names_embedder_input
nlu_tt1 = []
for (i, token) in enumerate(nlu_t1):
nlu_tt1 += tokenizer.tokenize(token)
current_hds1 = []
for hds1 in all_hds:
new_hds1 = current_hds1 + [hds1]
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, new_hds1)
if len(segment_ids1) > max_seq_length:
nlu_t.append(nlu_t1)
hds.append(current_hds1)
current_hds1 = [hds1]
else:
current_hds1 = new_hds1
if len(current_hds1) > 0:
nlu_t.append(nlu_t1)
hds.append(current_hds1)
return nlu_t, hds
def prepare_input_v2(tokenizer, input_sequence, input_schema):
nlu_t = []
hds = []
max_seq_length = 0
nlu_t1 = input_sequence
all_hds = input_schema.column_names_embedder_input
nlu_tt1 = []
for (i, token) in enumerate(nlu_t1):
nlu_tt1 += tokenizer.tokenize(token)
current_hds1 = []
current_table = ''
for hds1 in all_hds:
hds1_table = hds1.split('.')[0].strip()
if hds1_table == current_table:
current_hds1.append(hds1)
else:
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1)
max_seq_length = max(max_seq_length, len(segment_ids1))
nlu_t.append(nlu_t1)
hds.append(current_hds1)
current_hds1 = [hds1]
current_table = hds1_table
if len(current_hds1) > 0:
tokens1, segment_ids1, i_nlu1, i_hds1, t_to_tt_idx_hds1 = generate_inputs(tokenizer, nlu_tt1, current_hds1)
max_seq_length = max(max_seq_length, len(segment_ids1))
nlu_t.append(nlu_t1)
hds.append(current_hds1)
return nlu_t, hds, max_seq_length
def get_bert_encoding(bert_config, model_bert, tokenizer, input_sequence, input_schema, bert_input_version='v1', max_seq_length=512, num_out_layers_n=1, num_out_layers_h=1):
if bert_input_version == 'v1':
nlu_t, hds = prepare_input(tokenizer, input_sequence, input_schema, max_seq_length)
elif bert_input_version == 'v2':
nlu_t, hds, max_seq_length = prepare_input_v2(tokenizer, input_sequence, input_schema)
wemb_n, wemb_h, l_n, l_hpu, l_hs, nlu_tt, t_to_tt_idx, tt_to_t_idx, t_to_tt_idx_hds = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n, num_out_layers_h)
t_to_tt_idx = t_to_tt_idx[0]
assert len(t_to_tt_idx) == len(input_sequence)
assert sum(len(t_to_tt_idx_hds1) for t_to_tt_idx_hds1 in t_to_tt_idx_hds) == len(input_schema.column_names_embedder_input)
assert list(wemb_h.size())[0] == len(input_schema.column_names_embedder_input)
utterance_states = []
for i in range(len(t_to_tt_idx)):
start = t_to_tt_idx[i]
if i == len(t_to_tt_idx)-1:
end = l_n[0]
else:
end = t_to_tt_idx[i+1]
utterance_states.append(torch.mean(wemb_n[:,start:end,:], dim=[0,1]))
assert len(utterance_states) == len(input_sequence)
schema_token_states = []
cnt = -1
for t_to_tt_idx_hds1 in t_to_tt_idx_hds:
for t_to_tt_idx_hds11 in t_to_tt_idx_hds1:
cnt += 1
schema_token_states1 = []
for i in range(len(t_to_tt_idx_hds11)):
start = t_to_tt_idx_hds11[i]
if i == len(t_to_tt_idx_hds11)-1:
end = l_hpu[cnt]
else:
end = t_to_tt_idx_hds11[i+1]
schema_token_states1.append(torch.mean(wemb_h[cnt,start:end,:], dim=0))
assert len(schema_token_states1) == len(input_schema.column_names_embedder_input[cnt].split())
schema_token_states.append(schema_token_states1)
assert len(schema_token_states) == len(input_schema.column_names_embedder_input)
return utterance_states, schema_token_states

609
r2sql/cosql/model_util.py Normal file
View File

@ -0,0 +1,609 @@
"""Basic model training and evaluation functions."""
from enum import Enum
import random
import sys
import json
import progressbar
import model.torch_utils
import data_util.sql_util
import torch
def write_prediction(fileptr,
identifier,
input_seq,
probability,
prediction,
flat_prediction,
gold_query,
flat_gold_queries,
gold_tables,
index_in_interaction,
database_username,
database_password,
database_timeout,
compute_metrics=True,
beam=None):
pred_obj = {}
pred_obj["identifier"] = identifier
if len(identifier.split('/')) == 2:
database_id, interaction_id = identifier.split('/')
else:
database_id = 'atis'
interaction_id = identifier
pred_obj["database_id"] = database_id
pred_obj["interaction_id"] = interaction_id
pred_obj["input_seq"] = input_seq
pred_obj["probability"] = probability
pred_obj["prediction"] = prediction
pred_obj["flat_prediction"] = flat_prediction
pred_obj["gold_query"] = gold_query
pred_obj["flat_gold_queries"] = flat_gold_queries
pred_obj["index_in_interaction"] = index_in_interaction
pred_obj["gold_tables"] = str(gold_tables)
pred_obj["beam"] = beam
# Now compute the metrics we want.
if compute_metrics:
# First metric: whether flat predicted query is in the gold query set.
correct_string = " ".join(flat_prediction) in [
" ".join(q) for q in flat_gold_queries]
pred_obj["correct_string"] = correct_string
# Database metrics
if not correct_string:
syntactic, semantic, pred_table = sql_util.execution_results(
" ".join(flat_prediction), database_username, database_password, database_timeout)
pred_table = sorted(pred_table)
best_prec = 0.
best_rec = 0.
best_f1 = 0.
for gold_table in gold_tables:
num_overlap = float(len(set(pred_table) & set(gold_table)))
if len(set(gold_table)) > 0:
prec = num_overlap / len(set(gold_table))
else:
prec = 1.
if len(set(pred_table)) > 0:
rec = num_overlap / len(set(pred_table))
else:
rec = 1.
if prec > 0. and rec > 0.:
f1 = (2 * (prec * rec)) / (prec + rec)
else:
f1 = 1.
best_prec = max(best_prec, prec)
best_rec = max(best_rec, rec)
best_f1 = max(best_f1, f1)
else:
syntactic = True
semantic = True
pred_table = []
best_prec = 1.
best_rec = 1.
best_f1 = 1.
assert best_prec <= 1.
assert best_rec <= 1.
assert best_f1 <= 1.
pred_obj["syntactic"] = syntactic
pred_obj["semantic"] = semantic
correct_table = (pred_table in gold_tables) or correct_string
pred_obj["correct_table"] = correct_table
pred_obj["strict_correct_table"] = correct_table and syntactic
pred_obj["pred_table"] = str(pred_table)
pred_obj["table_prec"] = best_prec
pred_obj["table_rec"] = best_rec
pred_obj["table_f1"] = best_f1
fileptr.write(json.dumps(pred_obj) + "\n")
class Metrics(Enum):
"""Definitions of simple metrics to compute."""
LOSS = 1
TOKEN_ACCURACY = 2
STRING_ACCURACY = 3
CORRECT_TABLES = 4
STRICT_CORRECT_TABLES = 5
SEMANTIC_QUERIES = 6
SYNTACTIC_QUERIES = 7
def get_progressbar(name, size):
"""Gets a progress bar object given a name and the total size.
Inputs:
name (str): The name to display on the side.
size (int): The maximum size of the progress bar.
"""
return progressbar.ProgressBar(maxval=size,
widgets=[name,
progressbar.Bar('=', '[', ']'),
' ',
progressbar.Percentage(),
' ',
progressbar.ETA()])
def train_epoch_with_utterances(batches,
model,
randomize=True):
"""Trains model for a single epoch given batches of utterance data.
Inputs:
batches (UtteranceBatch): The batches to give to training.
model (ATISModel): The model obect.
learning_rate (float): The learning rate to use during training.
dropout_amount (float): Amount of dropout to set in the model.
randomize (bool): Whether or not to randomize the order that the batches are seen.
"""
if randomize:
random.shuffle(batches)
progbar = get_progressbar("train ", len(batches))
progbar.start()
loss_sum = 0.
for i, batch in enumerate(batches):
batch_loss = model.train_step(batch)
loss_sum += batch_loss
progbar.update(i)
progbar.finish()
total_loss = loss_sum / len(batches)
return total_loss
def train_epoch_with_interactions(interaction_batches,
params,
model,
randomize=True):
"""Trains model for single epoch given batches of interactions.
Inputs:
interaction_batches (list of InteractionBatch): The batches to train on.
params (namespace): Parameters to run with.
model (ATISModel): Model to train.
randomize (bool): Whether or not to randomize the order that batches are seen.
"""
if randomize:
random.shuffle(interaction_batches)
progbar = get_progressbar("train ", len(interaction_batches))
progbar.start()
loss_sum = 0.
break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1"]
#break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1", "behavior_monitoring"]
for i, interaction_batch in enumerate(interaction_batches):
assert len(interaction_batch) == 1
interaction = interaction_batch.items[0]
name = interaction.identifier.split('/')[0]
if name in break_lst:
continue
print(i, interaction.identifier)
# try:
batch_loss = model.train_step(interaction, params.train_maximum_sql_length)
# except RuntimeError:
# torch.cuda.empty_cache()
# continue
print(batch_loss)
loss_sum += batch_loss
torch.cuda.empty_cache()
progbar.update(i)
progbar.finish()
total_loss = loss_sum / len(interaction_batches)
return total_loss
def update_sums(metrics,
metrics_sums,
predicted_sequence,
flat_sequence,
gold_query,
original_gold_query,
gold_forcing=False,
loss=None,
token_accuracy=0.,
database_username="",
database_password="",
database_timeout=0,
gold_table=None):
"""" Updates summing for metrics in an aggregator.
TODO: don't use sums, just keep the raw value.
"""
if Metrics.LOSS in metrics:
metrics_sums[Metrics.LOSS] += loss.item()
if Metrics.TOKEN_ACCURACY in metrics:
if gold_forcing:
metrics_sums[Metrics.TOKEN_ACCURACY] += token_accuracy
else:
num_tokens_correct = 0.
for j, token in enumerate(gold_query):
if len(
predicted_sequence) > j and predicted_sequence[j] == token:
num_tokens_correct += 1
metrics_sums[Metrics.TOKEN_ACCURACY] += num_tokens_correct / \
len(gold_query)
if Metrics.STRING_ACCURACY in metrics:
metrics_sums[Metrics.STRING_ACCURACY] += int(
flat_sequence == original_gold_query)
if Metrics.CORRECT_TABLES in metrics:
assert database_username, "You did not provide a database username"
assert database_password, "You did not provide a database password"
assert database_timeout > 0, "Database timeout is 0 seconds"
# Evaluate SQL
if flat_sequence != original_gold_query:
syntactic, semantic, table = sql_util.execution_results(
" ".join(flat_sequence), database_username, database_password, database_timeout)
else:
syntactic = True
semantic = True
table = gold_table
metrics_sums[Metrics.CORRECT_TABLES] += int(table == gold_table)
if Metrics.SYNTACTIC_QUERIES in metrics:
metrics_sums[Metrics.SYNTACTIC_QUERIES] += int(syntactic)
if Metrics.SEMANTIC_QUERIES in metrics:
metrics_sums[Metrics.SEMANTIC_QUERIES] += int(semantic)
if Metrics.STRICT_CORRECT_TABLES in metrics:
metrics_sums[Metrics.STRICT_CORRECT_TABLES] += int(
table == gold_table and syntactic)
def construct_averages(metrics_sums, total_num):
""" Computes the averages for metrics.
Inputs:
metrics_sums (dict Metric -> float): Sums for a metric.
total_num (int): Number to divide by (average).
"""
metrics_averages = {}
for metric, value in metrics_sums.items():
metrics_averages[metric] = value / total_num
if metric != "loss":
metrics_averages[metric] *= 100.
return metrics_averages
def evaluate_utterance_sample(sample,
model,
max_generation_length,
name="",
gold_forcing=False,
metrics=None,
total_num=-1,
database_username="",
database_password="",
database_timeout=0,
write_results=False):
"""Evaluates a sample of utterance examples.
Inputs:
sample (list of Utterance): Examples to evaluate.
model (ATISModel): Model to predict with.
max_generation_length (int): Maximum length to generate.
name (str): Name to log with.
gold_forcing (bool): Whether to force the gold tokens during decoding.
metrics (list of Metric): Metrics to evaluate with.
total_num (int): Number to divide by when reporting results.
database_username (str): Username to use for executing queries.
database_password (str): Password to use when executing queries.
database_timeout (float): Timeout on queries when executing.
write_results (bool): Whether to write the results to a file.
"""
assert metrics
if total_num < 0:
total_num = len(sample)
metrics_sums = {}
for metric in metrics:
metrics_sums[metric] = 0.
predictions_file = open(name + "_predictions.json", "w")
print("Predicting with filename " + str(name) + "_predictions.json")
progbar = get_progressbar(name, len(sample))
progbar.start()
predictions = []
for i, item in enumerate(sample):
_, loss, predicted_seq = model.eval_step(
item, max_generation_length, feed_gold_query=gold_forcing)
loss = loss / len(item.gold_query())
predictions.append(predicted_seq)
flat_sequence = item.flatten_sequence(predicted_seq)
token_accuracy = torch_utils.per_token_accuracy(
item.gold_query(), predicted_seq)
if write_results:
write_prediction(
predictions_file,
identifier=item.interaction.identifier,
input_seq=item.input_sequence(),
probability=0,
prediction=predicted_seq,
flat_prediction=flat_sequence,
gold_query=item.gold_query(),
flat_gold_queries=item.original_gold_queries(),
gold_tables=item.gold_tables(),
index_in_interaction=item.utterance_index,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout)
update_sums(metrics,
metrics_sums,
predicted_seq,
flat_sequence,
item.gold_query(),
item.original_gold_queries()[0],
gold_forcing,
loss,
token_accuracy,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout,
gold_table=item.gold_tables()[0])
progbar.update(i)
progbar.finish()
predictions_file.close()
return construct_averages(metrics_sums, total_num), None
def evaluate_interaction_sample(sample,
model,
max_generation_length,
name="",
gold_forcing=False,
metrics=None,
total_num=-1,
database_username="",
database_password="",
database_timeout=0,
use_predicted_queries=False,
write_results=False,
use_gpu=False,
compute_metrics=False):
""" Evaluates a sample of interactions. """
predictions_file = open(name + "_predictions.json", "w")
print("Predicting with file " + str(name + "_predictions.json"))
metrics_sums = {}
for metric in metrics:
metrics_sums[metric] = 0.
progbar = get_progressbar(name, len(sample))
progbar.start()
num_utterances = 0
predictions = []
use_gpu = not ("--no_gpus" in sys.argv or "--no_gpus=1" in sys.argv)
model.eval()
break_lst = ["baseball_1", "soccer_1", "formula_1", "cre_Drama_Workshop_Groups", "sakila_1"]
for i, interaction in enumerate(sample):
name = interaction.identifier.split('/')[0]
if name in break_lst:
continue
try:
with torch.no_grad():
if use_predicted_queries:
example_preds = model.predict_with_predicted_queries(
interaction,
max_generation_length)
else:
example_preds = model.predict_with_gold_queries(
interaction,
max_generation_length,
feed_gold_query=gold_forcing)
torch.cuda.empty_cache()
except RuntimeError as exception:
print("Failed on interaction: " + str(interaction.identifier))
print(exception)
print("\n\n")
exit()
predictions.extend(example_preds)
assert len(example_preds) == len(
interaction.interaction.utterances) or not example_preds
for j, pred in enumerate(example_preds):
num_utterances += 1
sequence, loss, token_accuracy, _, decoder_results = pred
if use_predicted_queries:
item = interaction.processed_utterances[j]
original_utt = interaction.interaction.utterances[item.index]
gold_query = original_utt.gold_query_to_use
original_gold_query = original_utt.original_gold_query
gold_table = original_utt.gold_sql_results
gold_queries = [q[0] for q in original_utt.all_gold_queries]
gold_tables = [q[1] for q in original_utt.all_gold_queries]
index = item.index
else:
item = interaction.gold_utterances()[j]
gold_query = item.gold_query()
original_gold_query = item.original_gold_query()
gold_table = item.gold_table()
gold_queries = item.original_gold_queries()
gold_tables = item.gold_tables()
index = item.utterance_index
if loss:
loss = loss / len(gold_query)
flat_sequence = item.flatten_sequence(sequence)
if write_results:
ori_beam = decoder_results.beam
beam = []
for x in ori_beam:
beam.append((-x[0], item.flatten_sequence(x[1].sequence)))
beam.sort()
write_prediction(
predictions_file,
identifier=interaction.identifier,
input_seq=item.input_sequence(),
probability=decoder_results.probability,
prediction=sequence,
flat_prediction=flat_sequence,
gold_query=gold_query,
flat_gold_queries=gold_queries,
gold_tables=gold_tables,
index_in_interaction=index,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout,
compute_metrics=compute_metrics,
beam=beam)
update_sums(metrics,
metrics_sums,
sequence,
flat_sequence,
gold_query,
original_gold_query,
gold_forcing,
loss,
token_accuracy,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout,
gold_table=gold_table)
progbar.update(i)
progbar.finish()
if total_num < 0:
total_num = num_utterances
predictions_file.close()
return construct_averages(metrics_sums, total_num), predictions
def evaluate_using_predicted_queries(sample,
model,
name="",
gold_forcing=False,
metrics=None,
total_num=-1,
database_username="",
database_password="",
database_timeout=0,
snippet_keep_age=1):
predictions_file = open(name + "_predictions.json", "w")
print("Predicting with file " + str(name + "_predictions.json"))
assert not gold_forcing
metrics_sums = {}
for metric in metrics:
metrics_sums[metric] = 0.
progbar = get_progressbar(name, len(sample))
progbar.start()
num_utterances = 0
predictions = []
for i, item in enumerate(sample):
int_predictions = []
item.start_interaction()
while not item.done():
utterance = item.next_utterance(snippet_keep_age)
predicted_sequence, loss, _, probability = model.eval_step(
utterance)
int_predictions.append((utterance, predicted_sequence))
flat_sequence = utterance.flatten_sequence(predicted_sequence)
if sql_util.executable(
flat_sequence,
username=database_username,
password=database_password,
timeout=database_timeout) and probability >= 0.24:
utterance.set_pred_query(
item.remove_snippets(predicted_sequence))
item.add_utterance(utterance,
item.remove_snippets(predicted_sequence),
previous_snippets=utterance.snippets())
else:
# Add the /previous/ predicted query, guaranteed to be syntactically
# correct
seq = []
utterance.set_pred_query(seq)
item.add_utterance(
utterance, seq, previous_snippets=utterance.snippets())
original_utt = item.interaction.utterances[utterance.index]
write_prediction(
predictions_file,
identifier=item.interaction.identifier,
input_seq=utterance.input_sequence(),
probability=probability,
prediction=predicted_sequence,
flat_prediction=flat_sequence,
gold_query=original_utt.gold_query_to_use,
flat_gold_queries=[
q[0] for q in original_utt.all_gold_queries],
gold_tables=[
q[1] for q in original_utt.all_gold_queries],
index_in_interaction=utterance.index,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout)
update_sums(metrics,
metrics_sums,
predicted_sequence,
flat_sequence,
original_utt.gold_query_to_use,
original_utt.original_gold_query,
gold_forcing,
loss,
token_accuracy=0,
database_username=database_username,
database_password=database_password,
database_timeout=database_timeout,
gold_table=original_utt.gold_sql_results)
predictions.append(int_predictions)
progbar.update(i)
progbar.finish()
if total_num < 0:
total_num = num_utterances
predictions_file.close()
return construct_averages(metrics_sums, total_num), predictions

3310
r2sql/cosql/output_temp.txt Normal file

File diff suppressed because one or more lines are too long

167
r2sql/cosql/parse_args.py Normal file
View File

@ -0,0 +1,167 @@
import sys
args = sys.argv
import os
import argparse
def interpret_args():
""" Interprets the command line arguments, and returns a dictionary. """
parser = argparse.ArgumentParser()
parser.add_argument("--no_gpus", type=bool, default=1)
parser.add_argument("--seed", type=int, default=141)
### Data parameters
parser.add_argument(
'--raw_train_filename',
type=str,
default='../atis_data/data/resplit/processed/train_with_tables.pkl')
parser.add_argument(
'--raw_dev_filename',
type=str,
default='../atis_data/data/resplit/processed/dev_with_tables.pkl')
parser.add_argument(
'--raw_validation_filename',
type=str,
default='../atis_data/data/resplit/processed/valid_with_tables.pkl')
parser.add_argument(
'--raw_test_filename',
type=str,
default='../atis_data/data/resplit/processed/test_with_tables.pkl')
parser.add_argument('--data_directory', type=str, default='processed_data')
parser.add_argument('--processed_train_filename', type=str, default='train.pkl')
parser.add_argument('--processed_dev_filename', type=str, default='dev.pkl')
parser.add_argument('--processed_validation_filename', type=str, default='validation.pkl')
parser.add_argument('--processed_test_filename', type=str, default='test.pkl')
parser.add_argument('--database_schema_filename', type=str, default=None)
parser.add_argument('--embedding_filename', type=str, default=None)
parser.add_argument('--input_vocabulary_filename', type=str, default='input_vocabulary.pkl')
parser.add_argument('--output_vocabulary_filename',
type=str,
default='output_vocabulary.pkl')
parser.add_argument('--input_key', type=str, default='nl_with_dates')
parser.add_argument('--anonymize', type=bool, default=False)
parser.add_argument('--anonymization_scoring', type=bool, default=False)
parser.add_argument('--use_snippets', type=bool, default=False)
parser.add_argument('--use_previous_query', type=bool, default=False)
parser.add_argument('--maximum_queries', type=int, default=1)
parser.add_argument('--use_copy_switch', type=bool, default=False)
parser.add_argument('--use_query_attention', type=bool, default=False)
parser.add_argument('--use_utterance_attention', type=bool, default=False)
parser.add_argument('--freeze', type=bool, default=False)
parser.add_argument('--scheduler', type=bool, default=False)
parser.add_argument('--use_bert', type=bool, default=False)
parser.add_argument("--bert_type_abb", type=str, help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS")
parser.add_argument("--bert_input_version", type=str, default='v1')
parser.add_argument('--fine_tune_bert', type=bool, default=False)
parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.')
### Debugging/logging parameters
parser.add_argument('--logdir', type=str, default='logs')
parser.add_argument('--deterministic', type=bool, default=False)
parser.add_argument('--num_train', type=int, default=-1)
parser.add_argument('--logfile', type=str, default='log.txt')
parser.add_argument('--results_file', type=str, default='results.txt')
### Model architecture
parser.add_argument('--input_embedding_size', type=int, default=300)
parser.add_argument('--output_embedding_size', type=int, default=300)
parser.add_argument('--encoder_state_size', type=int, default=300)
parser.add_argument('--decoder_state_size', type=int, default=300)
parser.add_argument('--encoder_num_layers', type=int, default=1)
parser.add_argument('--decoder_num_layers', type=int, default=2)
parser.add_argument('--snippet_num_layers', type=int, default=1)
parser.add_argument('--maximum_utterances', type=int, default=5)
parser.add_argument('--state_positional_embeddings', type=bool, default=False)
parser.add_argument('--positional_embedding_size', type=int, default=50)
parser.add_argument('--snippet_age_embedding', type=bool, default=False)
parser.add_argument('--snippet_age_embedding_size', type=int, default=64)
parser.add_argument('--max_snippet_age_embedding', type=int, default=4)
parser.add_argument('--previous_decoder_snippet_encoding', type=bool, default=False)
parser.add_argument('--discourse_level_lstm', type=bool, default=False)
parser.add_argument('--use_schema_attention', type=bool, default=False)
parser.add_argument('--use_encoder_attention', type=bool, default=False)
parser.add_argument('--use_schema_encoder', type=bool, default=False)
parser.add_argument('--use_schema_self_attention', type=bool, default=False)
parser.add_argument('--use_schema_encoder_2', type=bool, default=False)
### Training parameters
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--train_maximum_sql_length', type=int, default=200)
parser.add_argument('--train_evaluation_size', type=int, default=100)
parser.add_argument('--dropout_amount', type=float, default=0.5)
parser.add_argument('--initial_patience', type=float, default=10.)
parser.add_argument('--patience_ratio', type=float, default=1.01)
parser.add_argument('--initial_learning_rate', type=float, default=0.001)
parser.add_argument('--learning_rate_ratio', type=float, default=0.8)
parser.add_argument('--interaction_level', type=bool, default=False)
parser.add_argument('--reweight_batch', type=bool, default=False)
### Setting
parser.add_argument('--train', type=bool, default=False)
parser.add_argument('--debug', type=bool, default=False)
parser.add_argument('--evaluate', type=bool, default=False)
parser.add_argument('--attention', type=bool, default=False)
parser.add_argument('--save_file', type=str, default="")
parser.add_argument('--enable_testing', type=bool, default=False)
parser.add_argument('--use_predicted_queries', type=bool, default=False)
parser.add_argument('--evaluate_split', type=str, default='dev')
parser.add_argument('--evaluate_with_gold_forcing', type=bool, default=False)
parser.add_argument('--eval_maximum_sql_length', type=int, default=1000)
parser.add_argument('--results_note', type=str, default='')
parser.add_argument('--compute_metrics', type=bool, default=False)
parser.add_argument('--reference_results', type=str, default='')
parser.add_argument('--interactive', type=bool, default=False)
parser.add_argument('--database_username', type=str, default="aviarmy")
parser.add_argument('--database_password', type=str, default="aviarmy")
parser.add_argument('--database_timeout', type=int, default=2)
parser.add_argument('--use_transformer_relation', type=bool, default=True)
parser.add_argument('--table_schema_path', type=str, default="data/cosql/tables.json")
## relace it for colab
# parser.add_argument('--table_schema_path', type=str, default="data/table.json")
args = parser.parse_args()
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
if not (args.train or args.evaluate or args.interactive or args.attention):
raise ValueError('You need to be training or evaluating')
if args.enable_testing and not args.evaluate:
raise ValueError('You should evaluate the model if enabling testing')
if args.train:
args_file = args.logdir + '/args.log'
if os.path.exists(args_file):
raise ValueError('Warning: arguments already exist in ' + str(args_file))
with open(args_file, 'w') as infile:
infile.write(str(args))
return args

View File

@ -0,0 +1,614 @@
import os
import json
import copy
import torch
import sqlparse
import argparse
import subprocess
from collections import defaultdict
from reranker.predict import Ranker
import numpy
numpy.random.seed()
torch.manual_seed(1)
#from eval_scripts.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
stop_word = set({'=', 'is', 'and', 'desc', '!=', 'select', ')', '_UNK', 'max', '0', 'none', 'order', 'like', 'join', '>', 'on', 'count', 'or', 'from', 'group_by', 'limit', '<=', 'sum', 'group', '(', '_EOS', 'intersect', 'order_by', 'limit_value', '-', 'except', 'where', 'distinct', 'avg', '<', 'min', 'asc', '+', 'in', 'union', 'between', 'exists', '/', 'having', 'not', 'as', '>=', 'value', ',', '*'})
#ranker = Ranker("./cosql_train_48.pt")
ranker = Ranker("./submit_models/cosql_reranker_roberta.pt")
def find_shortest_path(start, end, graph):
stack = [[start, []]]
visited = list()
while len(stack) > 0:
ele, history = stack.pop()
if ele == end:
return history
for node in graph[ele]:
if node[0] not in visited:
stack.append((node[0], history + [(node[0], node[1])]))
visited.append(node[0])
def gen_from(candidate_tables, schema):
if len(candidate_tables) <= 1:
if len(candidate_tables) == 1:
ret = "from {}".format(schema["table_names_original"][list(candidate_tables)[0]])
else:
ret = "from {}".format(schema["table_names_original"][0])
return {}, ret
table_alias_dict = {}
uf_dict = {}
for t in candidate_tables:
uf_dict[t] = -1
idx = 1
graph = defaultdict(list)
for acol, bcol in schema["foreign_keys"]:
t1 = schema["column_names"][acol][0]
t2 = schema["column_names"][bcol][0]
graph[t1].append((t2, (acol, bcol)))
graph[t2].append((t1, (bcol, acol)))
candidate_tables = list(candidate_tables)
start = candidate_tables[0]
table_alias_dict[start] = idx
idx += 1
ret = "from {} as T1".format(schema["table_names_original"][start])
try:
for end in candidate_tables[1:]:
if end in table_alias_dict:
continue
path = find_shortest_path(start, end, graph)
prev_table = start
if not path:
table_alias_dict[end] = idx
idx += 1
ret = "{} join {} as T{}".format(ret, schema["table_names_original"][end],
table_alias_dict[end],
)
continue
for node, (acol, bcol) in path:
if node in table_alias_dict:
prev_table = node
continue
table_alias_dict[node] = idx
idx += 1
ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(ret, schema["table_names_original"][node],
table_alias_dict[node],
table_alias_dict[prev_table],
schema["column_names_original"][acol][1],
table_alias_dict[node],
schema["column_names_original"][bcol][1])
prev_table = node
except:
traceback.print_exc()
print("db:{}".format(schema["db_id"]))
return table_alias_dict, ret
return table_alias_dict, ret
def normalize_space(format_sql):
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('(',' ( ').replace(')',' ) ').split())
for sub_sql in format_sql.split('\n')]
format_sql_1 = '\n'.join(format_sql_1)
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('union ', 'union\n').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
return format_sql_2
def get_candidate_tables(format_sql, schema):
candidate_tables = []
tokens = format_sql.split()
for ii, token in enumerate(tokens):
if '.' in token:
table_name = token.split('.')[0]
candidate_tables.append(table_name)
candidate_tables = list(set(candidate_tables))
candidate_tables.sort()
table_names_original = [table_name.lower() for table_name in schema['table_names_original']]
candidate_tables_id = [table_names_original.index(table_name) for table_name in candidate_tables]
assert -1 not in candidate_tables_id
table_names_original = schema['table_names_original']
return candidate_tables_id, table_names_original
def get_surface_form_orig(format_sql_2, schema):
column_names_surface_form = []
column_names_surface_form_original = []
column_names_original = schema['column_names_original']
table_names_original = schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
# this is just *
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
column_names_surface_form_original.append(column_name_surface_form)
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
column_names_surface_form_original.append('{}.*'.format(table_name))
assert len(column_names_surface_form) == len(column_names_surface_form_original)
for surface_form, surface_form_original in zip(column_names_surface_form, column_names_surface_form_original):
format_sql_2 = format_sql_2.replace(surface_form, surface_form_original)
return format_sql_2
def add_from_clase(sub_sql, from_clause):
select_right_sub_sql = []
left_sub_sql = []
left = True
num_left_parathesis = 0 # in select_right_sub_sql
num_right_parathesis = 0 # in select_right_sub_sql
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
if token == 'select':
left = False
if left:
left_sub_sql.append(token)
continue
select_right_sub_sql.append(token)
if token == '(':
num_left_parathesis += 1
elif token == ')':
num_right_parathesis += 1
def remove_missing_tables_from_select(select_statement):
tokens = select_statement.split(',')
stop_idx = -1
for i in range(len(tokens)):
idx = len(tokens) - 1 - i
token = tokens[idx]
if '.*' in token and 'count ' not in token:
pass
else:
stop_idx = idx+1
break
if stop_idx > 0:
new_select_statement = ','.join(tokens[:stop_idx]).strip()
else:
new_select_statement = select_statement
return new_select_statement
if num_left_parathesis == num_right_parathesis or num_left_parathesis > num_right_parathesis:
sub_sqls = []
sub_sqls.append(remove_missing_tables_from_select(sub_sql))
sub_sqls.append(from_clause)
else:
assert num_left_parathesis < num_right_parathesis
select_sub_sql = []
right_sub_sql = []
for i in range(len(select_right_sub_sql)):
token_idx = len(select_right_sub_sql)-1-i
token = select_right_sub_sql[token_idx]
if token == ')':
num_right_parathesis -= 1
if num_right_parathesis == num_left_parathesis:
select_sub_sql = select_right_sub_sql[:token_idx]
right_sub_sql = select_right_sub_sql[token_idx:]
break
sub_sqls = []
if len(left_sub_sql) > 0:
sub_sqls.append(' '.join(left_sub_sql))
if len(select_sub_sql) > 0:
new_select_statement = remove_missing_tables_from_select(' '.join(select_sub_sql))
sub_sqls.append(new_select_statement)
sub_sqls.append(from_clause)
if len(right_sub_sql) > 0:
sub_sqls.append(' '.join(right_sub_sql))
return sub_sqls
def postprocess_single(format_sql_2, schema, start_alias_id=0):
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_2, schema)
format_sql_2 = get_surface_form_orig(format_sql_2, schema)
if len(candidate_tables_id) == 0:
final_sql = format_sql_2.replace('\n', ' ')
elif len(candidate_tables_id) == 1:
# easy case
table_name = table_names_original[candidate_tables_id[0]]
from_clause = 'from {}'.format(table_name)
format_sql_3 = []
for sub_sql in format_sql_2.split('\n'):
if 'select' in sub_sql:
format_sql_3 += add_from_clase(sub_sql, from_clause)
else:
format_sql_3.append(sub_sql)
final_sql = ' '.join(format_sql_3).replace('{}.'.format(table_name), '')
else:
# more than 1 candidate_tables
table_alias_dict, ret = gen_from(candidate_tables_id, schema)
from_clause = ret
for i in range(len(table_alias_dict)):
from_clause = from_clause.replace('T{}'.format(i+1), 'T{}'.format(i+1+start_alias_id))
table_name_to_alias = {}
for table_id, alias_id in table_alias_dict.items():
table_name = table_names_original[table_id]
alias = 'T{}'.format(alias_id+start_alias_id)
table_name_to_alias[table_name] = alias
start_alias_id = start_alias_id + len(table_alias_dict)
format_sql_3 = []
for sub_sql in format_sql_2.split('\n'):
if 'select' in sub_sql:
format_sql_3 += add_from_clase(sub_sql, from_clause)
else:
format_sql_3.append(sub_sql)
format_sql_3 = ' '.join(format_sql_3)
for table_name, alias in table_name_to_alias.items():
format_sql_3 = format_sql_3.replace('{}.'.format(table_name), '{}.'.format(alias))
final_sql = format_sql_3
for i in range(5):
final_sql = final_sql.replace('select count ( T{}.* ) '.format(i), 'select count ( * ) ')
final_sql = final_sql.replace('count ( T{}.* ) from '.format(i), 'count ( * ) from ')
final_sql = final_sql.replace('order by count ( T{}.* ) '.format(i), 'order by count ( * ) ')
final_sql = final_sql.replace('having count ( T{}.* ) '.format(i), 'having count ( * ) ')
return final_sql, start_alias_id
def postprocess_nested(format_sql_2, schema):
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_2, schema)
if len(candidate_tables_id) == 1:
format_sql_2 = get_surface_form_orig(format_sql_2, schema)
# easy case
table_name = table_names_original[candidate_tables_id[0]]
from_clause = 'from {}'.format(table_name)
format_sql_3 = []
for sub_sql in format_sql_2.split('\n'):
if 'select' in sub_sql:
format_sql_3 += add_from_clase(sub_sql, from_clause)
else:
format_sql_3.append(sub_sql)
final_sql = ' '.join(format_sql_3).replace('{}.'.format(table_name), '')
else:
# case 1: easy case, except / union / intersect
# case 2: nested queries in condition
final_sql = []
num_keywords = format_sql_2.count('except') + format_sql_2.count('union') + format_sql_2.count('intersect')
num_select = format_sql_2.count('select')
def postprocess_subquery(sub_query_one, schema, start_alias_id_1):
num_select = sub_query_one.count('select ')
final_sub_sql = []
sub_query = []
for sub_sql in sub_query_one.split('\n'):
if 'select' in sub_sql:
if len(sub_query) > 0:
sub_query = '\n'.join(sub_query)
sub_query, start_alias_id_1 = postprocess_single(sub_query, schema, start_alias_id_1)
final_sub_sql.append(sub_query)
sub_query = []
sub_query.append(sub_sql)
else:
sub_query.append(sub_sql)
if len(sub_query) > 0:
sub_query = '\n'.join(sub_query)
sub_query, start_alias_id_1 = postprocess_single(sub_query, schema, start_alias_id_1)
final_sub_sql.append(sub_query)
final_sub_sql = ' '.join(final_sub_sql)
return final_sub_sql, False, start_alias_id_1
start_alias_id = 0
sub_query = []
for sub_sql in format_sql_2.split('\n'):
if 'except' in sub_sql or 'union' in sub_sql or 'intersect' in sub_sql:
sub_query = '\n'.join(sub_query)
sub_query, _, start_alias_id = postprocess_subquery(sub_query, schema, start_alias_id)
final_sql.append(sub_query)
final_sql.append(sub_sql)
sub_query = []
else:
sub_query.append(sub_sql)
if len(sub_query) > 0:
sub_query = '\n'.join(sub_query)
sub_query, _, start_alias_id = postprocess_subquery(sub_query, schema, start_alias_id)
final_sql.append(sub_query)
final_sql = ' '.join(final_sql)
# special case of from a subquery
final_sql = final_sql.replace('select count ( * ) (', 'select count ( * ) from (')
return final_sql
def postprocess_one(pred_sql, schema):
pred_sql = pred_sql.replace('group_by', 'group by').replace('order_by', 'order by').replace('limit_value', 'limit 1').replace('_EOS', '').replace(' value ',' 1 ').replace('distinct', '').strip(',').strip()
if pred_sql.endswith('value'):
pred_sql = pred_sql[:-len('value')] + '1'
try:
format_sql = sqlparse.format(pred_sql, reindent=True)
except:
if len(pred_sql) == 0:
return "None"
return pred_sql
format_sql_2 = normalize_space(format_sql)
num_select = format_sql_2.count('select')
if num_select > 1:
final_sql = postprocess_nested(format_sql_2, schema)
else:
final_sql, _ = postprocess_single(format_sql_2, schema)
if final_sql == "":
return "None"
return final_sql
def postprocess(predictions, database_schema, remove_from=False):
import math
use_reranker = True
correct = 0
total = 0
postprocess_sqls = {}
utterances = []
Count = 0
score_vis = dict()
if os.path.exists('./score_vis.json'):
with open('./score_vis.json', 'r') as f:
score_vis = json.load(f)
for pred in predictions:
Count += 1
if Count % 10 == 0:
print('Count', Count, "score_vis", len(score_vis))
db_id = pred['database_id']
schema = database_schema[db_id]
try:
return_match = pred['return_match']
except:
return_match = ''
if db_id not in postprocess_sqls:
postprocess_sqls[db_id] = []
interaction_id = pred['interaction_id']
turn_id = pred['index_in_interaction']
total += 1
if turn_id == 0:
utterances = []
question = ' '.join(pred['input_seq'])
pred_sql_str = ' '.join(pred['flat_prediction'])
utterances.append(question)
beam_sql_strs = []
for score, beam_sql_str in pred['beam']:
beam_sql_strs.append( (score, ' '.join(beam_sql_str)))
gold_sql_str = ' '.join(pred['flat_gold_queries'][0])
if pred_sql_str == gold_sql_str:
correct += 1
postprocess_sql = pred_sql_str
if remove_from:
postprocess_sql = postprocess_one(pred_sql_str, schema)
sqls = []
key_idx = dict()
for i in range(len(beam_sql_strs)):
sql = postprocess_one(beam_sql_strs[i][1], schema)
key = '&'.join(utterances)+'#'+sql
if key not in score_vis:
if key not in key_idx:
key_idx[key]=len(sqls)
sqls.append(sql.replace("value", "1"))
if use_reranker and len(sqls) > 0:
score_list = ranker.get_score_batch(utterances, sqls)
for i in range(len(beam_sql_strs)):
score = beam_sql_strs[i][0]
sql = postprocess_one(beam_sql_strs[i][1], schema)
if use_reranker:
key = '&'.join(utterances)+'#'+sql
old_score = score
if key in score_vis:
score = score_vis[key]
else:
score = score_list[key_idx[key]]
score_vis[key] = score
score += i * 1e-4
score = -math.log(score)
score += old_score
beam_sql_strs[i] = (score, sql)
assert beam_sql_strs[0][1] == postprocess_sql
if use_reranker and Count % 100 == 0:
with open('score_vis.json', 'w') as file:
json.dump(score_vis, file)
gold_sql_str = postprocess_one(gold_sql_str, schema)
postprocess_sqls[db_id].append((postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_sql_str))
print (correct, total, float(correct)/total)
return postprocess_sqls
def read_prediction(pred_file):
if len( pred_file.split(',') ) > 1:
pred_files = pred_file.split(',')
sum_predictions = []
for pred_file in pred_files:
print('Read prediction from', pred_file)
predictions = []
with open(pred_file) as f:
for line in f:
pred = json.loads(line)
pred["beam"] = pred["beam"][:5]
predictions.append(pred)
print('Number of predictions', len(predictions))
if len(sum_predictions) == 0:
sum_predictions = predictions
else:
assert len(sum_predictions) == len(predictions)
for i in range(len(sum_predictions)):
sum_predictions[i]['beam'] += predictions[i]['beam']
return sum_predictions
else:
print('Read prediction from', pred_file)
predictions = []
with open(pred_file) as f:
for line in f:
pred = json.loads(line)
pred["beam"] = pred["beam"][:5]
predictions.append(pred)
print('Number of predictions', len(predictions))
return predictions
def read_schema(table_schema_path):
with open(table_schema_path) as f:
database_schema = json.load(f)
database_schema_dict = {}
for table_schema in database_schema:
db_id = table_schema['db_id']
database_schema_dict[db_id] = table_schema
return database_schema_dict
def write_and_evaluate(postprocess_sqls, db_path, table_schema_path, gold_path, dataset):
db_list = []
if gold_path is not None:
with open(gold_path) as f:
for line in f:
line_split = line.strip().split('\t')
if len(line_split) != 2:
continue
db = line.strip().split('\t')[1]
if db not in db_list:
db_list.append(db)
else:
gold_path = './fake_gold.txt'
with open(gold_path, "w") as file:
db_list = list(postprocess_sqls.keys())
last_id = None
for db in db_list:
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
if last_id is not None and last_id != str(interaction_id)+db:
file.write('\n')
last_id = str(interaction_id)+db
file.write(gold_str+'\t'+db+'\n')
output_file = 'output_temp.txt'
if dataset == 'spider':
with open(output_file, "w") as f:
for db in db_list:
for postprocess_sql, interaction_id, turn_id in postprocess_sqls[db]:
f.write(postprocess_sql+'\n')
command = 'python3 eval_scripts/evaluation.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,
table_schema_path,
gold_path,
os.path.abspath(output_file))
elif dataset in ['sparc', 'cosql']:
cnt = 0
with open(output_file, "w") as f:
last_id = None
for db in db_list:
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
if last_id is not None and last_id != str(interaction_id)+db:
f.write('\n')
last_id = str(interaction_id) + db
f.write('{}\n'.format( '\t'.join( [x[1] for x in beam_sql_strs] ) ))
f.write('{}\n'.format( '\t'.join( [str(x[0]) for x in beam_sql_strs] )) )
f.write('{}\n'.format( question ))
cnt += 1
"""
predict_file = 'predicted_sql.txt'
cnt = 0
with open(predict_file, "w") as f:
for db in db_list:
for postprocess_sql, interaction_id, turn_id, beam_sql_strs, question, gold_str in postprocess_sqls[db]:
print(postprocess_sql)
print(beam_sql_strs)
print(question)
print(gold_str)
if turn_id == 0 and cnt > 0:
f.write('\n')
f.write('{}\n'.format(postprocess_sql))
cnt += 1
"""
command = 'python3 eval_scripts/gen_final.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,table_schema_path,gold_path,os.path.abspath(output_file))
#command = 'python3 eval_scripts/evaluation_sqa.py --db {} --table {} --etype match --gold {} --pred {}'.format(db_path,table_schema_path,gold_path,os.path.abspath(output_file))
#command += '; rm output_temp.txt'
print('begin command')
return command
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=('spider', 'sparc', 'cosql'), default='sparc')
parser.add_argument('--split', type=str, default='dev')
parser.add_argument('--pred_file', type=str, default='')
parser.add_argument('--remove_from', action='store_true', default=False)
args = parser.parse_args()
db_path = 'data/database/'
gold_path = None
if args.dataset == 'spider':
table_schema_path = 'data/spider/tables.json'
if args.split == 'dev':
gold_path = 'data/spider/dev_gold.sql'
elif args.dataset == 'sparc':
table_schema_path = 'data/sparc/tables.json'
if args.split == 'dev':
gold_path = 'data/sparc/dev_gold.txt'
elif args.dataset == 'cosql':
table_schema_path = 'data/cosql/tables.json'
if args.split == 'dev':
gold_path = 'data/cosql/dev_gold.txt'
pred_file = args.pred_file
database_schema = read_schema(table_schema_path)
predictions = read_prediction(pred_file)
postprocess_sqls = postprocess(predictions, database_schema, args.remove_from)
command = write_and_evaluate(postprocess_sqls, db_path, table_schema_path, gold_path, args.dataset)
print('command', command)
eval_output = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)
if len( pred_file.split(',') ) > 1:
pred_file = 'ensemble'
else:
pred_file = pred_file.split(',')[0]
with open(pred_file+'.eval', 'w') as f:
f.write(eval_output.decode("utf-8"))
print('Eval result in', pred_file+'.eval')

1296
r2sql/cosql/predict.txt Normal file

File diff suppressed because it is too large Load Diff

556
r2sql/cosql/preprocess.py Normal file
View File

@ -0,0 +1,556 @@
import argparse
import os
import sys
import pickle
import json
import shutil
import sqlparse
from postprocess_eval import get_candidate_tables
def write_interaction(interaction_list,split,output_dir):
json_split = os.path.join(output_dir,split+'.json')
pkl_split = os.path.join(output_dir,split+'.pkl')
with open(json_split, 'w') as outfile:
for interaction in interaction_list:
json.dump(interaction, outfile, indent = 4)
outfile.write('\n')
new_objs = []
for i, obj in enumerate(interaction_list):
new_interaction = []
for ut in obj["interaction"]:
sql = ut["sql"]
sqls = [sql]
tok_sql_list = []
for sql in sqls:
results = []
tokenized_sql = sql.split()
tok_sql_list.append((tokenized_sql, results))
ut["sql"] = tok_sql_list
new_interaction.append(ut)
obj["interaction"] = new_interaction
new_objs.append(obj)
with open(pkl_split,'wb') as outfile:
pickle.dump(new_objs, outfile)
return
def read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas_dict):
with open(database_schema_filename) as f:
database_schemas = json.load(f)
def get_schema_tokens(table_schema):
column_names_surface_form = []
column_names = []
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
# this is just *
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
column_names.append(column_name.lower())
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
return column_names_surface_form, column_names
for table_schema in database_schemas:
database_id = table_schema['db_id']
database_schemas_dict[database_id] = table_schema
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
return schema_tokens, column_names, database_schemas_dict
def remove_from_with_join(format_sql_2):
used_tables_list = []
format_sql_3 = []
table_to_name = {}
table_list = []
old_table_to_name = {}
old_table_list = []
for sub_sql in format_sql_2.split('\n'):
if 'select ' in sub_sql:
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
old_table_list = table_list
old_table_to_name = table_to_name
table_to_name = {}
table_list = []
format_sql_3.append(sub_sql)
elif sub_sql.startswith('from'):
new_sub_sql = None
sub_sql_tokens = sub_sql.split()
for t_i, t in enumerate(sub_sql_tokens):
if t == 'as':
table_to_name[sub_sql_tokens[t_i+1]] = sub_sql_tokens[t_i-1]
table_list.append(sub_sql_tokens[t_i-1])
elif t == ')' and new_sub_sql is None:
# new_sub_sql keeps some trailing parts after ')'
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
if len(table_list) > 0:
# if it's a from clause with join
if new_sub_sql is not None:
format_sql_3.append(new_sub_sql)
used_tables_list.append(table_list)
else:
# if it's a from clause without join
table_list = old_table_list
table_to_name = old_table_to_name
assert 'join' not in sub_sql
if new_sub_sql is not None:
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
assert len(sub_sub_sql.split()) == 2
used_tables_list.append([sub_sub_sql.split()[1]])
format_sql_3.append(sub_sub_sql)
format_sql_3.append(new_sub_sql)
elif 'join' not in sub_sql:
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
if len(sub_sql.split()) == 2:
used_tables_list.append([sub_sql.split()[1]])
format_sql_3.append(sub_sql)
else:
print('bad from clause in remove_from_with_join')
exit()
else:
format_sql_3.append(sub_sql)
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
used_tables = []
for t in used_tables_list:
for tt in t:
used_tables.append(tt)
used_tables = list(set(used_tables))
return format_sql_3, used_tables, used_tables_list
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
format_sql_4 = []
table_name = None
for sub_sql in format_sql_3.split('\n'):
if 'select ' in sub_sql:
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4[i] = ' '.join(tokens)
format_sql_4.append(sub_sql)
elif sub_sql.startswith('from'):
sub_sql_tokens = sub_sql.split()
if len(sub_sql_tokens) == 1:
table_name = None
elif len(sub_sql_tokens) == 2:
table_name = sub_sql_tokens[1]
else:
print('bad from clause in remove_from_without_join')
print(format_sql_3)
exit()
else:
format_sql_4.append(sub_sql)
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4[i] = ' '.join(tokens)
return format_sql_4
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
# If just one table used, easy case, replace all column_name -> table_name.column_name
if len(used_tables) == 1:
table_name = used_tables[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def get_table_name_for(token):
table_names = []
for table_name in used_tables:
if '{}.{}'.format(table_name, token) in schema_tokens:
table_names.append(table_name)
if len(table_names) == 0:
return 'table'
if len(table_names) > 1:
return None
else:
return table_names[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
# skip *
if token == '*':
continue
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
table_name = get_table_name_for(token)
if table_name:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def check_oov(format_sql_final, output_vocab, schema_tokens):
for sql_tok in format_sql_final.split():
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
print('OOV!', sql_tok)
raise Exception('OOV')
def normalize_space(format_sql):
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('.',' . ').replace('(',' ( ').replace(')',' ) ').split()) for sub_sql in format_sql.split('\n')]
format_sql_1 = '\n'.join(format_sql_1)
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor', 'select constructor').replace('select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork', 'select work').replace('select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder', 'select order').replace('from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor', 'from for').replace('from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace('.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace('.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor', '. constructor').replace('.\ngroup', '. group').replace('.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
return format_sql_2
def normalize_final_sql(format_sql_5):
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
# normalize two bad sqls
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
# This is the failure case of remove_from_without_join()
format_sql_final = format_sql_final.replace('select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value', 'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
return format_sql_final
def parse_sql(sql_string, db_id, column_names, output_vocab, schema_tokens, schema):
format_sql = sqlparse.format(sql_string, reindent=True)
format_sql_2 = normalize_space(format_sql)
num_from = sum([1 for sub_sql in format_sql_2.split('\n') if sub_sql.startswith('from')])
num_select = format_sql_2.count('select ') + format_sql_2.count('select\n')
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
format_sql_3 = '\n'.join(format_sql_3)
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
format_sql_4 = '\n'.join(format_sql_4)
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
format_sql_5 = '\n'.join(format_sql_5)
format_sql_final = normalize_final_sql(format_sql_5)
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_final, schema)
failure = False
if len(candidate_tables_id) != len(used_tables):
failure = True
check_oov(format_sql_final, output_vocab, schema_tokens)
return format_sql_final
def read_spider_split(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
with open(split_json) as f:
split_data = json.load(f)
print('read_spider_split', split_json, len(split_data))
for i, ex in enumerate(split_data):
db_id = ex['db_id']
final_sql = []
skip = False
for query_tok in ex['query_toks_no_value']:
if query_tok != '.' and '.' in query_tok:
# invalid sql; didn't use table alias in join
final_sql += query_tok.replace('.',' . ').split()
skip = True
else:
final_sql.append(query_tok)
final_sql = ' '.join(final_sql)
if skip and 'train' in split_json:
continue
if remove_from:
final_sql_parse = parse_sql(final_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
else:
final_sql_parse = final_sql
final_utterance = ' '.join(ex['question_toks'])
if db_id not in interaction_list:
interaction_list[db_id] = []
interaction = {}
interaction['id'] = ''
interaction['scenario'] = ''
interaction['database_id'] = db_id
interaction['interaction_id'] = len(interaction_list[db_id])
interaction['final'] = {}
interaction['final']['utterance'] = final_utterance
interaction['final']['sql'] = final_sql_parse
interaction['interaction'] = [{'utterance': final_utterance, 'sql': final_sql_parse}]
interaction_list[db_id].append(interaction)
return interaction_list
def read_data_json(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
with open(split_json) as f:
split_data = json.load(f)
print('read_data_json', split_json, len(split_data))
for interaction_data in split_data:
db_id = interaction_data['database_id']
final_sql = interaction_data['final']['query']
final_utterance = interaction_data['final']['utterance']
if db_id not in interaction_list:
interaction_list[db_id] = []
# no interaction_id in train
if 'interaction_id' in interaction_data['interaction']:
interaction_id = interaction_data['interaction']['interaction_id']
else:
interaction_id = len(interaction_list[db_id])
interaction = {}
interaction['id'] = ''
interaction['scenario'] = ''
interaction['database_id'] = db_id
interaction['interaction_id'] = interaction_id
interaction['final'] = {}
interaction['final']['utterance'] = final_utterance
interaction['final']['sql'] = final_sql
interaction['interaction'] = []
for turn in interaction_data['interaction']:
turn_sql = []
skip = False
for query_tok in turn['query_toks_no_value']:
if query_tok != '.' and '.' in query_tok:
# invalid sql; didn't use table alias in join
turn_sql += query_tok.replace('.',' . ').split()
skip = True
else:
turn_sql.append(query_tok)
turn_sql = ' '.join(turn_sql)
# Correct some human sql annotation error
turn_sql = turn_sql.replace('select f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id', 'select t1 . f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id')
turn_sql = turn_sql.replace('select name from climber mountain', 'select name from climber')
turn_sql = turn_sql.replace('select sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid', 'select t1 . sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid')
turn_sql = turn_sql.replace('select avg ( price ) from goods )', 'select avg ( price ) from goods')
turn_sql = turn_sql.replace('select min ( annual_fuel_cost ) , from vehicles', 'select min ( annual_fuel_cost ) from vehicles')
turn_sql = turn_sql.replace('select * from goods where price < ( select avg ( price ) from goods', 'select * from goods where price < ( select avg ( price ) from goods )')
turn_sql = turn_sql.replace('select distinct id , price from goods where price < ( select avg ( price ) from goods', 'select distinct id , price from goods where price < ( select avg ( price ) from goods )')
turn_sql = turn_sql.replace('select id from goods where price > ( select avg ( price ) from goods', 'select id from goods where price > ( select avg ( price ) from goods )')
if skip and 'train' in split_json:
continue
if remove_from:
try:
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
except:
print('continue')
continue
else:
turn_sql_parse = turn_sql
if 'utterance_toks' in turn:
turn_utterance = ' '.join(turn['utterance_toks']) # not lower()
else:
turn_utterance = turn['utterance']
interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
if len(interaction['interaction']) > 0:
interaction_list[db_id].append(interaction)
return interaction_list
def read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(spider_dir, 'train.json')
interaction_list = read_spider_split(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(spider_dir, 'dev.json')
interaction_list = read_spider_split(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(sparc_dir, 'train_no_value.json')
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(sparc_dir, 'dev_no_value.json')
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(cosql_dir, 'train.json')
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(cosql_dir, 'dev.json')
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_db_split(data_dir):
train_database = []
with open(os.path.join(data_dir,'train_db_ids.txt')) as f:
for line in f:
train_database.append(line.strip())
dev_database = []
with open(os.path.join(data_dir,'dev_db_ids.txt')) as f:
for line in f:
dev_database.append(line.strip())
return train_database, dev_database
def preprocess(dataset, remove_from=False):
# Validate output_vocab
output_vocab = ['_UNK', '_EOS', '.', 't1', 't2', '=', 'select', 'from', 'as', 'value', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'count', 'group', 'order', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
if remove_from:
output_vocab = ['_UNK', '_EOS', '=', 'select', 'value', ')', '(', 'where', ',', 'count', 'group_by', 'order_by', 'distinct', 'and', 'limit_value', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!=', 'union', 'between', '-', '+', '/']
print('size of output_vocab', len(output_vocab))
print('output_vocab', output_vocab)
print()
if dataset == 'spider':
spider_dir = 'data/spider/'
database_schema_filename = 'data/spider/tables.json'
output_dir = 'data/spider_data'
if remove_from:
output_dir = 'data/spider_data_removefrom'
train_database, dev_database = read_db_split(spider_dir)
elif dataset == 'sparc':
sparc_dir = 'data/sparc/'
database_schema_filename = 'data/sparc/tables.json'
output_dir = 'data/sparc_data'
if remove_from:
output_dir = 'data/sparc_data_removefrom'
train_database, dev_database = read_db_split(sparc_dir)
elif dataset == 'cosql':
cosql_dir = 'data/cosql/'
database_schema_filename = 'data/cosql/tables.json'
output_dir = 'data/cosql_data'
if remove_from:
output_dir = 'data/cosql_data_removefrom'
train_database, dev_database = read_db_split(cosql_dir)
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
schema_tokens = {}
column_names = {}
database_schemas = {}
print('Reading spider database schema file')
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas)
num_database = len(schema_tokens)
print('num_database', num_database, len(train_database), len(dev_database))
print('total number of schema_tokens / databases:', len(schema_tokens))
output_database_schema_filename = os.path.join(output_dir, 'tables.json')
with open(output_database_schema_filename, 'w') as outfile:
json.dump([v for k,v in database_schemas.items()], outfile, indent=4)
if dataset == 'spider':
interaction_list = read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
elif dataset == 'sparc':
interaction_list = read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
elif dataset == 'cosql':
interaction_list = read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
print('interaction_list length', len(interaction_list))
train_interaction = []
for database_id in interaction_list:
if database_id not in dev_database:
train_interaction += interaction_list[database_id]
dev_interaction = []
for database_id in dev_database:
dev_interaction += interaction_list[database_id]
print('train interaction: ', len(train_interaction))
print('dev interaction: ', len(dev_interaction))
write_interaction(train_interaction, 'train', output_dir)
write_interaction(dev_interaction, 'dev', output_dir)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", choices=('spider', 'sparc', 'cosql'), default='sparc')
parser.add_argument('--remove_from', action='store_true', default=False)
args = parser.parse_args()
preprocess(args.dataset, args.remove_from)

View File

View File

@ -0,0 +1,77 @@
import torch
from torch.utils import data
import numpy as np
import os
import json
import random
from transformers import RobertaTokenizer
from .utils import preprocess
class NL2SQL_Dataset(data.Dataset):
def __init__(self, args, data_type="train", shuffle=True):
self.data_type = data_type
assert self.data_type in ["train", "valid", "test"]
# tokenizer
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', max_len=128)
if self.data_type in ['train', 'valid']:
self.file_path = os.path.join(args.data_path, 'rerank_train_data.json')
else:
self.file_path = os.path.join(args.data_path, 'rerank_dev_data.json')
with open(self.file_path, 'r') as f:
self.data_lines = f.readlines()
print('load data from :', self.file_path)
# split to pos and neg data
self.pos_data_all = []
self.neg_data_all = []
for i in range(len(self.data_lines)):
content = json.loads(self.data_lines[i])
if content['match']:
self.pos_data_all.append(content)
else:
self.neg_data_all.append(content)
if self.data_type in ['train', 'valid']:
# np.random.shuffle(self.pos_data_all)
# np.random.shuffle(self.neg_data_all)
# valid_range = int(len(self.pos_data_all + self.neg_data_all) * 0.9 )
valid_range_pos = int(len(self.pos_data_all) * 0.9)
valid_range_neg = int(len(self.neg_data_all) * 0.9)
self.train_pos_data = self.pos_data_all[:valid_range_pos]
self.val_pos_data = self.pos_data_all[valid_range_pos:]
self.train_neg_data = self.neg_data_all[:valid_range_neg]
self.val_neg_data = self.neg_data_all[valid_range_neg:]
def get_random_data(self, shuffle=True):
if self.data_type == "train":
max_train_num = len(self.train_pos_data) if len(self.train_pos_data) < len(self.train_neg_data) else len(self.train_neg_data)
np.random.shuffle(self.train_pos_data)
np.random.shuffle(self.train_neg_data)
data = self.train_pos_data[:max_train_num] + self.train_neg_data[:max_train_num]
np.random.shuffle(data)
elif self.data_type == "valid":
max_val_num = len(self.val_pos_data) if len(self.val_pos_data) < len(self.val_neg_data) else len(self.val_neg_data)
data = self.val_pos_data[:max_val_num] + self.val_neg_data[:max_val_num]
else:
data = self.pos_data_all + self.neg_data_all
return data
def __getitem__(self, idx):
data_sample = self.get_random_data()
data = data_sample[idx]
turn_id = data['turn_id']
utterances = data['utterances'][:turn_id + 1]
lable = data['match']
sql = data['pred']
tokens_tensor, attention_mask_tensor = preprocess(utterances, sql, self.tokenizer)
label = torch.tensor(lable)
return tokens_tensor, attention_mask_tensor, lable
def __len__(self):
return len(self.get_random_data())

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
import argparse
import os
from .dataset import NL2SQL_Dataset
from .model import ReRanker
from sklearn.metrics import confusion_matrix, accuracy_score
def parser_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=666)
parser.add_argument("--gpu", type=str, default='0')
parser.add_argument("--data_path", type=str, default="./reranker/data")
parser.add_argument("--train", action="store_true")
parser.add_argument("--test", action="store_true")
parser.add_argument("--epoches", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--cls_lr", default=1e-3)
parser.add_argument("--bert_lr", default=5e-6)
parser.add_argument("--threshold", default=0.5)
parser.add_argument("--save_dir", default="./reranker/checkpoints")
parser.add_argument("--base_model", default="roberta")
args = parser.parse_args()
return args
def init(args):
"""
guaranteed reproducible
"""
os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
print("make dir : ", args.save_dir)
def print_evalute(evalute_lst):
all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc = evalute_lst
print("\n All Acc {}/{} ({:.2f}%) \t Pos Acc: {}/{} ({:.2f}%) \t Neg Acc: {}/{} ({:.2f}%) \n".format(all_acc, all_cnt, \
100.* all_acc / all_cnt, \
pos_acc, pos_cnt, \
100. * pos_acc / pos_cnt, \
neg_acc, neg_cnt, \
100. * neg_acc / neg_cnt))
return 100. * all_acc / all_cnt
def evaluate2(args, output_flat, target_flat):
print("start evalute ...")
pred = torch.gt(output_flat, args.threshold).float()
pred = pred.cpu().numpy()
target = target_flat.cpu().numpy()
ret = confusion_matrix(target, pred)
neg_recall, pos_recall = ret.diagonal() / ret.sum(axis=1)
neg_acc, pos_acc = ret.diagonal() / ret.sum(axis=0)
acc = accuracy_score(target, pred)
print(" All Acc:\t {:.3f}%".format(100.0 * acc))
print(" Neg Recall: \t {:.3f}% \t Pos Recall: \t {:.3f}% \n Neg Acc: \t {:.3f}% \t Pos Acc: \t {:.3f}% \n".format(100.0 * neg_recall, 100.0 * pos_recall, 100.0 * neg_acc, 100.0 * pos_acc))
def evaluate(args, output, target, evalute_lst):
all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc = evalute_lst
output = torch.gt(output, args.threshold).float()
for idx in range(target.shape[0]):
gt = target[idx]
pred = output[idx]
if gt == 1:
pos_cnt += 1
if gt == pred:
pos_acc += 1
elif gt == 0:
neg_cnt += 1
if gt == pred:
neg_acc += 1
all_acc = pos_acc + neg_acc
all_cnt = pos_cnt + neg_cnt
return [all_cnt, all_acc, pos_cnt, pos_acc, neg_cnt, neg_acc]
def train(args, model, train_loader, epoch):
model.train()
criterion = nn.BCELoss()
idx = 0
evalute_lst = [0] * 6
for batch_idx, (tokens, attention_mask, target) in enumerate(train_loader):
tokens, attention_mask, target = tokens.cuda(), attention_mask.cuda(), target.cuda().float()
model.zero_grad()
output = model(input_ids=tokens, attention_mask=attention_mask)
output = output.squeeze(dim=-1)
loss = criterion(output, target)
loss.backward()
model.cls_trainer.step()
model.bert_trainer.step()
idx += len(target)
evalute_lst = evaluate(args, output, target, evalute_lst)
if batch_idx % 50 == 0:
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch, idx, len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
return print_evalute(evalute_lst)
def test(args, model, data_loader):
model.eval()
criterion = nn.BCELoss()
test_loss = 0
evalute_lst = [0] * 6
output_lst = []
target_lst = []
with torch.no_grad():
for tokens, attention_mask, target in data_loader:
tokens, attention_mask, target = tokens.cuda(), attention_mask.cuda(), target.cuda().float()
output = model(input_ids=tokens, attention_mask=attention_mask)
output = output.squeeze(dim=-1)
output_lst.append(output)
target_lst.append(target)
test_loss = criterion(output, target)
evalute_lst = evaluate(args, output, target, evalute_lst)
output_flat = torch.cat(output_lst)
target_flat = torch.cat(target_lst)
print('\n{} set: loss: {:.4f}\n'.format(data_loader.dataset.data_type, test_loss))
acc = print_evalute(evalute_lst)
return acc
def main():
# init
args = parser_arg()
init(args)
# model define
model = ReRanker(args, args.base_model)
model = model.cuda()
model.build_optim()
# learning rate scheduler
scheduler = StepLR(model.cls_trainer, step_size=1)
test_loader = DataLoader(
NL2SQL_Dataset(args, data_type="test"),
batch_size=args.batch_size
)
if args.train:
train_loader = DataLoader(
NL2SQL_Dataset(args, data_type="train"),
batch_size=args.batch_size
)
valid_loader = DataLoader(
NL2SQL_Dataset(args, data_type="valid"),
batch_size=args.batch_size
)
print("train set num: ", len(train_loader.dataset))
print("valid set num: ", len(valid_loader.dataset))
print("test set num: ", len(test_loader.dataset))
for epoch in range(args.epoches):
train_acc = train(args, model, train_loader, epoch)
valid_acc = test(args, model, valid_loader)
save_path = os.path.join(args.save_dir, "roberta_%s_train+%s_val+%s.pt" % (str(epoch), str(train_acc), str(valid_acc)))
torch.save(model.state_dict(), save_path)
print('save model %s...' % save_path)
scheduler.step()
test(args, model, test_loader)
if args.test:
model_path = 'xxx.pt'
model.load_state_dict(torch.load(model_path))
print("load model from ", model_path)
test(args, model, test_loader)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from transformers import BertModel, RobertaModel
class ReRanker(nn.Module):
def __init__(self, args=None, base_model="roberta"):
super(ReRanker, self).__init__()
self.args = args
self.bert_model = RobertaModel.from_pretrained('./local_param/')
self.cls_model = nn.Sequential(
nn.Linear(768, 128),
nn.Tanh(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def build_optim(self):
params_cls_trainer = []
params_bert_trainer = []
for name, param in self.named_parameters():
if param.requires_grad:
if 'bert_model' in name:
params_bert_trainer.append(param)
else:
params_cls_trainer.append(param)
self.bert_trainer = torch.optim.Adam(params_bert_trainer, lr=self.args.bert_lr)
self.cls_trainer = torch.optim.Adam(params_cls_trainer, lr=self.args.cls_lr)
def forward(self, input_ids, attention_mask):
# bert model
x = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[0]
# get first [CLS] representation
x = x[:, 0, :]
# classification layer
x = self.cls_model(x)
return x

View File

@ -0,0 +1,56 @@
import torch
from transformers import BertTokenizer, RobertaTokenizer
from .model import ReRanker
from .utils import preprocess
class Ranker:
def __init__(self, model_path, base_model='roberta'):
self.model = ReRanker(base_model=base_model)
self.model.load_state_dict(torch.load(model_path))
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', max_len=512)
print("load reranker model from ", model_path)
def get_score(self, utterances, sql):
tokens_tensor, attention_mask_tensor = preprocess(utterances, sql, self.tokenizer)
tokens_tensor = tokens_tensor.unsqueeze(0)
attention_mask_tensor = attention_mask_tensor.unsqueeze(0)
output = self.model(input_ids=tokens_tensor, attention_mask=attention_mask_tensor)
output = output.squeeze(dim=-1)
return output.item()
def get_score_batch(self, utterances, sqls):
assert isinstance(sqls, list), "only support sql list input"
tokens_tensor_batch, attention_mask_tensor_batch = preprocess(utterances, sqls[0], self.tokenizer)
tokens_tensor_batch = tokens_tensor_batch.unsqueeze(0)
attention_mask_tensor_batch = attention_mask_tensor_batch.unsqueeze(0)
if len(sqls) > 1:
for s in sqls[1:]:
new_token, new_mask = preprocess(utterances, s, self.tokenizer)
tokens_tensor_batch = torch.cat([tokens_tensor_batch, new_token.unsqueeze(0)], dim=0)
attention_mask_tensor_batch = torch.cat([attention_mask_tensor_batch, new_mask.unsqueeze(0)])
output = self.model(input_ids=tokens_tensor_batch, attention_mask=attention_mask_tensor_batch)
output = output.view(-1)
ret = []
for i in output:
ret.append(i.item())
return ret
if __name__ == '__main__':
utterances_1 = ["What are all the airlines ?", "Of these , which is Jetblue Airways ?", "What is the country corresponding it ?"]
sql_1 = "select Country from airlines where Airline = 1 select from airlines select * from airlines where Airline = 1"
utterances_2 = ["What are all the airlines ?", "Of these , which is Jetblue Airways ?", "What is the country corresponding it ?"]
sql_2 = "select T2.Countryabbrev from airlines as T1 join airports as T2 where T1.Airline = 1"
sql_lst = ["select *", "select T2.Countryabbrev from airlines as T1 join airports as T2 where T1.Airline = 1"]
utterances_test = ['Show the document name for all documents .', 'Also show their description .']
sql_test = ['select Document_Name , Document_Description from Documents', 'select Document_Name , Document_Description from Documents select Document_Name from Documents', 'select Document_Name , Document_Name , Document_Description from Documents', 'select Document_ID , Document_Description from Documents', 'select Document_Name from Documents']
ranker = Ranker()
#print("utterance : ", utterances_1)
#print("sql : ", sql_1)
#print(ranker.get_score(utterances_1, sql_1))
#print("utterance : ", utterances_2)
#print("sql : ", sql_2)
#print(ranker.get_score(utterances_2, sql_2))
print(ranker.get_score_batch(utterances_test, sql_test))

View File

@ -0,0 +1,25 @@
import torch
def preprocess(utterances, sql, tokenizer):
text = ""
for u in utterances:
for t in u.split(' '):
text = text + ' ' + t.strip()
sql = sql.strip()
sql = sql.replace(".", " ")
sql = sql.replace("_", " ")
l = []
for char in sql:
if char.isupper():
l.append(' ')
l.append(char)
sql = ''.join(l)
sql = ' '.join(sql.split())
# input:
# [CLS] utterance [SEP] sql [SEP]
token_encoding = tokenizer.encode_plus(text, sql, max_length=128, pad_to_max_length=True)
tokenized_token = token_encoding['input_ids']
attention_mask = token_encoding['attention_mask']
tokens_tensor = torch.tensor(tokenized_token)
attention_mask_tensor = torch.tensor(attention_mask)
return tokens_tensor, attention_mask_tensor

311
r2sql/cosql/run.py Normal file
View File

@ -0,0 +1,311 @@
"""Contains a main function for training and/or evaluating a model."""
import os
import sys
import numpy as np
import random
from parse_args import interpret_args
import data_util
from data_util import atis_data
from model.schema_interaction_model import SchemaInteractionATISModel
from logger import Logger
from model.model import ATISModel
from model_util import Metrics, evaluate_utterance_sample, evaluate_interaction_sample, \
train_epoch_with_utterances, train_epoch_with_interactions, evaluate_using_predicted_queries
import torch
import time
VALID_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY]
TRAIN_EVAL_METRICS = [Metrics.LOSS, Metrics.TOKEN_ACCURACY, Metrics.STRING_ACCURACY]
FINAL_EVAL_METRICS = [Metrics.STRING_ACCURACY, Metrics.TOKEN_ACCURACY]
def train(model, data, params):
""" Trains a model.
Inputs:
model (ATISModel): The model to train.
data (ATISData): The data that is used to train.
params (namespace): Training parameters.
"""
# Get the training batches.
log = Logger(os.path.join(params.logdir, params.logfile), "a+")
num_train_original = atis_data.num_utterances(data.train_data)
eval_fn = evaluate_utterance_sample
trainbatch_fn = data.get_utterance_batches
trainsample_fn = data.get_random_utterances
validsample_fn = data.get_all_utterances
batch_size = params.batch_size
if params.interaction_level:
batch_size = 1
eval_fn = evaluate_interaction_sample
trainbatch_fn = data.get_interaction_batches
trainsample_fn = data.get_random_interactions
validsample_fn = data.get_all_interactions
maximum_output_length = params.train_maximum_sql_length
train_batches = trainbatch_fn(batch_size,
max_output_length=maximum_output_length,
randomize=not params.deterministic)
if params.num_train >= 0:
train_batches = train_batches[:params.num_train]
training_sample = trainsample_fn(params.train_evaluation_size,
max_output_length=maximum_output_length)
valid_examples = validsample_fn(data.valid_data,
max_output_length=maximum_output_length)
num_train_examples = sum([len(batch) for batch in train_batches])
num_steps_per_epoch = len(train_batches)
# Keeping track of things during training.
epochs = 0
patience = params.initial_patience
learning_rate_coefficient = 1.
previous_epoch_loss = float('inf')
maximum_validation_accuracy = 0.
maximum_string_accuracy = 0.
countdown = int(patience)
if params.scheduler:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.trainer, mode='min', )
keep_training = True
while keep_training:
log.put("Epoch:\t" + str(epochs))
model.set_dropout(params.dropout_amount)
if not params.scheduler:
model.set_learning_rate(learning_rate_coefficient * params.initial_learning_rate)
# Run a training step.
if params.interaction_level:
epoch_loss = train_epoch_with_interactions(
train_batches,
params,
model,
randomize=not params.deterministic)
else:
epoch_loss = train_epoch_with_utterances(
train_batches,
model,
randomize=not params.deterministic)
log.put("train epoch loss:\t" + str(epoch_loss))
model.set_dropout(0.)
# Run an evaluation step on a sample of the training data.
"""
train_eval_results = eval_fn(training_sample,
model,
params.train_maximum_sql_length,
name=os.path.join(params.logdir, "train-eval"),
write_results=True,
gold_forcing=True,
metrics=TRAIN_EVAL_METRICS)[0]
for name, value in train_eval_results.items():
log.put(
"train final gold-passing " +
name.name +
":\t" +
"%.2f" %
value)
"""
# Run an evaluation step on the validation set.
valid_eval_results = eval_fn(valid_examples,
model,
params.eval_maximum_sql_length,
name=os.path.join(params.logdir, "valid-eval"),
write_results=True,
gold_forcing=True,
metrics=VALID_EVAL_METRICS)[0]
for name, value in valid_eval_results.items():
log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value)
valid_loss = valid_eval_results[Metrics.LOSS]
valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY]
string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY]
assert string_accuracy >= 20
if params.scheduler:
scheduler.step(valid_loss)
if valid_loss > previous_epoch_loss:
learning_rate_coefficient *= params.learning_rate_ratio
log.put(
"learning rate coefficient:\t" +
str(learning_rate_coefficient))
previous_epoch_loss = valid_loss
saved = False
if not saved and string_accuracy > maximum_string_accuracy:
maximum_string_accuracy = string_accuracy
patience = patience * params.patience_ratio
countdown = int(patience)
last_save_file = os.path.join(params.logdir, "save_" + str(epochs))
model.save(last_save_file)
log.put(
"maximum string accuracy:\t" +
str(maximum_string_accuracy))
log.put("patience:\t" + str(patience))
log.put("save file:\t" + str(last_save_file))
if countdown <= 0:
keep_training = False
countdown -= 1
log.put("countdown:\t" + str(countdown))
log.put("")
epochs += 1
log.put("Finished training!")
log.close()
return last_save_file
def evaluate(model, data, params, last_save_file, split):
"""Evaluates a pretrained model on a dataset.
Inputs:
model (ATISModel): Model class.
data (ATISData): All of the data.
params (namespace): Parameters for the model.
last_save_file (str): Location where the model save file is.
"""
if last_save_file:
model.load(last_save_file)
else:
if not params.save_file:
raise ValueError(
"Must provide a save file name if not training first.")
model.load(params.save_file)
filename = split
if filename == 'dev':
split = data.dev_data
elif filename == 'train':
split = data.train_data
elif filename == 'test':
split = data.test_data
elif filename == 'valid':
split = data.valid_data
else:
raise ValueError("Split not recognized: " + str(params.evaluate_split))
if params.use_predicted_queries:
filename += "_use_predicted_queries"
else:
filename += "_use_gold_queries"
if filename == 'train':
full_name = os.path.join(params.logdir, filename) + params.results_note
else:
full_name = os.path.join("results", params.save_file.split('/')[-1]) + params.results_note
if params.interaction_level or params.use_predicted_queries:
examples = data.get_all_interactions(split)
if params.interaction_level:
evaluate_interaction_sample(
examples,
model,
name=full_name,
metrics=FINAL_EVAL_METRICS,
total_num=atis_data.num_utterances(split),
database_username=params.database_username,
database_password=params.database_password,
database_timeout=params.database_timeout,
use_predicted_queries=params.use_predicted_queries,
max_generation_length=params.eval_maximum_sql_length,
write_results=True,
use_gpu=True,
compute_metrics=params.compute_metrics)
else:
evaluate_using_predicted_queries(
examples,
model,
name=full_name,
metrics=FINAL_EVAL_METRICS,
total_num=atis_data.num_utterances(split),
database_username=params.database_username,
database_password=params.database_password,
database_timeout=params.database_timeout)
else:
examples = data.get_all_utterances(split)
evaluate_utterance_sample(
examples,
model,
name=full_name,
gold_forcing=False,
metrics=FINAL_EVAL_METRICS,
total_num=atis_data.num_utterances(split),
max_generation_length=params.eval_maximum_sql_length,
database_username=params.database_username,
database_password=params.database_password,
database_timeout=params.database_timeout,
write_results=True)
def init_env(params):
if params.seed == 0:
params.seed = int(time.time()) % 1000
seed = params.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def main():
"""Main function that trains and/or evaluates a model."""
params = interpret_args()
init_env(params)
# Prepare the dataset into the proper form.
data = atis_data.ATISDataset(params)
# Construct the model object.
if params.interaction_level:
model_type = SchemaInteractionATISModel
else:
print('not implemented')
exit()
model = model_type(
params,
data.input_vocabulary,
data.output_vocabulary,
data.output_vocabulary_schema,
data.anonymizer if params.anonymize and params.anonymization_scoring else None)
model = model.cuda()
model.build_optim()
sys.stdout.flush()
last_save_file = ""
if params.train:
last_save_file = train(model, data, params)
if params.evaluate and 'train' in params.evaluate_split:
evaluate(model, data, params, last_save_file, split='train')
if params.evaluate and 'valid' in params.evaluate_split:
evaluate(model, data, params, last_save_file, split='valid')
if params.evaluate and 'dev' in params.evaluate_split:
evaluate(model, data, params, last_save_file, split='dev')
if params.evaluate and 'test' in params.evaluate_split:
evaluate(model, data, params, last_save_file, split='test')
if __name__ == "__main__":
main()

9
r2sql/requirement.txt Normal file
View File

@ -0,0 +1,9 @@
torch==1.0.0
sqlparse
pymysql
progressbar
nltk
numpy
six
spacy
transformers==2.10.0

View File

@ -0,0 +1,779 @@
{"sql": ["'2V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northeast", "express", "regional", "airlines"]}
{"sql": ["'3J'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "alliance"]}
{"sql": ["'7V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["alpha", "air"]}
{"sql": ["'9E'"], "type": "AIRLINE_CODE", "cleaned_nl": ["express", "airlines", "i"]}
{"sql": ["'9L'"], "type": "AIRLINE_CODE", "cleaned_nl": ["colgan", "air"]}
{"sql": ["'9N'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "states", "airlines"]}
{"sql": ["'9X'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ontario", "express"]}
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "airlines"]}
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "airline"]}
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american"]}
{"sql": ["'AC'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "canada"]}
{"sql": ["'AR'"], "type": "AIRLINE_CODE", "cleaned_nl": ["aerolineas", "argentinas"]}
{"sql": ["'AS'"], "type": "AIRLINE_CODE", "cleaned_nl": ["alaska", "airlines"]}
{"sql": ["'AT'"], "type": "AIRLINE_CODE", "cleaned_nl": ["royal", "air", "maroc"]}
{"sql": ["'BA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["british", "airways"]}
{"sql": ["'BE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["braniff", "international", "airlines"]}
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["continental", "airlines"]}
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["continental"]}
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian", "airlines", "international"]}
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian"]}
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["canadian", "airlines"]}
{"sql": ["'DH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["atlantic", "coast", "airlines"]}
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta", "airlines"]}
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta"]}
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["delta", "airline"]}
{"sql": ["'EV'"], "type": "AIRLINE_CODE", "cleaned_nl": ["atlantic", "southeast", "airlines"]}
{"sql": ["'EA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ea"]}
{"sql": ["'EA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["eastern"]}
{"sql": ["'FF'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tower", "air"]}
{"sql": ["'GX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "ontario"]}
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["america", "west", "airlines"]}
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["america", "west"]}
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "west"]}
{"sql": ["'HQ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["business", "express"]}
{"sql": ["'KW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["carnival", "air", "lines"]}
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa", "german", "airlines"]}
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa"]}
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lufthansa", "airlines"]}
{"sql": ["'MG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mgm", "grand", "air"]}
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest", "airlines"]}
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest", "airline"]}
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["northwest"]}
{"sql": ["'NX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nationair"]}
{"sql": ["'OE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["westair", "airlines"]}
{"sql": ["'OH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["comair"]}
{"sql": ["'OK'"], "type": "AIRLINE_CODE", "cleaned_nl": ["czechoslovak", "airlines"]}
{"sql": ["'OO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sky", "west", "airlines"]}
{"sql": ["'QD'"], "type": "AIRLINE_CODE", "cleaned_nl": ["grand", "airways"]}
{"sql": ["'RP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["precision", "airlines"]}
{"sql": ["'RZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "world", "express"]}
{"sql": ["'SN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sabena", "belgian", "world", "airlines"]}
{"sql": ["'SX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["christman", "air"]}
{"sql": ["'TG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["thai", "airways"]}
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["trans", "world"]}
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["twa"]}
{"sql": ["'TZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["american", "trans", "air"]}
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["united", "airlines"]}
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["united"]}
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["usair"]}
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["us", "air"]}
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest", "airlines"]}
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest", "air"]}
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["southwest"]}
{"sql": ["'XJ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mesaba", "aviation"]}
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["midwest", "express", "airlines"]}
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["midwest", "express"]}
{"sql": ["'ZW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["air", "wisconsin"]}
{"sql": ["'2V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["2v"]}
{"sql": ["'3J'"], "type": "AIRLINE_CODE", "cleaned_nl": ["3j"]}
{"sql": ["'7V'"], "type": "AIRLINE_CODE", "cleaned_nl": ["7v"]}
{"sql": ["'9E'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9e"]}
{"sql": ["'9L'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9l"]}
{"sql": ["'9N'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9n"]}
{"sql": ["'9X'"], "type": "AIRLINE_CODE", "cleaned_nl": ["9x"]}
{"sql": ["'AA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["aa"]}
{"sql": ["'AC'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ac"]}
{"sql": ["'AR'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ar"]}
{"sql": ["'AS'"], "type": "AIRLINE_CODE", "cleaned_nl": ["as"]}
{"sql": ["'BA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ba"]}
{"sql": ["'CO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["co"]}
{"sql": ["'CP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["cp"]}
{"sql": ["'DH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["dh"]}
{"sql": ["'DL'"], "type": "AIRLINE_CODE", "cleaned_nl": ["dl"]}
{"sql": ["'EV'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ev"]}
{"sql": ["'FF'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ff"]}
{"sql": ["'GX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["gx"]}
{"sql": ["'HP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["hp"]}
{"sql": ["'HQ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["hq"]}
{"sql": ["'KW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["kw"]}
{"sql": ["'LH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["lh"]}
{"sql": ["'MG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["mg"]}
{"sql": ["'NW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nw"]}
{"sql": ["'NX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["nx"]}
{"sql": ["'OE'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oe"]}
{"sql": ["'OH'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oh"]}
{"sql": ["'OK'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ok"]}
{"sql": ["'OO'"], "type": "AIRLINE_CODE", "cleaned_nl": ["oo"]}
{"sql": ["'QD'"], "type": "AIRLINE_CODE", "cleaned_nl": ["qd"]}
{"sql": ["'RP'"], "type": "AIRLINE_CODE", "cleaned_nl": ["rp"]}
{"sql": ["'RZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["rz"]}
{"sql": ["'SN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sn"]}
{"sql": ["'SX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["sx"]}
{"sql": ["'TG'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tg"]}
{"sql": ["'TW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tw"]}
{"sql": ["'TZ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["tz"]}
{"sql": ["'UA'"], "type": "AIRLINE_CODE", "cleaned_nl": ["ua"]}
{"sql": ["'US'"], "type": "AIRLINE_CODE", "cleaned_nl": ["us"]}
{"sql": ["'WN'"], "type": "AIRLINE_CODE", "cleaned_nl": ["wn"]}
{"sql": ["'XJ'"], "type": "AIRLINE_CODE", "cleaned_nl": ["xj"]}
{"sql": ["'YX'"], "type": "AIRLINE_CODE", "cleaned_nl": ["yx"]}
{"sql": ["'ZW'"], "type": "AIRLINE_CODE", "cleaned_nl": ["zw"]}
{"sql": ["'NASHVILLE'"], "type": "CITY_NAME", "cleaned_nl": ["nashville"]}
{"sql": ["'BOSTON'"], "type": "CITY_NAME", "cleaned_nl": ["boston"]}
{"sql": ["'BURBANK'"], "type": "CITY_NAME", "cleaned_nl": ["burbank"]}
{"sql": ["'BALTIMORE'"], "type": "CITY_NAME", "cleaned_nl": ["baltimore"]}
{"sql": ["'CHICAGO'"], "type": "CITY_NAME", "cleaned_nl": ["chicago"]}
{"sql": ["'CLEVELAND'"], "type": "CITY_NAME", "cleaned_nl": ["cleveland"]}
{"sql": ["'CHARLOTTE'"], "type": "CITY_NAME", "cleaned_nl": ["charlotte"]}
{"sql": ["'COLUMBUS'"], "type": "CITY_NAME", "cleaned_nl": ["columbus"]}
{"sql": ["'CINCINNATI'"], "type": "CITY_NAME", "cleaned_nl": ["cincinnati"]}
{"sql": ["'DENVER'"], "type": "CITY_NAME", "cleaned_nl": ["denver"]}
{"sql": ["'DALLAS'"], "type": "CITY_NAME", "cleaned_nl": ["dallas"]}
{"sql": ["'DETROIT'"], "type": "CITY_NAME", "cleaned_nl": ["detroit"]}
{"sql": ["'FORT WORTH'"], "type": "CITY_NAME", "cleaned_nl": ["fort", "worth"]}
{"sql": ["'HOUSTON'"], "type": "CITY_NAME", "cleaned_nl": ["houston"]}
{"sql": ["'WESTCHESTER COUNTY'"], "type": "CITY_NAME", "cleaned_nl": ["westchester", "county"]}
{"sql": ["'WESTCHESTER COUNTY'"], "type": "CITY_NAME", "cleaned_nl": ["westchester"]}
{"sql": ["'INDIANAPOLIS'"], "type": "CITY_NAME", "cleaned_nl": ["indianapolis"]}
{"sql": ["'NEWARK'"], "type": "CITY_NAME", "cleaned_nl": ["newark"]}
{"sql": ["'LAS VEGAS'"], "type": "CITY_NAME", "cleaned_nl": ["las", "vegas"]}
{"sql": ["'LOS ANGELES'"], "type": "CITY_NAME", "cleaned_nl": ["los", "angeles"]}
{"sql": ["'LOS ANGELES'"], "type": "CITY_NAME", "cleaned_nl": ["la"]}
{"sql": ["'LONG BEACH'"], "type": "CITY_NAME", "cleaned_nl": ["long", "beach"]}
{"sql": ["'ATLANTA'"], "type": "CITY_NAME", "cleaned_nl": ["atlanta"]}
{"sql": ["'MEMPHIS'"], "type": "CITY_NAME", "cleaned_nl": ["memphis"]}
{"sql": ["'MIAMI'"], "type": "CITY_NAME", "cleaned_nl": ["miami"]}
{"sql": ["'KANSAS CITY'"], "type": "CITY_NAME", "cleaned_nl": ["kansas", "city"]}
{"sql": ["'MILWAUKEE'"], "type": "CITY_NAME", "cleaned_nl": ["milwaukee"]}
{"sql": ["'MINNEAPOLIS'"], "type": "CITY_NAME", "cleaned_nl": ["minneapolis"]}
{"sql": ["'NEW YORK'"], "type": "CITY_NAME", "cleaned_nl": ["new", "york"]}
{"sql": ["'NEW YORK'"], "type": "CITY_NAME", "cleaned_nl": ["new", "york", "city"]}
{"sql": ["'OAKLAND'"], "type": "CITY_NAME", "cleaned_nl": ["oakland"]}
{"sql": ["'ONTARIO'"], "type": "CITY_NAME", "cleaned_nl": ["ontario"]}
{"sql": ["'ORLANDO'"], "type": "CITY_NAME", "cleaned_nl": ["orlando"]}
{"sql": ["'PHILADELPHIA'"], "type": "CITY_NAME", "cleaned_nl": ["philadelphia"]}
{"sql": ["'PHOENIX'"], "type": "CITY_NAME", "cleaned_nl": ["phoenix"]}
{"sql": ["'PITTSBURGH'"], "type": "CITY_NAME", "cleaned_nl": ["pittsburgh"]}
{"sql": ["'ST. PAUL'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "paul"]}
{"sql": ["'ST. PAUL'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "paul"]}
{"sql": ["'SAN DIEGO'"], "type": "CITY_NAME", "cleaned_nl": ["san", "diego"]}
{"sql": ["'SEATTLE'"], "type": "CITY_NAME", "cleaned_nl": ["seattle"]}
{"sql": ["'SAN FRANCISCO'"], "type": "CITY_NAME", "cleaned_nl": ["san", "francisco"]}
{"sql": ["'SAN JOSE'"], "type": "CITY_NAME", "cleaned_nl": ["san", "jose"]}
{"sql": ["'SALT LAKE CITY'"], "type": "CITY_NAME", "cleaned_nl": ["salt", "lake", "city"]}
{"sql": ["'ST. LOUIS'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "louis"]}
{"sql": ["'ST. LOUIS'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "louis"]}
{"sql": ["'ST. PETERSBURG'"], "type": "CITY_NAME", "cleaned_nl": ["saint", "petersburg"]}
{"sql": ["'ST. PETERSBURG'"], "type": "CITY_NAME", "cleaned_nl": ["st.", "petersburg"]}
{"sql": ["'TACOMA'"], "type": "CITY_NAME", "cleaned_nl": ["tacoma"]}
{"sql": ["'TAMPA'"], "type": "CITY_NAME", "cleaned_nl": ["tampa"]}
{"sql": ["'WASHINGTON'"], "type": "CITY_NAME", "cleaned_nl": ["washington"]}
{"sql": ["'MONTREAL'"], "type": "CITY_NAME", "cleaned_nl": ["montreal"]}
{"sql": ["'TORONTO'"], "type": "CITY_NAME", "cleaned_nl": ["toronto"]}
{"sql": ["1"], "type": "MONTH_NUMBER", "cleaned_nl": ["january"]}
{"sql": ["2"], "type": "MONTH_NUMBER", "cleaned_nl": ["february"]}
{"sql": ["3"], "type": "MONTH_NUMBER", "cleaned_nl": ["march"]}
{"sql": ["4"], "type": "MONTH_NUMBER", "cleaned_nl": ["april"]}
{"sql": ["5"], "type": "MONTH_NUMBER", "cleaned_nl": ["may"]}
{"sql": ["6"], "type": "MONTH_NUMBER", "cleaned_nl": ["june"]}
{"sql": ["7"], "type": "MONTH_NUMBER", "cleaned_nl": ["july"]}
{"sql": ["8"], "type": "MONTH_NUMBER", "cleaned_nl": ["august"]}
{"sql": ["9"], "type": "MONTH_NUMBER", "cleaned_nl": ["september"]}
{"sql": ["10"], "type": "MONTH_NUMBER", "cleaned_nl": ["october"]}
{"sql": ["11"], "type": "MONTH_NUMBER", "cleaned_nl": ["november"]}
{"sql": ["12"], "type": "MONTH_NUMBER", "cleaned_nl": ["december"]}
{"sql": ["'MONDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["monday"]}
{"sql": ["'TUESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["tuesday"]}
{"sql": ["'WEDNESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["wednesday"]}
{"sql": ["'THURSDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["thursday"]}
{"sql": ["'FRIDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["friday"]}
{"sql": ["'SATURDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["saturday"]}
{"sql": ["'SUNDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["sunday"]}
{"sql": ["'MONDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["mondays"]}
{"sql": ["'TUESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["tuesdays"]}
{"sql": ["'WEDNESDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["wednesdays"]}
{"sql": ["'THURSDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["thursdays"]}
{"sql": ["'FRIDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["fridays"]}
{"sql": ["'SATURDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["saturdays"]}
{"sql": ["'SUNDAY'"], "type": "DAY_OF_WEEK", "cleaned_nl": ["sundays"]}
{"sql": ["1"], "type": "DAY_NUMBER", "cleaned_nl": ["first"]}
{"sql": ["2"], "type": "DAY_NUMBER", "cleaned_nl": ["second"]}
{"sql": ["3"], "type": "DAY_NUMBER", "cleaned_nl": ["third"]}
{"sql": ["4"], "type": "DAY_NUMBER", "cleaned_nl": ["fourth"]}
{"sql": ["5"], "type": "DAY_NUMBER", "cleaned_nl": ["fifth"]}
{"sql": ["6"], "type": "DAY_NUMBER", "cleaned_nl": ["sixth"]}
{"sql": ["7"], "type": "DAY_NUMBER", "cleaned_nl": ["seventh"]}
{"sql": ["8"], "type": "DAY_NUMBER", "cleaned_nl": ["eighth"]}
{"sql": ["9"], "type": "DAY_NUMBER", "cleaned_nl": ["ninth"]}
{"sql": ["10"], "type": "DAY_NUMBER", "cleaned_nl": ["tenth"]}
{"sql": ["11"], "type": "DAY_NUMBER", "cleaned_nl": ["eleventh"]}
{"sql": ["12"], "type": "DAY_NUMBER", "cleaned_nl": ["twelfth"]}
{"sql": ["13"], "type": "DAY_NUMBER", "cleaned_nl": ["thirteenth"]}
{"sql": ["14"], "type": "DAY_NUMBER", "cleaned_nl": ["fourteenth"]}
{"sql": ["15"], "type": "DAY_NUMBER", "cleaned_nl": ["fifteenth"]}
{"sql": ["16"], "type": "DAY_NUMBER", "cleaned_nl": ["sixteenth"]}
{"sql": ["17"], "type": "DAY_NUMBER", "cleaned_nl": ["seventeenth"]}
{"sql": ["18"], "type": "DAY_NUMBER", "cleaned_nl": ["eighteenth"]}
{"sql": ["19"], "type": "DAY_NUMBER", "cleaned_nl": ["nineteenth"]}
{"sql": ["20"], "type": "DAY_NUMBER", "cleaned_nl": ["twentieth"]}
{"sql": ["21"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "first"]}
{"sql": ["22"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "second"]}
{"sql": ["23"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "third"]}
{"sql": ["24"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "fourth"]}
{"sql": ["25"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "fifth"]}
{"sql": ["26"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "sixth"]}
{"sql": ["27"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "seventh"]}
{"sql": ["28"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "eighth"]}
{"sql": ["29"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "ninth"]}
{"sql": ["30"], "type": "DAY_NUMBER", "cleaned_nl": ["thirtieth"]}
{"sql": ["31"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty", "first"]}
{"sql": ["2"], "type": "DAY_NUMBER", "cleaned_nl": ["two"]}
{"sql": ["3"], "type": "DAY_NUMBER", "cleaned_nl": ["three"]}
{"sql": ["4"], "type": "DAY_NUMBER", "cleaned_nl": ["four"]}
{"sql": ["5"], "type": "DAY_NUMBER", "cleaned_nl": ["five"]}
{"sql": ["6"], "type": "DAY_NUMBER", "cleaned_nl": ["six"]}
{"sql": ["7"], "type": "DAY_NUMBER", "cleaned_nl": ["seven"]}
{"sql": ["8"], "type": "DAY_NUMBER", "cleaned_nl": ["eight"]}
{"sql": ["9"], "type": "DAY_NUMBER", "cleaned_nl": ["nine"]}
{"sql": ["10"], "type": "DAY_NUMBER", "cleaned_nl": ["ten"]}
{"sql": ["11"], "type": "DAY_NUMBER", "cleaned_nl": ["eleven"]}
{"sql": ["12"], "type": "DAY_NUMBER", "cleaned_nl": ["twelve"]}
{"sql": ["13"], "type": "DAY_NUMBER", "cleaned_nl": ["thirteen"]}
{"sql": ["14"], "type": "DAY_NUMBER", "cleaned_nl": ["fourteen"]}
{"sql": ["15"], "type": "DAY_NUMBER", "cleaned_nl": ["fifteen"]}
{"sql": ["16"], "type": "DAY_NUMBER", "cleaned_nl": ["sixteen"]}
{"sql": ["17"], "type": "DAY_NUMBER", "cleaned_nl": ["seventeen"]}
{"sql": ["18"], "type": "DAY_NUMBER", "cleaned_nl": ["eighteen"]}
{"sql": ["19"], "type": "DAY_NUMBER", "cleaned_nl": ["nineteen"]}
{"sql": ["20"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty"]}
{"sql": ["21"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "one"]}
{"sql": ["22"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "two"]}
{"sql": ["23"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "three"]}
{"sql": ["24"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "four"]}
{"sql": ["25"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "five"]}
{"sql": ["26"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "six"]}
{"sql": ["27"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "seven"]}
{"sql": ["28"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "eight"]}
{"sql": ["29"], "type": "DAY_NUMBER", "cleaned_nl": ["twenty", "nine"]}
{"sql": ["30"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty"]}
{"sql": ["31"], "type": "DAY_NUMBER", "cleaned_nl": ["thirty", "one"]}
{"sql": ["'COACH'"], "type": "FARE_TYPE", "cleaned_nl": ["coach"]}
{"sql": ["'BUSINESS'"], "type": "FARE_TYPE", "cleaned_nl": ["business"]}
{"sql": ["'FIRST'"], "type": "FARE_TYPE", "cleaned_nl": ["first", "class"]}
{"sql": ["'THRIFT'"], "type": "FARE_TYPE", "cleaned_nl": ["thrift"]}
{"sql": ["'STANDARD'"], "type": "FARE_TYPE", "cleaned_nl": ["standard"]}
{"sql": ["'SHUTTLE'"], "type": "FARE_TYPE", "cleaned_nl": ["shuttle"]}
{"sql": ["'B'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["b"]}
{"sql": ["'BH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bh"]}
{"sql": ["'BHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bhw"]}
{"sql": ["'BHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bhx"]}
{"sql": ["'BL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bl"]}
{"sql": ["'BLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["blw"]}
{"sql": ["'BLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["blx"]}
{"sql": ["'BN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bn"]}
{"sql": ["'BOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bow"]}
{"sql": ["'BOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["box"]}
{"sql": ["'BW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bw"]}
{"sql": ["'BX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["bx"]}
{"sql": ["'C'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["c"]}
{"sql": ["'CN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["cn"]}
{"sql": ["'F'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["f"]}
{"sql": ["'FN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["fn"]}
{"sql": ["'H'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["h"]}
{"sql": ["'HH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hh"]}
{"sql": ["'HHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hhw"]}
{"sql": ["'HHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hhx"]}
{"sql": ["'HL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hl"]}
{"sql": ["'HLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hlw"]}
{"sql": ["'HLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hlx"]}
{"sql": ["'HOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["hox"]}
{"sql": ["'J'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["j"]}
{"sql": ["'K'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["k"]}
{"sql": ["'KH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kh"]}
{"sql": ["'KL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kl"]}
{"sql": ["'KN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["kn"]}
{"sql": ["'LX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["lx"]}
{"sql": ["'M'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["m"]}
{"sql": ["'MH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["mh"]}
{"sql": ["'ML'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["ml"]}
{"sql": ["'MOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["mow"]}
{"sql": ["'P'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["p"]}
{"sql": ["'Q'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["q"]}
{"sql": ["'QH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qh"]}
{"sql": ["'QHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qhw"]}
{"sql": ["'QHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qhx"]}
{"sql": ["'QLW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qlw"]}
{"sql": ["'QLX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qlx"]}
{"sql": ["'QO'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qo"]}
{"sql": ["'QOW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qow"]}
{"sql": ["'QOX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qox"]}
{"sql": ["'QW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qw"]}
{"sql": ["'QX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["qx"]}
{"sql": ["'S'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["s"]}
{"sql": ["'U'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["u"]}
{"sql": ["'V'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["v"]}
{"sql": ["'VHW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vhw"]}
{"sql": ["'VHX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vhx"]}
{"sql": ["'VW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vw"]}
{"sql": ["'VX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["vx"]}
{"sql": ["'Y'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["y"]}
{"sql": ["'YH'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yh"]}
{"sql": ["'YL'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yl"]}
{"sql": ["'YN'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yn"]}
{"sql": ["'YW'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yw"]}
{"sql": ["'YX'"], "type": "FARE_BASIS_CODE", "cleaned_nl": ["yx"]}
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["rental", "car"]}
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["rental", "cars"]}
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car", "rental"]}
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car", "rentals"]}
{"sql": ["'RENTAL CAR'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["car"]}
{"sql": ["'AIR TAXI OPERATION'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["air", "taxi", "operation"]}
{"sql": ["'LIMOUSINE'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["limousine"]}
{"sql": ["'TAXI'"], "type": "TRANSPORT_TYPE", "cleaned_nl": ["taxi"]}
{"sql": ["'ARIZONA'"], "type": "STATE_NAME", "cleaned_nl": ["arizona"]}
{"sql": ["'CALIFORNIA'"], "type": "STATE_NAME", "cleaned_nl": ["california"]}
{"sql": ["'COLORADO'"], "type": "STATE_NAME", "cleaned_nl": ["colorado"]}
{"sql": ["'DISTRICT OF COLUMBIA'"], "type": "STATE_NAME", "cleaned_nl": ["district", "of", "columbia"]}
{"sql": ["'DISTRICT OF COLUMBIA'"], "type": "STATE_NAME", "cleaned_nl": ["dc"]}
{"sql": ["'DC'"], "type": "STATE_NAME", "cleaned_nl": ["dc"]}
{"sql": ["'FLORIDA'"], "type": "STATE_NAME", "cleaned_nl": ["florida"]}
{"sql": ["'GEORGIA'"], "type": "STATE_NAME", "cleaned_nl": ["georgia"]}
{"sql": ["'ILLINOIS'"], "type": "STATE_NAME", "cleaned_nl": ["illinois"]}
{"sql": ["'INDIANA'"], "type": "STATE_NAME", "cleaned_nl": ["indiana"]}
{"sql": ["'MASSACHUSETTS'"], "type": "STATE_NAME", "cleaned_nl": ["massachusetts"]}
{"sql": ["'MARYLAND'"], "type": "STATE_NAME", "cleaned_nl": ["maryland"]}
{"sql": ["'MICHIGAN'"], "type": "STATE_NAME", "cleaned_nl": ["michigan"]}
{"sql": ["'MINNESOTA'"], "type": "STATE_NAME", "cleaned_nl": ["minnesota"]}
{"sql": ["'MISSOURI'"], "type": "STATE_NAME", "cleaned_nl": ["missouri"]}
{"sql": ["'NORTH CAROLINA'"], "type": "STATE_NAME", "cleaned_nl": ["north", "carolina"]}
{"sql": ["'NEW JERSEY'"], "type": "STATE_NAME", "cleaned_nl": ["new", "jersey"]}
{"sql": ["'NEVADA'"], "type": "STATE_NAME", "cleaned_nl": ["nevada"]}
{"sql": ["'NEW YORK'"], "type": "STATE_NAME", "cleaned_nl": ["new", "york"]}
{"sql": ["'OHIO'"], "type": "STATE_NAME", "cleaned_nl": ["ohio"]}
{"sql": ["'ONTARIO'"], "type": "STATE_NAME", "cleaned_nl": ["ontario"]}
{"sql": ["'PENNSYLVANIA'"], "type": "STATE_NAME", "cleaned_nl": ["pennsylvania"]}
{"sql": ["'QUEBEC'"], "type": "STATE_NAME", "cleaned_nl": ["quebec"]}
{"sql": ["'TENNESSEE'"], "type": "STATE_NAME", "cleaned_nl": ["tennessee"]}
{"sql": ["'TEXAS'"], "type": "STATE_NAME", "cleaned_nl": ["texas"]}
{"sql": ["'UTAH'"], "type": "STATE_NAME", "cleaned_nl": ["utah"]}
{"sql": ["'WASHINGTON'"], "type": "STATE_NAME", "cleaned_nl": ["washington"]}
{"sql": ["'WISCONSIN'"], "type": "STATE_NAME", "cleaned_nl": ["wisconsin"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1pm"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1am"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1", "o'clock", "am"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1", "o'clock", "pm"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["1", "o'clock"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1", "o'clock"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["100"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["100"]}
{"sql": ["100"], "type": "TIME", "cleaned_nl": ["100am"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["100pm"]}
{"sql": ["115"], "type": "TIME", "cleaned_nl": ["115"]}
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["115"]}
{"sql": ["115"], "type": "TIME", "cleaned_nl": ["115am"]}
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["115pm"]}
{"sql": ["130"], "type": "TIME", "cleaned_nl": ["130"]}
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["130"]}
{"sql": ["130"], "type": "TIME", "cleaned_nl": ["130am"]}
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["130pm"]}
{"sql": ["145"], "type": "TIME", "cleaned_nl": ["145"]}
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["145"]}
{"sql": ["145"], "type": "TIME", "cleaned_nl": ["145am"]}
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["145pm"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2pm"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2am"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2", "o'clock", "am"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2", "o'clock", "pm"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["2", "o'clock"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["2", "o'clock"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["200"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["200"]}
{"sql": ["200"], "type": "TIME", "cleaned_nl": ["200am"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["200pm"]}
{"sql": ["215"], "type": "TIME", "cleaned_nl": ["215"]}
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["215"]}
{"sql": ["215"], "type": "TIME", "cleaned_nl": ["215am"]}
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["215pm"]}
{"sql": ["230"], "type": "TIME", "cleaned_nl": ["230"]}
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["230"]}
{"sql": ["230"], "type": "TIME", "cleaned_nl": ["230am"]}
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["230pm"]}
{"sql": ["245"], "type": "TIME", "cleaned_nl": ["245"]}
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["245"]}
{"sql": ["245"], "type": "TIME", "cleaned_nl": ["245am"]}
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["245pm"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3pm"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3am"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3", "o'clock", "am"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3", "o'clock", "pm"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["3", "o'clock"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["3", "o'clock"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["300"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["300"]}
{"sql": ["300"], "type": "TIME", "cleaned_nl": ["300am"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["300pm"]}
{"sql": ["315"], "type": "TIME", "cleaned_nl": ["315"]}
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["315"]}
{"sql": ["315"], "type": "TIME", "cleaned_nl": ["315am"]}
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["315pm"]}
{"sql": ["330"], "type": "TIME", "cleaned_nl": ["330"]}
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["330"]}
{"sql": ["330"], "type": "TIME", "cleaned_nl": ["330am"]}
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["330pm"]}
{"sql": ["345"], "type": "TIME", "cleaned_nl": ["345"]}
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["345"]}
{"sql": ["345"], "type": "TIME", "cleaned_nl": ["345am"]}
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["345pm"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4pm"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4am"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4", "o'clock", "am"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4", "o'clock", "pm"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["4", "o'clock"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["4", "o'clock"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["400"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["400"]}
{"sql": ["400"], "type": "TIME", "cleaned_nl": ["400am"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["400pm"]}
{"sql": ["415"], "type": "TIME", "cleaned_nl": ["415"]}
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["415"]}
{"sql": ["415"], "type": "TIME", "cleaned_nl": ["415am"]}
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["415pm"]}
{"sql": ["430"], "type": "TIME", "cleaned_nl": ["430"]}
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["430"]}
{"sql": ["430"], "type": "TIME", "cleaned_nl": ["430am"]}
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["430pm"]}
{"sql": ["445"], "type": "TIME", "cleaned_nl": ["445"]}
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["445"]}
{"sql": ["445"], "type": "TIME", "cleaned_nl": ["445am"]}
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["445pm"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5pm"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5am"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5", "o'clock", "am"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5", "o'clock", "pm"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["5", "o'clock"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["5", "o'clock"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["500"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["500"]}
{"sql": ["500"], "type": "TIME", "cleaned_nl": ["500am"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["500pm"]}
{"sql": ["515"], "type": "TIME", "cleaned_nl": ["515"]}
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["515"]}
{"sql": ["515"], "type": "TIME", "cleaned_nl": ["515am"]}
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["515pm"]}
{"sql": ["530"], "type": "TIME", "cleaned_nl": ["530"]}
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["530"]}
{"sql": ["530"], "type": "TIME", "cleaned_nl": ["530am"]}
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["530pm"]}
{"sql": ["545"], "type": "TIME", "cleaned_nl": ["545"]}
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["545"]}
{"sql": ["545"], "type": "TIME", "cleaned_nl": ["545am"]}
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["545pm"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6pm"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6am"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6", "o'clock", "am"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6", "o'clock", "pm"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["6", "o'clock"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["6", "o'clock"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["600"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["600"]}
{"sql": ["600"], "type": "TIME", "cleaned_nl": ["600am"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["600pm"]}
{"sql": ["615"], "type": "TIME", "cleaned_nl": ["615"]}
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["615"]}
{"sql": ["615"], "type": "TIME", "cleaned_nl": ["615am"]}
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["615pm"]}
{"sql": ["630"], "type": "TIME", "cleaned_nl": ["630"]}
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["630"]}
{"sql": ["630"], "type": "TIME", "cleaned_nl": ["630am"]}
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["630pm"]}
{"sql": ["645"], "type": "TIME", "cleaned_nl": ["645"]}
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["645"]}
{"sql": ["645"], "type": "TIME", "cleaned_nl": ["645am"]}
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["645pm"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7pm"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7am"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7", "o'clock", "am"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7", "o'clock", "pm"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["7", "o'clock"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["7", "o'clock"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["700"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["700"]}
{"sql": ["700"], "type": "TIME", "cleaned_nl": ["700am"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["700pm"]}
{"sql": ["715"], "type": "TIME", "cleaned_nl": ["715"]}
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["715"]}
{"sql": ["715"], "type": "TIME", "cleaned_nl": ["715am"]}
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["715pm"]}
{"sql": ["730"], "type": "TIME", "cleaned_nl": ["730"]}
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["730"]}
{"sql": ["730"], "type": "TIME", "cleaned_nl": ["730am"]}
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["730pm"]}
{"sql": ["745"], "type": "TIME", "cleaned_nl": ["745"]}
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["745"]}
{"sql": ["745"], "type": "TIME", "cleaned_nl": ["745am"]}
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["745pm"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8pm"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8am"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8", "o'clock", "am"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8", "o'clock", "pm"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["8", "o'clock"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["8", "o'clock"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["800"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["800"]}
{"sql": ["800"], "type": "TIME", "cleaned_nl": ["800am"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["800pm"]}
{"sql": ["815"], "type": "TIME", "cleaned_nl": ["815"]}
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["815"]}
{"sql": ["815"], "type": "TIME", "cleaned_nl": ["815am"]}
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["815pm"]}
{"sql": ["830"], "type": "TIME", "cleaned_nl": ["830"]}
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["830"]}
{"sql": ["830"], "type": "TIME", "cleaned_nl": ["830am"]}
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["830pm"]}
{"sql": ["845"], "type": "TIME", "cleaned_nl": ["845"]}
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["845"]}
{"sql": ["845"], "type": "TIME", "cleaned_nl": ["845am"]}
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["845pm"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9pm"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9am"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9", "o'clock", "am"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9", "o'clock", "pm"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["9", "o'clock"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["9", "o'clock"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["900"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["900"]}
{"sql": ["900"], "type": "TIME", "cleaned_nl": ["900am"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["900pm"]}
{"sql": ["915"], "type": "TIME", "cleaned_nl": ["915"]}
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["915"]}
{"sql": ["915"], "type": "TIME", "cleaned_nl": ["915am"]}
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["915pm"]}
{"sql": ["930"], "type": "TIME", "cleaned_nl": ["930"]}
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["930"]}
{"sql": ["930"], "type": "TIME", "cleaned_nl": ["930am"]}
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["930pm"]}
{"sql": ["945"], "type": "TIME", "cleaned_nl": ["945"]}
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["945"]}
{"sql": ["945"], "type": "TIME", "cleaned_nl": ["945am"]}
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["945pm"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10pm"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10am"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10", "o'clock", "am"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10", "o'clock", "pm"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["10", "o'clock"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["10", "o'clock"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["1000"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["1000"]}
{"sql": ["1000"], "type": "TIME", "cleaned_nl": ["1000am"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["1000pm"]}
{"sql": ["1015"], "type": "TIME", "cleaned_nl": ["1015"]}
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["1015"]}
{"sql": ["1015"], "type": "TIME", "cleaned_nl": ["1015am"]}
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["1015pm"]}
{"sql": ["1030"], "type": "TIME", "cleaned_nl": ["1030"]}
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["1030"]}
{"sql": ["1030"], "type": "TIME", "cleaned_nl": ["1030am"]}
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["1030pm"]}
{"sql": ["1045"], "type": "TIME", "cleaned_nl": ["1045"]}
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["1045"]}
{"sql": ["1045"], "type": "TIME", "cleaned_nl": ["1045am"]}
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["1045pm"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11pm"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11am"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11", "o'clock", "am"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11", "o'clock", "pm"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["11", "o'clock"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["11", "o'clock"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["1100"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["1100"]}
{"sql": ["1100"], "type": "TIME", "cleaned_nl": ["1100am"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["1100pm"]}
{"sql": ["1115"], "type": "TIME", "cleaned_nl": ["1115"]}
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["1115"]}
{"sql": ["1115"], "type": "TIME", "cleaned_nl": ["1115am"]}
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["1115pm"]}
{"sql": ["1130"], "type": "TIME", "cleaned_nl": ["1130"]}
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["1130"]}
{"sql": ["1130"], "type": "TIME", "cleaned_nl": ["1130am"]}
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["1130pm"]}
{"sql": ["1145"], "type": "TIME", "cleaned_nl": ["1145"]}
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["1145"]}
{"sql": ["1145"], "type": "TIME", "cleaned_nl": ["1145am"]}
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["1145pm"]}
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["12"]}
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["1200"]}
{"sql": ["2400"], "type": "TIME", "cleaned_nl": ["1200"]}
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["1200am"]}
{"sql": ["2400"], "type": "TIME", "cleaned_nl": ["1200pm"]}
{"sql": ["1215"], "type": "TIME", "cleaned_nl": ["1215"]}
{"sql": ["2415"], "type": "TIME", "cleaned_nl": ["1215"]}
{"sql": ["1215"], "type": "TIME", "cleaned_nl": ["1215am"]}
{"sql": ["2415"], "type": "TIME", "cleaned_nl": ["1215pm"]}
{"sql": ["1230"], "type": "TIME", "cleaned_nl": ["1230"]}
{"sql": ["2430"], "type": "TIME", "cleaned_nl": ["1230"]}
{"sql": ["1230"], "type": "TIME", "cleaned_nl": ["1230am"]}
{"sql": ["2430"], "type": "TIME", "cleaned_nl": ["1230pm"]}
{"sql": ["1245"], "type": "TIME", "cleaned_nl": ["1245"]}
{"sql": ["2445"], "type": "TIME", "cleaned_nl": ["1245"]}
{"sql": ["1245"], "type": "TIME", "cleaned_nl": ["1245am"]}
{"sql": ["2445"], "type": "TIME", "cleaned_nl": ["1245pm"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["13"]}
{"sql": ["1300"], "type": "TIME", "cleaned_nl": ["1300"]}
{"sql": ["1315"], "type": "TIME", "cleaned_nl": ["1315"]}
{"sql": ["1330"], "type": "TIME", "cleaned_nl": ["1330"]}
{"sql": ["1345"], "type": "TIME", "cleaned_nl": ["1345"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["14"]}
{"sql": ["1400"], "type": "TIME", "cleaned_nl": ["1400"]}
{"sql": ["1415"], "type": "TIME", "cleaned_nl": ["1415"]}
{"sql": ["1430"], "type": "TIME", "cleaned_nl": ["1430"]}
{"sql": ["1445"], "type": "TIME", "cleaned_nl": ["1445"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["15"]}
{"sql": ["1500"], "type": "TIME", "cleaned_nl": ["1500"]}
{"sql": ["1515"], "type": "TIME", "cleaned_nl": ["1515"]}
{"sql": ["1530"], "type": "TIME", "cleaned_nl": ["1530"]}
{"sql": ["1545"], "type": "TIME", "cleaned_nl": ["1545"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["16"]}
{"sql": ["1600"], "type": "TIME", "cleaned_nl": ["1600"]}
{"sql": ["1615"], "type": "TIME", "cleaned_nl": ["1615"]}
{"sql": ["1630"], "type": "TIME", "cleaned_nl": ["1630"]}
{"sql": ["1645"], "type": "TIME", "cleaned_nl": ["1645"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["17"]}
{"sql": ["1700"], "type": "TIME", "cleaned_nl": ["1700"]}
{"sql": ["1715"], "type": "TIME", "cleaned_nl": ["1715"]}
{"sql": ["1730"], "type": "TIME", "cleaned_nl": ["1730"]}
{"sql": ["1745"], "type": "TIME", "cleaned_nl": ["1745"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["18"]}
{"sql": ["1800"], "type": "TIME", "cleaned_nl": ["1800"]}
{"sql": ["1815"], "type": "TIME", "cleaned_nl": ["1815"]}
{"sql": ["1830"], "type": "TIME", "cleaned_nl": ["1830"]}
{"sql": ["1845"], "type": "TIME", "cleaned_nl": ["1845"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["19"]}
{"sql": ["1900"], "type": "TIME", "cleaned_nl": ["1900"]}
{"sql": ["1915"], "type": "TIME", "cleaned_nl": ["1915"]}
{"sql": ["1930"], "type": "TIME", "cleaned_nl": ["1930"]}
{"sql": ["1945"], "type": "TIME", "cleaned_nl": ["1945"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["20"]}
{"sql": ["2000"], "type": "TIME", "cleaned_nl": ["2000"]}
{"sql": ["2015"], "type": "TIME", "cleaned_nl": ["2015"]}
{"sql": ["2030"], "type": "TIME", "cleaned_nl": ["2030"]}
{"sql": ["2045"], "type": "TIME", "cleaned_nl": ["2045"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["21"]}
{"sql": ["2100"], "type": "TIME", "cleaned_nl": ["2100"]}
{"sql": ["2115"], "type": "TIME", "cleaned_nl": ["2115"]}
{"sql": ["2130"], "type": "TIME", "cleaned_nl": ["2130"]}
{"sql": ["2145"], "type": "TIME", "cleaned_nl": ["2145"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["22"]}
{"sql": ["2200"], "type": "TIME", "cleaned_nl": ["2200"]}
{"sql": ["2215"], "type": "TIME", "cleaned_nl": ["2215"]}
{"sql": ["2230"], "type": "TIME", "cleaned_nl": ["2230"]}
{"sql": ["2245"], "type": "TIME", "cleaned_nl": ["2245"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["23"]}
{"sql": ["2300"], "type": "TIME", "cleaned_nl": ["2300"]}
{"sql": ["2315"], "type": "TIME", "cleaned_nl": ["2315"]}
{"sql": ["2330"], "type": "TIME", "cleaned_nl": ["2330"]}
{"sql": ["2345"], "type": "TIME", "cleaned_nl": ["2345"]}
{"sql": ["1200"], "type": "TIME", "cleaned_nl": ["noon"]}
{"sql": ["'BREAKFAST'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["breakfast"]}
{"sql": ["'LUNCH'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["lunch"]}
{"sql": ["'DINNER'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["dinner"]}
{"sql": ["'SNACK'"], "type": "MEAL_DESCRIPTION", "cleaned_nl": ["snack"]}
{"sql": ["'ATL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["atl"]}
{"sql": ["'ATL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["atlanta", "airport"]}
{"sql": ["'BNA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bna"]}
{"sql": ["'BOS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bos"]}
{"sql": ["'BOS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["boston", "airport"]}
{"sql": ["'BUR'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bur"]}
{"sql": ["'BWI'"], "type": "AIRPORT_CODE", "cleaned_nl": ["bwi"]}
{"sql": ["'CLE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cle"]}
{"sql": ["'CLT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["clt"]}
{"sql": ["'CMH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cmh"]}
{"sql": ["'CVG'"], "type": "AIRPORT_CODE", "cleaned_nl": ["cvg"]}
{"sql": ["'DAL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dal"]}
{"sql": ["'DAL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["love", "field"]}
{"sql": ["'DCA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dca"]}
{"sql": ["'DEN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["den"]}
{"sql": ["'DEN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["denver", "airport"]}
{"sql": ["'DET'"], "type": "AIRPORT_CODE", "cleaned_nl": ["det"]}
{"sql": ["'DFW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dfw"]}
{"sql": ["'DFW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dallas", "airport"]}
{"sql": ["'DTW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["dtw"]}
{"sql": ["'EWR'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ewr"]}
{"sql": ["'HOU'"], "type": "AIRPORT_CODE", "cleaned_nl": ["hou"]}
{"sql": ["'HPN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["hpn"]}
{"sql": ["'IAD'"], "type": "AIRPORT_CODE", "cleaned_nl": ["iad"]}
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["iah"]}
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "international"]}
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "international", "airport"]}
{"sql": ["'IAH'"], "type": "AIRPORT_CODE", "cleaned_nl": ["houston", "intercontinental", "airport"]}
{"sql": ["'IND'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ind"]}
{"sql": ["'JFK'"], "type": "AIRPORT_CODE", "cleaned_nl": ["jfk"]}
{"sql": ["'LAS'"], "type": "AIRPORT_CODE", "cleaned_nl": ["las"]}
{"sql": ["'LAX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lax"]}
{"sql": ["'LAX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["los", "angeles", "international", "airport"]}
{"sql": ["'LGA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lga"]}
{"sql": ["'LGA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["la", "guardia"]}
{"sql": ["'LGB'"], "type": "AIRPORT_CODE", "cleaned_nl": ["lgb"]}
{"sql": ["'MCI'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mci"]}
{"sql": ["'MCO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mco"]}
{"sql": ["'MDW'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mdw"]}
{"sql": ["'MEM'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mem"]}
{"sql": ["'MIA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mia"]}
{"sql": ["'MKE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["mke"]}
{"sql": ["'MKE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["general", "mitchell", "airport"]}
{"sql": ["'MSP'"], "type": "AIRPORT_CODE", "cleaned_nl": ["msp"]}
{"sql": ["'OAK'"], "type": "AIRPORT_CODE", "cleaned_nl": ["oak"]}
{"sql": ["'ONT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ont"]}
{"sql": ["'ORD'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ord"]}
{"sql": ["'PHL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["phl"]}
{"sql": ["'PHX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["phx"]}
{"sql": ["'PIE'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pie"]}
{"sql": ["'PIT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pit"]}
{"sql": ["'PIT'"], "type": "AIRPORT_CODE", "cleaned_nl": ["pittsburgh", "airport"]}
{"sql": ["'SAN'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san"]}
{"sql": ["'SEA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sea"]}
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sfo"]}
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san", "francisco", "international"]}
{"sql": ["'SFO'"], "type": "AIRPORT_CODE", "cleaned_nl": ["san", "francisco", "international", "airport"]}
{"sql": ["'SJC'"], "type": "AIRPORT_CODE", "cleaned_nl": ["sjc"]}
{"sql": ["'SLC'"], "type": "AIRPORT_CODE", "cleaned_nl": ["slc"]}
{"sql": ["'STL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["stl"]}
{"sql": ["'TPA'"], "type": "AIRPORT_CODE", "cleaned_nl": ["tpa"]}
{"sql": ["'YKZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ykz"]}
{"sql": ["'YMX'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ymx"]}
{"sql": ["'YTZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["ytz"]}
{"sql": ["'YUL'"], "type": "AIRPORT_CODE", "cleaned_nl": ["yul"]}
{"sql": ["'YYZ'"], "type": "AIRPORT_CODE", "cleaned_nl": ["yyz"]}
{"sql": ["'72S'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["72s"]}
{"sql": ["'734'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["734"]}
{"sql": ["'M80'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["m80"]}
{"sql": ["'D9S'"], "type": "AIRCRAFT_CODE", "cleaned_nl": ["d9s"]}
{"cleaned_nl": ["ap/2"], "type": "RESTRICTION_CODE", "sql": ["'AP/2'"]}
{"cleaned_nl": ["ap/6"], "type": "RESTRICTION_CODE", "sql": ["'AP/6'"]}
{"cleaned_nl": ["ap/12"], "type": "RESTRICTION_CODE", "sql": ["'AP/12'"]}
{"cleaned_nl": ["ap/20"], "type": "RESTRICTION_CODE", "sql": ["'AP/20'"]}
{"cleaned_nl": ["ap/21"], "type": "RESTRICTION_CODE", "sql": ["'AP/21'"]}
{"cleaned_nl": ["ap/57"], "type": "RESTRICTION_CODE", "sql": ["'AP/57'"]}
{"cleaned_nl": ["ap", "57"], "type": "RESTRICTION_CODE", "sql": ["'AP/57'"]}
{"cleaned_nl": ["ap/58"], "type": "RESTRICTION_CODE", "sql": ["'AP/58'"]}
{"cleaned_nl": ["ap/60"], "type": "RESTRICTION_CODE", "sql": ["'AP/60'"]}
{"cleaned_nl": ["ap/80"], "type": "RESTRICTION_CODE", "sql": ["'AP/80'"]}
{"cleaned_nl": ["ap", "80"], "type": "RESTRICTION_CODE", "sql": ["'AP/80'"]}
{"cleaned_nl": ["ap/75"], "type": "RESTRICTION_CODE", "sql": ["'AP/75'"]}
{"cleaned_nl": ["ex/9"], "type": "RESTRICTION_CODE", "sql": ["'EX/9'"]}
{"cleaned_nl": ["ex/13"], "type": "RESTRICTION_CODE", "sql": ["'EX/13'"]}
{"cleaned_nl": ["ex/14"], "type": "RESTRICTION_CODE", "sql": ["'EX/14'"]}
{"cleaned_nl": ["ex/17"], "type": "RESTRICTION_CODE", "sql": ["'EX/17'"]}
{"cleaned_nl": ["ex/19"], "type": "RESTRICTION_CODE", "sql": ["'EX/19'"]}

View File

@ -0,0 +1,3 @@
raw/atis2/12-1.1/ATIS2/TEXT/TRAIN/MIT/4L0/1
raw/atis2/12-1.1/ATIS2/TEXT/TEST/NOV92/770/5
raw/atis3_test/17_4.2/atis3/sp_tst/dev94/prelim/g0x/4

View File

@ -0,0 +1,902 @@
{"input": "northeast express regional airlines", "output": ["'2V'"]}
{"input": "air alliance", "output": ["'3J'"]}
{"input": "alpha air", "output": ["'7V'"]}
{"input": "express airlines i", "output": ["'9E'"]}
{"input": "colgan air", "output": ["'9L'"]}
{"input": "trans states airlines", "output": ["'9N'"]}
{"input": "ontario express", "output": ["'9X'"]}
{"input": "american airlines", "output": ["'AA'"]}
{"input": "american airline", "output": ["'AA'"]}
{"input": "american", "output": ["'AA'"]}
{"input": "air canada", "output": ["'AC'"]}
{"input": "aerolineas argentinas", "output": ["'AR'"]}
{"input": "alaska airlines", "output": ["'AS'"]}
{"input": "royal air maroc", "output": ["'AT'"]}
{"input": "british airways", "output": ["'BA'"]}
{"input": "braniff international airlines", "output": ["'BE'"]}
{"input": "continental airlines", "output": ["'CO'"]}
{"input": "continental", "output": ["'CO'"]}
{"input": "canadian airlines international", "output": ["'CP'"]}
{"input": "canadian", "output": ["'CP'"]}
{"input": "canadian airlines", "output": ["'CP'"]}
{"input": "atlantic coast airlines", "output": ["'DH'"]}
{"input": "delta airlines", "output": ["'DL'"]}
{"input": "delta", "output": ["'DL'"]}
{"input": "delta airline", "output": ["'DL'"]}
{"input": "atlantic southeast airlines", "output": ["'EV'"]}
{"input": "tower air", "output": ["'FF'"]}
{"input": "air ontario", "output": ["'GX'"]}
{"input": "america west airlines", "output": ["'HP'"]}
{"input": "america west", "output": ["'HP'"]}
{"input": "business express", "output": ["'HQ'"]}
{"input": "carnival air lines", "output": ["'KW'"]}
{"input": "lufthansa german airlines", "output": ["'LH'"]}
{"input": "lufthansa", "output": ["'LH'"]}
{"input": "lufthansa airlines", "output": ["'LH'"]}
{"input": "mgm grand air", "output": ["'MG'"]}
{"input": "northwest airlines", "output": ["'NW'"]}
{"input": "northwest airline", "output": ["'NW'"]}
{"input": "northwest", "output": ["'NW'"]}
{"input": "nationair", "output": ["'NX'"]}
{"input": "westair airlines", "output": ["'OE'"]}
{"input": "comair", "output": ["'OH'"]}
{"input": "czechoslovak airlines", "output": ["'OK'"]}
{"input": "sky west airlines", "output": ["'OO'"]}
{"input": "grand airways", "output": ["'QD'"]}
{"input": "precision airlines", "output": ["'RP'"]}
{"input": "trans world express", "output": ["'RZ'"]}
{"input": "sabena belgian world airlines", "output": ["'SN'"]}
{"input": "christman air system", "output": ["'SX'"]}
{"input": "thai airways international,ltd.", "output": ["'TG'"]}
{"input": "trans world airlines", "output": ["'TW'"]}
{"input": "twa", "output": ["'TW'"]}
{"input": "american trans air", "output": ["'TZ'"]}
{"input": "united airlines", "output": ["'UA'"]}
{"input": "united", "output": ["'UA'"]}
{"input": "usair", "output": ["'US'"]}
{"input": "us air", "output": ["'US'"]}
{"input": "southwest airlines", "output": ["'WN'"]}
{"input": "southwest air", "output": ["'WN'"]}
{"input": "southwest", "output": ["'WN'"]}
{"input": "mesaba aviation", "output": ["'XJ'"]}
{"input": "midwest express airlines", "output": ["'YX'"]}
{"input": "midwest express", "output": ["'YX'"]}
{"input": "air wisconsin", "output": ["'ZW'"]}
{"input": "2v", "output": ["'2V'"]}
{"input": "3j", "output": ["'3J'"]}
{"input": "7v", "output": ["'7V'"]}
{"input": "9e", "output": ["'9E'"]}
{"input": "9l", "output": ["'9L'"]}
{"input": "9n", "output": ["'9N'"]}
{"input": "9x", "output": ["'9X'"]}
{"input": "aa", "output": ["'AA'"]}
{"input": "ac", "output": ["'AC'"]}
{"input": "ar", "output": ["'AR'"]}
{"input": "as", "output": ["'AS'"]}
{"input": "at", "output": ["'AT'"]}
{"input": "ba", "output": ["'BA'"]}
{"input": "be", "output": ["'BE'"]}
{"input": "co", "output": ["'CO'"]}
{"input": "cp", "output": ["'CP'"]}
{"input": "dh", "output": ["'DH'"]}
{"input": "dl", "output": ["'DL'"]}
{"input": "ev", "output": ["'EV'"]}
{"input": "ff", "output": ["'FF'"]}
{"input": "gx", "output": ["'GX'"]}
{"input": "hp", "output": ["'HP'"]}
{"input": "hq", "output": ["'HQ'"]}
{"input": "kw", "output": ["'KW'"]}
{"input": "lh", "output": ["'LH'"]}
{"input": "mg", "output": ["'MG'"]}
{"input": "nw", "output": ["'NW'"]}
{"input": "nx", "output": ["'NX'"]}
{"input": "oe", "output": ["'OE'"]}
{"input": "oh", "output": ["'OH'"]}
{"input": "ok", "output": ["'OK'"]}
{"input": "oo", "output": ["'OO'"]}
{"input": "qd", "output": ["'QD'"]}
{"input": "rp", "output": ["'RP'"]}
{"input": "rz", "output": ["'RZ'"]}
{"input": "sn", "output": ["'SN'"]}
{"input": "sx", "output": ["'SX'"]}
{"input": "tg", "output": ["'TG'"]}
{"input": "tw", "output": ["'TW'"]}
{"input": "tz", "output": ["'TZ'"]}
{"input": "ua", "output": ["'UA'"]}
{"input": "us", "output": ["'US'"]}
{"input": "wn", "output": ["'WN'"]}
{"input": "xj", "output": ["'XJ'"]}
{"input": "yx", "output": ["'YX'"]}
{"input": "zw", "output": ["'ZW'"]}
{"input": "nashville", "output": ["'NASHVILLE'"]}
{"input": "boston", "output": ["'BOSTON'"]}
{"input": "burbank", "output": ["'BURBANK'"]}
{"input": "baltimore", "output": ["'BALTIMORE'"]}
{"input": "chicago", "output": ["'CHICAGO'"]}
{"input": "cleveland", "output": ["'CLEVELAND'"]}
{"input": "charlotte", "output": ["'CHARLOTTE'"]}
{"input": "columbus", "output": ["'COLUMBUS'"]}
{"input": "cincinnati", "output": ["'CINCINNATI'"]}
{"input": "denver", "output": ["'DENVER'"]}
{"input": "dallas", "output": ["'DALLAS'"]}
{"input": "detroit", "output": ["'DETROIT'"]}
{"input": "fort worth", "output": ["'FORT", "WORTH'"]}
{"input": "houston", "output": ["'HOUSTON'"]}
{"input": "westchester county", "output": ["'WESTCHESTER COUNTY'"]}
{"input": "westchester", "output": ["'WESTCHESTER COUNTY'"]}
{"input": "indianapolis", "output": ["'INDIANAPOLIS'"]}
{"input": "newark", "output": ["'NEWARK'"]}
{"input": "las vegas", "output": ["'LAS VEGAS'"]}
{"input": "los angeles", "output": ["'LOS ANGELES'"]}
{"input": "long beach", "output": ["'LONG BEACH'"]}
{"input": "atlanta", "output": ["'ATLANTA'"]}
{"input": "memphis", "output": ["'MEMPHIS'"]}
{"input": "miami", "output": ["'MIAMI'"]}
{"input": "kansas city", "output": ["'KANSAS CITY'"]}
{"input": "milwaukee", "output": ["'MILWAUKEE'"]}
{"input": "minneapolis", "output": ["'MINNEAPOLIS'"]}
{"input": "new york", "output": ["'NEW YORK'"]}
{"input": "new york city", "output": ["'NEW YORK'"]}
{"input": "oakland", "output": ["'OAKLAND'"]}
{"input": "ontario", "output": ["'ONTARIO'"]}
{"input": "orlando", "output": ["'ORLANDO'"]}
{"input": "philadelphia", "output": ["'PHILADELPHIA'"]}
{"input": "phoenix", "output": ["'PHOENIX'"]}
{"input": "pittsburgh", "output": ["'PITTSBURGH'"]}
{"input": "st. paul", "output": ["'ST. PAUL'"]}
{"input": "saint paul", "output": ["'ST. PAUL'"]}
{"input": "san diego", "output": ["'SAN DIEGO'"]}
{"input": "seattle", "output": ["'SEATTLE'"]}
{"input": "san francisco", "output": ["'SAN FRANCISCO'"]}
{"input": "san jose", "output": ["'SAN JOSE'"]}
{"input": "salt lake city", "output": ["'SALT LAKE CITY'"]}
{"input": "st. louis", "output": ["'ST. LOUIS'"]}
{"input": "saint louis", "output": ["'ST. LOUIS'"]}
{"input": "saint petersburg", "output": ["'ST. PETERSBURG'"]}
{"input": "st. petersburg", "output": ["'ST. PETERSBURG'"]}
{"input": "tacoma", "output": ["'TACOMA'"]}
{"input": "tampa", "output": ["'TAMPA'"]}
{"input": "washington", "output": ["'WASHINGTON'"]}
{"input": "montreal", "output": ["'MONTREAL'"]}
{"input": "toronto", "output": ["'TORONTO'"]}
{"input": "january", "output": ["1"]}
{"input": "february", "output": ["2"]}
{"input": "march", "output": ["3"]}
{"input": "april", "output": ["4"]}
{"input": "may", "output": ["5"]}
{"input": "june", "output": ["6"]}
{"input": "july", "output": ["7"]}
{"input": "august", "output": ["8"]}
{"input": "september", "output": ["9"]}
{"input": "october", "output": ["10"]}
{"input": "november", "output": ["11"]}
{"input": "december", "output": ["12"]}
{"input": "monday", "output": ["'MONDAY'"]}
{"input": "tuesday", "output": ["'TUESDAY'"]}
{"input": "wednesday", "output": ["'WEDNESDAY'"]}
{"input": "thursday", "output": ["'THURSDAY'"]}
{"input": "friday", "output": ["'FRIDAY'"]}
{"input": "saturday", "output": ["'SATURDAY'"]}
{"input": "sunday", "output": ["'SUNDAY'"]}
{"input": "first", "output": ["1"]}
{"input": "second", "output": ["2"]}
{"input": "third", "output": ["3"]}
{"input": "fourth", "output": ["4"]}
{"input": "fifth", "output": ["5"]}
{"input": "sixth", "output": ["6"]}
{"input": "seventh", "output": ["7"]}
{"input": "eighth", "output": ["8"]}
{"input": "ninth", "output": ["9"]}
{"input": "tenth", "output": ["10"]}
{"input": "eleventh", "output": ["11"]}
{"input": "twelfth", "output": ["12"]}
{"input": "thirteenth", "output": ["13"]}
{"input": "fourteenth", "output": ["14"]}
{"input": "fifteenth", "output": ["15"]}
{"input": "sixteenth", "output": ["16"]}
{"input": "seventeenth", "output": ["17"]}
{"input": "eighteenth", "output": ["18"]}
{"input": "nineteenth", "output": ["19"]}
{"input": "twentieth", "output": ["20"]}
{"input": "twenty first", "output": ["21"]}
{"input": "twenty second", "output": ["22"]}
{"input": "twenty third", "output": ["23"]}
{"input": "twenty fourth", "output": ["24"]}
{"input": "twenty fifth", "output": ["25"]}
{"input": "twenty sixth", "output": ["26"]}
{"input": "twenty seventh", "output": ["27"]}
{"input": "twenty eighth", "output": ["28"]}
{"input": "twenty ninth", "output": ["29"]}
{"input": "thirtieth", "output": ["30"]}
{"input": "thirty first", "output": ["31"]}
{"input": "coach", "output": ["'COACH'"]}
{"input": "business", "output": ["'BUSINESS'"]}
{"input": "first", "output": ["'FIRST'"]}
{"input": "first class", "output": ["'FIRST'"]}
{"input": "thrift", "output": ["'THRIFT'"]}
{"input": "standard", "output": ["'STANDARD'"]}
{"input": "shuttle", "output": ["'SHUTTLE'"]}
{"input": "b", "output": ["'B'"]}
{"input": "bh", "output": ["'BH'"]}
{"input": "bhw", "output": ["'BHW'"]}
{"input": "bhx", "output": ["'BHX'"]}
{"input": "bl", "output": ["'BL'"]}
{"input": "blw", "output": ["'BLW'"]}
{"input": "blx", "output": ["'BLX'"]}
{"input": "bn", "output": ["'BN'"]}
{"input": "bow", "output": ["'BOW'"]}
{"input": "box", "output": ["'BOX'"]}
{"input": "bw", "output": ["'BW'"]}
{"input": "bx", "output": ["'BX'"]}
{"input": "c", "output": ["'C'"]}
{"input": "cn", "output": ["'CN'"]}
{"input": "f", "output": ["'F'"]}
{"input": "fn", "output": ["'FN'"]}
{"input": "h", "output": ["'H'"]}
{"input": "hh", "output": ["'HH'"]}
{"input": "hhw", "output": ["'HHW'"]}
{"input": "hhx", "output": ["'HHX'"]}
{"input": "hl", "output": ["'HL'"]}
{"input": "hlw", "output": ["'HLW'"]}
{"input": "hlx", "output": ["'HLX'"]}
{"input": "how", "output": ["'HOW'"]}
{"input": "hox", "output": ["'HOX'"]}
{"input": "j", "output": ["'J'"]}
{"input": "k", "output": ["'K'"]}
{"input": "kh", "output": ["'KH'"]}
{"input": "kl", "output": ["'KL'"]}
{"input": "kn", "output": ["'KN'"]}
{"input": "lx", "output": ["'LX'"]}
{"input": "m", "output": ["'M'"]}
{"input": "mh", "output": ["'MH'"]}
{"input": "ml", "output": ["'ML'"]}
{"input": "mow", "output": ["'MOW'"]}
{"input": "p", "output": ["'P'"]}
{"input": "q", "output": ["'Q'"]}
{"input": "qh", "output": ["'QH'"]}
{"input": "qhw", "output": ["'QHW'"]}
{"input": "qhx", "output": ["'QHX'"]}
{"input": "qlw", "output": ["'QLW'"]}
{"input": "qlx", "output": ["'QLX'"]}
{"input": "qo", "output": ["'QO'"]}
{"input": "qow", "output": ["'QOW'"]}
{"input": "qox", "output": ["'QOX'"]}
{"input": "qw", "output": ["'QW'"]}
{"input": "qx", "output": ["'QX'"]}
{"input": "s", "output": ["'S'"]}
{"input": "u", "output": ["'U'"]}
{"input": "v", "output": ["'V'"]}
{"input": "vhw", "output": ["'VHW'"]}
{"input": "vhx", "output": ["'VHX'"]}
{"input": "vw", "output": ["'VW'"]}
{"input": "vx", "output": ["'VX'"]}
{"input": "y", "output": ["'Y'"]}
{"input": "yh", "output": ["'YH'"]}
{"input": "yl", "output": ["'YL'"]}
{"input": "yn", "output": ["'YN'"]}
{"input": "yw", "output": ["'YW'"]}
{"input": "yx", "output": ["'YX'"]}
{"input": "rental car", "output": ["'RENTAL CAR'"]}
{"input": "air taxi operation", "output": ["'AIR TAXI OPERATION'"]}
{"input": "limousine", "output": ["'LIMOUSINE'"]}
{"input": "taxi", "output": ["'TAXI'"]}
{"input": "arizona", "output": ["'ARIZONA'"]}
{"input": "california", "output": ["'CALIFORNIA'"]}
{"input": "colorado", "output": ["'COLORADO'"]}
{"input": "district of columbia", "output": ["'DISTRICT OF COLUMBIA'"]}
{"input": "dc", "output": ["'DISTRICT OF COLUMBIA'"]}
{"input": "dc", "output": ["'DC'"]}
{"input": "florida", "output": ["'FLORIDA'"]}
{"input": "georgia", "output": ["'GEORGIA'"]}
{"input": "illinois", "output": ["'ILLINOIS'"]}
{"input": "indiana", "output": ["'INDIANA'"]}
{"input": "massachusetts", "output": ["'MASSACHUSETTS'"]}
{"input": "maryland", "output": ["'MARYLAND'"]}
{"input": "michigan", "output": ["'MICHIGAN'"]}
{"input": "minnesota", "output": ["'MINNESOTA'"]}
{"input": "missouri", "output": ["'MISSOURI'"]}
{"input": "north carolina", "output": ["'NORTH CAROLINA'"]}
{"input": "new jersey", "output": ["'NEW JERSEY'"]}
{"input": "nevada", "output": ["'NEVADA'"]}
{"input": "new york", "output": ["'NEW YORK'"]}
{"input": "ohio", "output": ["'OHIO'"]}
{"input": "ontario", "output": ["'ONTARIO'"]}
{"input": "pennsylvania", "output": ["'PENNSYLVANIA'"]}
{"input": "quebec", "output": ["'QUEBEC'"]}
{"input": "tennessee", "output": ["'TENNESSEE'"]}
{"input": "texas", "output": ["'TEXAS'"]}
{"input": "utah", "output": ["'UTAH'"]}
{"input": "washington", "output": ["'WASHINGTON'"]}
{"input": "wisconsin", "output": ["'WISCONSIN'"]}
{"input": "1", "output": ["100"]}
{"input": "1", "output": ["1300"]}
{"input": "1pm", "output": ["1300"]}
{"input": "1am", "output": ["100"]}
{"input": "1 o'clock am", "output": ["100"]}
{"input": "1 o'clock pm", "output": ["1300"]}
{"input": "1 o'clock", "output": ["100"]}
{"input": "1 o'clock", "output": ["1300"]}
{"input": "100", "output": ["100"]}
{"input": "100", "output": ["1300"]}
{"input": "100am", "output": ["100"]}
{"input": "100pm", "output": ["1300"]}
{"input": "115", "output": ["115"]}
{"input": "115", "output": ["1315"]}
{"input": "115am", "output": ["115"]}
{"input": "115pm", "output": ["1315"]}
{"input": "130", "output": ["130"]}
{"input": "130", "output": ["1330"]}
{"input": "130am", "output": ["130"]}
{"input": "130pm", "output": ["1330"]}
{"input": "145", "output": ["145"]}
{"input": "145", "output": ["1345"]}
{"input": "145am", "output": ["145"]}
{"input": "145pm", "output": ["1345"]}
{"input": "2", "output": ["200"]}
{"input": "2", "output": ["1400"]}
{"input": "2pm", "output": ["1400"]}
{"input": "2am", "output": ["200"]}
{"input": "2 o'clock am", "output": ["200"]}
{"input": "2 o'clock pm", "output": ["1400"]}
{"input": "2 o'clock", "output": ["200"]}
{"input": "2 o'clock", "output": ["1400"]}
{"input": "200", "output": ["200"]}
{"input": "200", "output": ["1400"]}
{"input": "200am", "output": ["200"]}
{"input": "200pm", "output": ["1400"]}
{"input": "215", "output": ["215"]}
{"input": "215", "output": ["1415"]}
{"input": "215am", "output": ["215"]}
{"input": "215pm", "output": ["1415"]}
{"input": "230", "output": ["230"]}
{"input": "230", "output": ["1430"]}
{"input": "230am", "output": ["230"]}
{"input": "230pm", "output": ["1430"]}
{"input": "245", "output": ["245"]}
{"input": "245", "output": ["1445"]}
{"input": "245am", "output": ["245"]}
{"input": "245pm", "output": ["1445"]}
{"input": "3", "output": ["300"]}
{"input": "3", "output": ["1500"]}
{"input": "3pm", "output": ["1500"]}
{"input": "3am", "output": ["300"]}
{"input": "3 o'clock am", "output": ["300"]}
{"input": "3 o'clock pm", "output": ["1500"]}
{"input": "3 o'clock", "output": ["300"]}
{"input": "3 o'clock", "output": ["1500"]}
{"input": "300", "output": ["300"]}
{"input": "300", "output": ["1500"]}
{"input": "300am", "output": ["300"]}
{"input": "300pm", "output": ["1500"]}
{"input": "315", "output": ["315"]}
{"input": "315", "output": ["1515"]}
{"input": "315am", "output": ["315"]}
{"input": "315pm", "output": ["1515"]}
{"input": "330", "output": ["330"]}
{"input": "330", "output": ["1530"]}
{"input": "330am", "output": ["330"]}
{"input": "330pm", "output": ["1530"]}
{"input": "345", "output": ["345"]}
{"input": "345", "output": ["1545"]}
{"input": "345am", "output": ["345"]}
{"input": "345pm", "output": ["1545"]}
{"input": "4", "output": ["400"]}
{"input": "4", "output": ["1600"]}
{"input": "4pm", "output": ["1600"]}
{"input": "4am", "output": ["400"]}
{"input": "4 o'clock am", "output": ["400"]}
{"input": "4 o'clock pm", "output": ["1600"]}
{"input": "4 o'clock", "output": ["400"]}
{"input": "4 o'clock", "output": ["1600"]}
{"input": "400", "output": ["400"]}
{"input": "400", "output": ["1600"]}
{"input": "400am", "output": ["400"]}
{"input": "400pm", "output": ["1600"]}
{"input": "415", "output": ["415"]}
{"input": "415", "output": ["1615"]}
{"input": "415am", "output": ["415"]}
{"input": "415pm", "output": ["1615"]}
{"input": "430", "output": ["430"]}
{"input": "430", "output": ["1630"]}
{"input": "430am", "output": ["430"]}
{"input": "430pm", "output": ["1630"]}
{"input": "445", "output": ["445"]}
{"input": "445", "output": ["1645"]}
{"input": "445am", "output": ["445"]}
{"input": "445pm", "output": ["1645"]}
{"input": "5", "output": ["500"]}
{"input": "5", "output": ["1700"]}
{"input": "5pm", "output": ["1700"]}
{"input": "5am", "output": ["500"]}
{"input": "5 o'clock am", "output": ["500"]}
{"input": "5 o'clock pm", "output": ["1700"]}
{"input": "5 o'clock", "output": ["500"]}
{"input": "5 o'clock", "output": ["1700"]}
{"input": "500", "output": ["500"]}
{"input": "500", "output": ["1700"]}
{"input": "500am", "output": ["500"]}
{"input": "500pm", "output": ["1700"]}
{"input": "515", "output": ["515"]}
{"input": "515", "output": ["1715"]}
{"input": "515am", "output": ["515"]}
{"input": "515pm", "output": ["1715"]}
{"input": "530", "output": ["530"]}
{"input": "530", "output": ["1730"]}
{"input": "530am", "output": ["530"]}
{"input": "530pm", "output": ["1730"]}
{"input": "545", "output": ["545"]}
{"input": "545", "output": ["1745"]}
{"input": "545am", "output": ["545"]}
{"input": "545pm", "output": ["1745"]}
{"input": "6", "output": ["600"]}
{"input": "6", "output": ["1800"]}
{"input": "6pm", "output": ["1800"]}
{"input": "6am", "output": ["600"]}
{"input": "6 o'clock am", "output": ["600"]}
{"input": "6 o'clock pm", "output": ["1800"]}
{"input": "6 o'clock", "output": ["600"]}
{"input": "6 o'clock", "output": ["1800"]}
{"input": "600", "output": ["600"]}
{"input": "600", "output": ["1800"]}
{"input": "600am", "output": ["600"]}
{"input": "600pm", "output": ["1800"]}
{"input": "615", "output": ["615"]}
{"input": "615", "output": ["1815"]}
{"input": "615am", "output": ["615"]}
{"input": "615pm", "output": ["1815"]}
{"input": "630", "output": ["630"]}
{"input": "630", "output": ["1830"]}
{"input": "630am", "output": ["630"]}
{"input": "630pm", "output": ["1830"]}
{"input": "645", "output": ["645"]}
{"input": "645", "output": ["1845"]}
{"input": "645am", "output": ["645"]}
{"input": "645pm", "output": ["1845"]}
{"input": "7", "output": ["700"]}
{"input": "7", "output": ["1900"]}
{"input": "7pm", "output": ["1900"]}
{"input": "7am", "output": ["700"]}
{"input": "7 o'clock am", "output": ["700"]}
{"input": "7 o'clock pm", "output": ["1900"]}
{"input": "7 o'clock", "output": ["700"]}
{"input": "7 o'clock", "output": ["1900"]}
{"input": "700", "output": ["700"]}
{"input": "700", "output": ["1900"]}
{"input": "700am", "output": ["700"]}
{"input": "700pm", "output": ["1900"]}
{"input": "715", "output": ["715"]}
{"input": "715", "output": ["1915"]}
{"input": "715am", "output": ["715"]}
{"input": "715pm", "output": ["1915"]}
{"input": "730", "output": ["730"]}
{"input": "730", "output": ["1930"]}
{"input": "730am", "output": ["730"]}
{"input": "730pm", "output": ["1930"]}
{"input": "745", "output": ["745"]}
{"input": "745", "output": ["1945"]}
{"input": "745am", "output": ["745"]}
{"input": "745pm", "output": ["1945"]}
{"input": "8", "output": ["800"]}
{"input": "8", "output": ["2000"]}
{"input": "8pm", "output": ["2000"]}
{"input": "8am", "output": ["800"]}
{"input": "8 o'clock am", "output": ["800"]}
{"input": "8 o'clock pm", "output": ["2000"]}
{"input": "8 o'clock", "output": ["800"]}
{"input": "8 o'clock", "output": ["2000"]}
{"input": "800", "output": ["800"]}
{"input": "800", "output": ["2000"]}
{"input": "800am", "output": ["800"]}
{"input": "800pm", "output": ["2000"]}
{"input": "815", "output": ["815"]}
{"input": "815", "output": ["2015"]}
{"input": "815am", "output": ["815"]}
{"input": "815pm", "output": ["2015"]}
{"input": "830", "output": ["830"]}
{"input": "830", "output": ["2030"]}
{"input": "830am", "output": ["830"]}
{"input": "830pm", "output": ["2030"]}
{"input": "845", "output": ["845"]}
{"input": "845", "output": ["2045"]}
{"input": "845am", "output": ["845"]}
{"input": "845pm", "output": ["2045"]}
{"input": "9", "output": ["900"]}
{"input": "9", "output": ["2100"]}
{"input": "9pm", "output": ["2100"]}
{"input": "9am", "output": ["900"]}
{"input": "9 o'clock am", "output": ["900"]}
{"input": "9 o'clock pm", "output": ["2100"]}
{"input": "9 o'clock", "output": ["900"]}
{"input": "9 o'clock", "output": ["2100"]}
{"input": "900", "output": ["900"]}
{"input": "900", "output": ["2100"]}
{"input": "900am", "output": ["900"]}
{"input": "900pm", "output": ["2100"]}
{"input": "915", "output": ["915"]}
{"input": "915", "output": ["2115"]}
{"input": "915am", "output": ["915"]}
{"input": "915pm", "output": ["2115"]}
{"input": "930", "output": ["930"]}
{"input": "930", "output": ["2130"]}
{"input": "930am", "output": ["930"]}
{"input": "930pm", "output": ["2130"]}
{"input": "945", "output": ["945"]}
{"input": "945", "output": ["2145"]}
{"input": "945am", "output": ["945"]}
{"input": "945pm", "output": ["2145"]}
{"input": "10", "output": ["1000"]}
{"input": "10", "output": ["2200"]}
{"input": "10pm", "output": ["2200"]}
{"input": "10am", "output": ["1000"]}
{"input": "10 o'clock am", "output": ["1000"]}
{"input": "10 o'clock pm", "output": ["2200"]}
{"input": "10 o'clock", "output": ["1000"]}
{"input": "10 o'clock", "output": ["2200"]}
{"input": "1000", "output": ["1000"]}
{"input": "1000", "output": ["2200"]}
{"input": "1000am", "output": ["1000"]}
{"input": "1000pm", "output": ["2200"]}
{"input": "1015", "output": ["1015"]}
{"input": "1015", "output": ["2215"]}
{"input": "1015am", "output": ["1015"]}
{"input": "1015pm", "output": ["2215"]}
{"input": "1030", "output": ["1030"]}
{"input": "1030", "output": ["2230"]}
{"input": "1030am", "output": ["1030"]}
{"input": "1030pm", "output": ["2230"]}
{"input": "1045", "output": ["1045"]}
{"input": "1045", "output": ["2245"]}
{"input": "1045am", "output": ["1045"]}
{"input": "1045pm", "output": ["2245"]}
{"input": "11", "output": ["1100"]}
{"input": "11", "output": ["2300"]}
{"input": "11pm", "output": ["2300"]}
{"input": "11am", "output": ["1100"]}
{"input": "11 o'clock am", "output": ["1100"]}
{"input": "11 o'clock pm", "output": ["2300"]}
{"input": "11 o'clock", "output": ["1100"]}
{"input": "11 o'clock", "output": ["2300"]}
{"input": "1100", "output": ["1100"]}
{"input": "1100", "output": ["2300"]}
{"input": "1100am", "output": ["1100"]}
{"input": "1100pm", "output": ["2300"]}
{"input": "1115", "output": ["1115"]}
{"input": "1115", "output": ["2315"]}
{"input": "1115am", "output": ["1115"]}
{"input": "1115pm", "output": ["2315"]}
{"input": "1130", "output": ["1130"]}
{"input": "1130", "output": ["2330"]}
{"input": "1130am", "output": ["1130"]}
{"input": "1130pm", "output": ["2330"]}
{"input": "1145", "output": ["1145"]}
{"input": "1145", "output": ["2345"]}
{"input": "1145am", "output": ["1145"]}
{"input": "1145pm", "output": ["2345"]}
{"input": "12", "output": ["1200"]}
{"input": "1200", "output": ["1200"]}
{"input": "1200", "output": ["2400"]}
{"input": "1200am", "output": ["1200"]}
{"input": "1200pm", "output": ["2400"]}
{"input": "1215", "output": ["1215"]}
{"input": "1215", "output": ["2415"]}
{"input": "1215am", "output": ["1215"]}
{"input": "1215pm", "output": ["2415"]}
{"input": "1230", "output": ["1230"]}
{"input": "1230", "output": ["2430"]}
{"input": "1230am", "output": ["1230"]}
{"input": "1230pm", "output": ["2430"]}
{"input": "1245", "output": ["1245"]}
{"input": "1245", "output": ["2445"]}
{"input": "1245am", "output": ["1245"]}
{"input": "1245pm", "output": ["2445"]}
{"input": "13", "output": ["1300"]}
{"input": "1300", "output": ["1300"]}
{"input": "1315", "output": ["1315"]}
{"input": "1330", "output": ["1330"]}
{"input": "1345", "output": ["1345"]}
{"input": "14", "output": ["1400"]}
{"input": "1400", "output": ["1400"]}
{"input": "1415", "output": ["1415"]}
{"input": "1430", "output": ["1430"]}
{"input": "1445", "output": ["1445"]}
{"input": "15", "output": ["1500"]}
{"input": "1500", "output": ["1500"]}
{"input": "1515", "output": ["1515"]}
{"input": "1530", "output": ["1530"]}
{"input": "1545", "output": ["1545"]}
{"input": "16", "output": ["1600"]}
{"input": "1600", "output": ["1600"]}
{"input": "1615", "output": ["1615"]}
{"input": "1630", "output": ["1630"]}
{"input": "1645", "output": ["1645"]}
{"input": "17", "output": ["1700"]}
{"input": "1700", "output": ["1700"]}
{"input": "1715", "output": ["1715"]}
{"input": "1730", "output": ["1730"]}
{"input": "1745", "output": ["1745"]}
{"input": "18", "output": ["1800"]}
{"input": "1800", "output": ["1800"]}
{"input": "1815", "output": ["1815"]}
{"input": "1830", "output": ["1830"]}
{"input": "1845", "output": ["1845"]}
{"input": "19", "output": ["1900"]}
{"input": "1900", "output": ["1900"]}
{"input": "1915", "output": ["1915"]}
{"input": "1930", "output": ["1930"]}
{"input": "1945", "output": ["1945"]}
{"input": "20", "output": ["2000"]}
{"input": "2000", "output": ["2000"]}
{"input": "2015", "output": ["2015"]}
{"input": "2030", "output": ["2030"]}
{"input": "2045", "output": ["2045"]}
{"input": "21", "output": ["2100"]}
{"input": "2100", "output": ["2100"]}
{"input": "2115", "output": ["2115"]}
{"input": "2130", "output": ["2130"]}
{"input": "2145", "output": ["2145"]}
{"input": "22", "output": ["2200"]}
{"input": "2200", "output": ["2200"]}
{"input": "2215", "output": ["2215"]}
{"input": "2230", "output": ["2230"]}
{"input": "2245", "output": ["2245"]}
{"input": "23", "output": ["2300"]}
{"input": "2300", "output": ["2300"]}
{"input": "2315", "output": ["2315"]}
{"input": "2330", "output": ["2330"]}
{"input": "2345", "output": ["2345"]}
{"input": "noon", "output": ["1200"]}
{"input": "breakfast", "output": ["'BREAKFAST'"]}
{"input": "lunch", "output": ["'LUNCH'"]}
{"input": "dinner", "output": ["'DINNER'"]}
{"input": "snack", "output": ["'SNACK'"]}
{"input": "atl", "output": ["'ATL'"]}
{"input": "bna", "output": ["'BNA'"]}
{"input": "bos", "output": ["'BOS'"]}
{"input": "bur", "output": ["'BUR'"]}
{"input": "bwi", "output": ["'BWI'"]}
{"input": "cle", "output": ["'CLE'"]}
{"input": "clt", "output": ["'CLT'"]}
{"input": "cmh", "output": ["'CMH'"]}
{"input": "cvg", "output": ["'CVG'"]}
{"input": "dal", "output": ["'DAL'"]}
{"input": "love field", "output": ["'DAL'"]}
{"input": "dca", "output": ["'DCA'"]}
{"input": "den", "output": ["'DEN'"]}
{"input": "det", "output": ["'DET'"]}
{"input": "dfw", "output": ["'DFW'"]}
{"input": "dtw", "output": ["'DTW'"]}
{"input": "ewr", "output": ["'EWR'"]}
{"input": "hou", "output": ["'HOU'"]}
{"input": "hpn", "output": ["'HPN'"]}
{"input": "iad", "output": ["'IAD'"]}
{"input": "iah", "output": ["'IAH'"]}
{"input": "houston international", "output": ["'IAH'"]}
{"input": "houston international airport", "output": ["'IAH'"]}
{"input": "houston intercontinental airport", "output": ["'IAH'"]}
{"input": "ind", "output": ["'IND'"]}
{"input": "jfk", "output": ["'JFK'"]}
{"input": "las", "output": ["'LAS'"]}
{"input": "lax", "output": ["'LAX'"]}
{"input": "lga", "output": ["'LGA'"]}
{"input": "la guardia", "output": ["'LGA'"]}
{"input": "lgb", "output": ["'LGB'"]}
{"input": "mci", "output": ["'MCI'"]}
{"input": "mco", "output": ["'MCO'"]}
{"input": "mdw", "output": ["'MDW'"]}
{"input": "mem", "output": ["'MEM'"]}
{"input": "mia", "output": ["'MIA'"]}
{"input": "mke", "output": ["'MKE'"]}
{"input": "msp", "output": ["'MSP'"]}
{"input": "oak", "output": ["'OAK'"]}
{"input": "ont", "output": ["'ONT'"]}
{"input": "ord", "output": ["'ORD'"]}
{"input": "phl", "output": ["'PHL'"]}
{"input": "phx", "output": ["'PHX'"]}
{"input": "pie", "output": ["'PIE'"]}
{"input": "pit", "output": ["'PIT'"]}
{"input": "san", "output": ["'SAN'"]}
{"input": "sea", "output": ["'SEA'"]}
{"input": "sfo", "output": ["'SFO'"]}
{"input": "sjc", "output": ["'SJC'"]}
{"input": "slc", "output": ["'SLC'"]}
{"input": "stl", "output": ["'STL'"]}
{"input": "tpa", "output": ["'TPA'"]}
{"input": "ykz", "output": ["'YKZ'"]}
{"input": "ymx", "output": ["'YMX'"]}
{"input": "ytz", "output": ["'YTZ'"]}
{"input": "yul", "output": ["'YUL'"]}
{"input": "yyz", "output": ["'YYZ'"]}
{"input": "ap/2", "output": ["'AP/2'"]}
{"input": "ap/6", "output": ["'AP/6'"]}
{"input": "ap/12", "output": ["'AP/12'"]}
{"input": "ap/20", "output": ["'AP/20'"]}
{"input": "ap/21", "output": ["'AP/21'"]}
{"input": "ap/57", "output": ["'AP/57'"]}
{"input": "ap/58", "output": ["'AP/58'"]}
{"input": "ap/60", "output": ["'AP/60'"]}
{"input": "ap/75", "output": ["'AP/75'"]}
{"input": "ex/9", "output": ["'EX/9'"]}
{"input": "ex/13", "output": ["'EX/13'"]}
{"input": "ex/14", "output": ["'EX/14'"]}
{"input": "ex/17", "output": ["'EX/17'"]}
{"input": "ex/19", "output": ["'EX/19'"]}
{"input": "flight", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "flights", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "fly", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "trip", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "cost", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "price", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "cheapest", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "expensive", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "fares", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "fare", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "rate", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "cost", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "prices", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "price", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "cheapest", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "expensive", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "fares", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "fare", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "rate", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "cheapest", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "least", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "cheapest one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "least expensive one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "most expensive one way", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "highest fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "highest one way fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "cheapest", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "round trip", "output": ["fare.round_trip_cost", "IS", "NOT", "NULL"]}
{"input": "least", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "cheapest round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "least expensive round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "airplane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "plane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "ground transportation", "output": ["DISTINCT", "ground_service.city_code", ",", "ground_service.airport_code", ",", "ground_service.transport_type", "FROM", "ground_service"]}
{"input": "airline", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
{"input": "airlines", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
{"input": "meals", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "meal", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "breakfast", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "dinner", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "lunch", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "schedule", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "where", "output": ["DISTINCT", "airport.airport_code"]}
{"input": "seating capacity", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
{"input": "capacities", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.basic_type", "FROM", "aircraft"]}
{"input": "stop", "output": ["flight.stops"]}
{"input": "schedules", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "time", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "times", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "flights and fares", "output": ["DISTINCT", "flight.flight_id", ",", "fare.fare_id", "FROM", "flight", ",", "flight_fare", ",", "fare"]}
{"input": "fare code", "output": ["DISTINCT", "fare_basis.fare_basis_code", ",", "fare_basis.booking_class", ",", "fare_basis.class_type", ",", "fare_basis.premium", ",", "fare_basis.economy", ",", "fare_basis.discounted", ",", "fare_basis.night", ",", "fare_basis.season", ",", "fare_basis.basis_days", "FROM", "fare_basis"]}
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "latest", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "last", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "shortest", "output": ["MIN", "(", "flight.time_elapsed", ")", "FROM", "flight"]}
{"input": "how many", "output": ["count", "(", "*", ")"]}
{"input": "number", "output": ["count", "(", "*", ")"]}
{"input": "one way", "output": ["fare.round_trip_required", "=", "'NO'"]}
{"input": "1 way", "output": ["fare.round_trip_required", "=", "'NO'"]}
{"input": "cities", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
{"input": "city", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
{"input": "airport", "output": ["DISTINCT", "airport.airport_location"]}
{"input": "economy", "output": ["'ECONOMY'"]}
{"input": "monday", "output": ["date_day.month_number"]}
{"input": "monday", "output": ["date_day.day_number"]}
{"input": "tuesday", "output": ["date_day.month_number"]}
{"input": "tuesday", "output": ["date_day.day_number"]}
{"input": "wednesday", "output": ["date_day.month_number"]}
{"input": "wednesday", "output": ["date_day.day_number"]}
{"input": "thursday", "output": ["date_day.month_number"]}
{"input": "thursday", "output": ["date_day.day_number"]}
{"input": "friday", "output": ["date_day.month_number"]}
{"input": "friday", "output": ["date_day.day_number"]}
{"input": "saturday", "output": ["date_day.month_number"]}
{"input": "saturday", "output": ["date_day.day_number"]}
{"input": "sunday", "output": ["date_day.month_number"]}
{"input": "sunday", "output": ["date_day.day_number"]}
{"input": "class", "output": ["class_of_service.booking_class"]}
{"input": "morning", "output": ["BETWEEN", "0", "AND", "1200"]}
{"input": "what does", "output": ["DISTINCT", "airport.airport_name", "FROM", "airport"]}
{"input": "what does", "output": ["DISTINCT", "class_of_service.class_description", "FROM", "class_of_service"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "meals", "output": ["food_service.meal_code"]}
{"input": "flight", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "flights", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "fly", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "trip", "output": ["DISTINCT", "flight.flight_id", "FROM", "flight"]}
{"input": "cost", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "price", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "cheapest", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "expensive", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "fares", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "fare", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "rate", "output": ["DISTINCT", "ground_service.ground_fare", "FROM", "ground_service"]}
{"input": "cost", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "prices", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "price", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "cheapest", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "expensive", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "fares", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "fare", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "rate", "output": ["DISTINCT", "fare.fare_id", "FROM", "fare"]}
{"input": "cheapest", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "least", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "cheapest one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "least expensive one way", "output": ["MIN", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "most expensive one way", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "highest fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "highest one way fare", "output": ["MAX", "(", "fare.one_direction_cost", ")", "FROM", "fare"]}
{"input": "cheapest", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "round trip", "output": ["fare.round_trip_cost", "IS", "NOT", "NULL"]}
{"input": "least", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "cheapest round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "least expensive round trip", "output": ["MIN", "(", "fare.round_trip_cost", ")", "FROM", "fare"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "airplane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "plane", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "ground transportation", "output": ["DISTINCT", "ground_service.city_code", ",", "ground_service.airport_code", ",", "ground_service.transport_type", "FROM", "ground_service"]}
{"input": "airline", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
{"input": "airlines", "output": ["DISTINCT", "airline.airline_code", "FROM", "airline"]}
{"input": "meals", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "meal", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "breakfast", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "dinner", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "lunch", "output": ["DISTINCT", "food_service.meal_code", ",", "food_service.meal_number", ",", "food_service.compartment", "FROM", "food_service"]}
{"input": "schedule", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "where", "output": ["DISTINCT", "airport.airport_code"]}
{"input": "seating capacity", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
{"input": "capacities", "output": ["DISTINCT", "aircraft.capacity", "FROM", "aircraft"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.basic_type", "FROM", "aircraft"]}
{"input": "stop", "output": ["flight.stops"]}
{"input": "schedules", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "time", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "times", "output": ["DISTINCT", "flight.departure_time", "FROM", "flight"]}
{"input": "flights and fares", "output": ["DISTINCT", "flight.flight_id", ",", "fare.fare_id", "FROM", "flight", ",", "flight_fare", ",", "fare"]}
{"input": "fare code", "output": ["DISTINCT", "fare_basis.fare_basis_code", ",", "fare_basis.booking_class", ",", "fare_basis.class_type", ",", "fare_basis.premium", ",", "fare_basis.economy", ",", "fare_basis.discounted", ",", "fare_basis.night", ",", "fare_basis.season", ",", "fare_basis.basis_days", "FROM", "fare_basis"]}
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "information", "output": ["DISTINCT", "flight.flight_id", ",", "flight", ".", "flight_days", ",", "flight", ".", "from_airport", ",", "flight", ".", "to_airport", ",", "flight", ".", "departure_time", ",", "flight", ".", "arrival_time", ",", "flight", ".", "airline_flight", ",", "flight", ".", "airline_code", ",", "flight", ".", "flight_number", ",", "flight", ".", "aircraft_code_sequence", ",", "flight", ".", "meal_code", ",", "flight", ".", "stops", ",", "flight", ".", "connections", ",", "flight", ".", "dual_carrier", ",", "flight", ".", "time_elapsed", "FROM", "flight"]}
{"input": "tell me about", "output": ["DISTINCT", "flight.flight_id", ",", "flight.flight_days", ",", "flight.from_airport", ",", "flight.to_airport", ",", "flight.departure_time", ",", "flight.arrival_time", ",", "flight.airline_flight", ",", "flight.airline_code", ",", "flight.flight_number", ",", "flight.aircraft_code_sequence", ",", "flight.meal_code", ",", "flight.stops", ",", "flight.connections", ",", "flight.dual_carrier", ",", "flight.time_elapsed", "FROM", "flight"]}
{"input": "earliest", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "first", "output": ["MIN", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "latest", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "last", "output": ["MAX", "(", "flight.departure_time", ")", "FROM", "flight"]}
{"input": "shortest", "output": ["MIN", "(", "flight.time_elapsed", ")", "FROM", "flight"]}
{"input": "how many", "output": ["count", "(", "*", ")"]}
{"input": "number", "output": ["count", "(", "*", ")"]}
{"input": "one way", "output": ["fare.round_trip_required", "=", "'NO'"]}
{"input": "1 way", "output": ["fare.round_trip_required", "=", "'NO'"]}
{"input": "cities", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
{"input": "city", "output": ["DISTINCT", "city.city_code", "FROM", "city"]}
{"input": "airport", "output": ["DISTINCT", "airport.airport_location"]}
{"input": "economy", "output": ["'ECONOMY'"]}
{"input": "monday", "output": ["date_day.month_number"]}
{"input": "monday", "output": ["date_day.day_number"]}
{"input": "tuesday", "output": ["date_day.month_number"]}
{"input": "tuesday", "output": ["date_day.day_number"]}
{"input": "wednesday", "output": ["date_day.month_number"]}
{"input": "wednesday", "output": ["date_day.day_number"]}
{"input": "thursday", "output": ["date_day.month_number"]}
{"input": "thursday", "output": ["date_day.day_number"]}
{"input": "friday", "output": ["date_day.month_number"]}
{"input": "friday", "output": ["date_day.day_number"]}
{"input": "saturday", "output": ["date_day.month_number"]}
{"input": "saturday", "output": ["date_day.day_number"]}
{"input": "sunday", "output": ["date_day.month_number"]}
{"input": "sunday", "output": ["date_day.day_number"]}
{"input": "class", "output": ["class_of_service.booking_class"]}
{"input": "morning", "output": ["BETWEEN", "0", "AND", "1200"]}
{"input": "what does", "output": ["DISTINCT", "airport.airport_name", "FROM", "airport"]}
{"input": "what does", "output": ["DISTINCT", "class_of_service.class_description", "FROM", "class_of_service"]}
{"input": "aircraft", "output": ["DISTINCT", "aircraft.aircraft_code", "FROM", "aircraft"]}
{"input": "meals", "output": ["food_service.meal_code"]}

View File

@ -0,0 +1,262 @@
"""Code for identifying and anonymizing entities in NL and SQL."""
import copy
import json
from . import util
ENTITY_NAME = "ENTITY"
CONSTANT_NAME = "CONSTANT"
TIME_NAME = "TIME"
SEPARATOR = "#"
def timeval(string):
"""Returns the numeric version of a time.
Inputs:
string (str): String representing a time.
Returns:
String representing the absolute time.
"""
if string.endswith("am") or string.endswith(
"pm") and string[:-2].isdigit():
numval = int(string[:-2])
if len(string) == 3 or len(string) == 4:
numval *= 100
if string.endswith("pm"):
numval += 1200
return str(numval)
return ""
def is_time(string):
"""Returns whether a string represents a time.
Inputs:
string (str): String to check.
Returns:
Whether the string represents a time.
"""
if string.endswith("am") or string.endswith("pm"):
if string[:-2].isdigit():
return True
return False
def deanonymize(sequence, ent_dict, key):
"""Deanonymizes a sequence.
Inputs:
sequence (list of str): List of tokens to deanonymize.
ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary.
key (str): The key to use, in this case either natural language or SQL.
Returns:
Deanonymized sequence of tokens.
"""
new_sequence = []
for token in sequence:
if token in ent_dict:
new_sequence.extend(ent_dict[token][key])
else:
new_sequence.append(token)
return new_sequence
class Anonymizer:
"""Anonymization class for keeping track of entities in this domain and
scripts for anonymizing/deanonymizing.
Members:
anonymization_map (list of dict (str->str)): Containing entities from
the anonymization file.
entity_types (list of str): All entities in the anonymization file.
keys (set of str): Possible keys (types of text handled); in this case it should be
one for natural language and another for SQL.
entity_set (set of str): entity_types as a set.
"""
def __init__(self, filename):
self.anonymization_map = []
self.entity_types = []
self.keys = set()
pairs = [json.loads(line) for line in open(filename).readlines()]
for pair in pairs:
for key in pair:
if key != "type":
self.keys.add(key)
self.anonymization_map.append(pair)
if pair["type"] not in self.entity_types:
self.entity_types.append(pair["type"])
self.entity_types.append(ENTITY_NAME)
self.entity_types.append(CONSTANT_NAME)
self.entity_types.append(TIME_NAME)
self.entity_set = set(self.entity_types)
def get_entity_type_from_token(self, token):
"""Gets the type of an entity given an anonymized token.
Inputs:
token (str): The entity token.
Returns:
str, representing the type of the entity.
"""
# these are in the pattern NAME:#, so just strip the thing after the
# colon
colon_loc = token.index(SEPARATOR)
entity_type = token[:colon_loc]
assert entity_type in self.entity_set
return entity_type
def is_anon_tok(self, token):
"""Returns whether a token is an anonymized token or not.
Input:
token (str): The token to check.
Returns:
bool, whether the token is an anonymized token.
"""
return token.split(SEPARATOR)[0] in self.entity_set
def get_anon_id(self, token):
"""Gets the entity index (unique ID) for a token.
Input:
token (str): The token to get the index from.
Returns:
int, the token ID if it is an anonymized token; otherwise -1.
"""
if self.is_anon_tok(token):
return self.entity_types.index(token.split(SEPARATOR)[0])
else:
return -1
def anonymize(self,
sequence,
tok_to_entity_dict,
key,
add_new_anon_toks=False):
"""Anonymizes a sequence.
Inputs:
sequence (list of str): Sequence to anonymize.
tok_to_entity_dict (dict): Existing dictionary mapping from anonymized
tokens to entities.
key (str): Which kind of text this is (natural language or SQL)
add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict.
Returns:
list of str, the anonymized sequence.
"""
# Sort the token-tok-entity dict by the length of the modality.
sorted_dict = sorted(tok_to_entity_dict.items(),
key=lambda k: len(k[1][key]))[::-1]
anonymized_sequence = copy.deepcopy(sequence)
if add_new_anon_toks:
type_counts = {}
for entity_type in self.entity_types:
type_counts[entity_type] = 0
for token in tok_to_entity_dict:
entity_type = self.get_entity_type_from_token(token)
type_counts[entity_type] += 1
# First find occurrences of things in the anonymization dictionary.
for token, modalities in sorted_dict:
our_modality = modalities[key]
# Check if this key's version of the anonymized thing is in our
# sequence.
while util.subsequence(our_modality, anonymized_sequence):
found = False
for startidx in range(
len(anonymized_sequence) - len(our_modality) + 1):
if anonymized_sequence[startidx:startidx +
len(our_modality)] == our_modality:
anonymized_sequence = anonymized_sequence[:startidx] + [
token] + anonymized_sequence[startidx + len(our_modality):]
found = True
break
assert found, "Thought " \
+ str(our_modality) + " was in [" \
+ str(anonymized_sequence) + "] but could not find it"
# Now add new keys if they are present.
if add_new_anon_toks:
# For every span in the sequence, check whether it is in the anon map
# for this modality
sorted_anon_map = sorted(self.anonymization_map,
key=lambda k: len(k[key]))[::-1]
for pair in sorted_anon_map:
our_modality = pair[key]
token_type = pair["type"]
new_token = token_type + SEPARATOR + \
str(type_counts[token_type])
while util.subsequence(our_modality, anonymized_sequence):
found = False
for startidx in range(
len(anonymized_sequence) - len(our_modality) + 1):
if anonymized_sequence[startidx:startidx + \
len(our_modality)] == our_modality:
if new_token not in tok_to_entity_dict:
type_counts[token_type] += 1
tok_to_entity_dict[new_token] = pair
anonymized_sequence = anonymized_sequence[:startidx] + [
new_token] + anonymized_sequence[startidx + len(our_modality):]
found = True
break
assert found, "Thought " \
+ str(our_modality) + " was in [" \
+ str(anonymized_sequence) + "] but could not find it"
# Also replace integers with constants
for index, token in enumerate(anonymized_sequence):
if token.isdigit() or is_time(token):
if token.isdigit():
entity_type = CONSTANT_NAME
value = new_token
if is_time(token):
entity_type = TIME_NAME
value = timeval(token)
# First try to find the constant in the entity dictionary already,
# and get the name if it's found.
new_token = ""
new_dict = {}
found = False
for entity, value in tok_to_entity_dict.items():
if value[key][0] == token:
new_token = entity
new_dict = value
found = True
break
if not found:
new_token = entity_type + SEPARATOR + \
str(type_counts[entity_type])
new_dict = {}
for tempkey in self.keys:
new_dict[tempkey] = [token]
tok_to_entity_dict[new_token] = new_dict
type_counts[entity_type] += 1
anonymized_sequence[index] = new_token
return anonymized_sequence

View File

@ -0,0 +1,350 @@
# TODO: review this entire file and make it much simpler.
import copy
from . import snippets as snip
from . import sql_util
from . import vocabulary as vocab
class UtteranceItem():
def __init__(self, interaction, index):
self.interaction = interaction
self.utterance_index = index
def __str__(self):
return str(self.interaction.utterances[self.utterance_index])
def histories(self, maximum):
if maximum > 0:
history_seqs = []
for utterance in self.interaction.utterances[:self.utterance_index]:
history_seqs.append(utterance.input_seq_to_use)
if len(history_seqs) > maximum:
history_seqs = history_seqs[-maximum:]
return history_seqs
return []
def input_sequence(self):
return self.interaction.utterances[self.utterance_index].input_seq_to_use
def previous_query(self):
if self.utterance_index == 0:
return []
return self.interaction.utterances[self.utterance_index -
1].anonymized_gold_query
def anonymized_gold_query(self):
return self.interaction.utterances[self.utterance_index].anonymized_gold_query
def snippets(self):
return self.interaction.utterances[self.utterance_index].available_snippets
def original_gold_query(self):
return self.interaction.utterances[self.utterance_index].original_gold_query
def contained_entities(self):
return self.interaction.utterances[self.utterance_index].contained_entities
def original_gold_queries(self):
return [
q[0] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
def gold_tables(self):
return [
q[1] for q in self.interaction.utterances[self.utterance_index].all_gold_queries]
def gold_query(self):
return self.interaction.utterances[self.utterance_index].gold_query_to_use + [
vocab.EOS_TOK]
def gold_edit_sequence(self):
return self.interaction.utterances[self.utterance_index].gold_edit_sequence
def gold_table(self):
return self.interaction.utterances[self.utterance_index].gold_sql_results
def all_snippets(self):
return self.interaction.snippets
def within_limits(self,
max_input_length=float('inf'),
max_output_length=float('inf')):
return self.interaction.utterances[self.utterance_index].length_valid(
max_input_length, max_output_length)
def expand_snippets(self, sequence):
# Remove the EOS
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
# First remove the snippets
no_snippets_sequence = self.interaction.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
return no_snippets_sequence
def flatten_sequence(self, sequence):
# Remove the EOS
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
# First remove the snippets
no_snippets_sequence = self.interaction.expand_snippets(sequence)
# Deanonymize
deanon_sequence = self.interaction.deanonymize(
no_snippets_sequence, "sql")
return deanon_sequence
class UtteranceBatch():
def __init__(self, items):
self.items = items
def __len__(self):
return len(self.items)
def start(self):
self.index = 0
def next(self):
item = self.items[self.index]
self.index += 1
return item
def done(self):
return self.index >= len(self.items)
class PredUtteranceItem():
def __init__(self,
input_sequence,
interaction_item,
previous_query,
index,
available_snippets):
self.input_seq_to_use = input_sequence
self.interaction_item = interaction_item
self.index = index
self.available_snippets = available_snippets
self.prev_pred_query = previous_query
def input_sequence(self):
return self.input_seq_to_use
def histories(self, maximum):
if maximum == 0:
return histories
histories = []
for utterance in self.interaction_item.processed_utterances[:self.index]:
histories.append(utterance.input_sequence())
if len(histories) > maximum:
histories = histories[-maximum:]
return histories
def snippets(self):
return self.available_snippets
def previous_query(self):
return self.prev_pred_query
def flatten_sequence(self, sequence):
return self.interaction_item.flatten_sequence(sequence)
def remove_snippets(self, sequence):
return sql_util.fix_parentheses(
self.interaction_item.expand_snippets(sequence))
def set_predicted_query(self, query):
self.anonymized_pred_query = query
# Mocks an Interaction item, but allows for the parameters to be updated during
# the process
class InteractionItem():
def __init__(self,
interaction,
max_input_length=float('inf'),
max_output_length=float('inf'),
nl_to_sql_dict={},
maximum_length=float('inf')):
if maximum_length != float('inf'):
self.interaction = copy.deepcopy(interaction)
self.interaction.utterances = self.interaction.utterances[:maximum_length]
else:
self.interaction = interaction
self.processed_utterances = []
self.snippet_bank = []
self.identifier = self.interaction.identifier
self.max_input_length = max_input_length
self.max_output_length = max_output_length
self.nl_to_sql_dict = nl_to_sql_dict
self.index = 0
def __len__(self):
return len(self.interaction)
def __str__(self):
s = "Utterances, gold queries, and predictions:\n"
for i, utterance in enumerate(self.interaction.utterances):
s += " ".join(utterance.input_seq_to_use) + "\n"
pred_utterance = self.processed_utterances[i]
s += " ".join(pred_utterance.gold_query()) + "\n"
s += " ".join(pred_utterance.anonymized_query()) + "\n"
s += "\n"
s += "Snippets:\n"
for snippet in self.snippet_bank:
s += str(snippet) + "\n"
return s
def start_interaction(self):
assert len(self.snippet_bank) == 0
assert len(self.processed_utterances) == 0
assert self.index == 0
def next_utterance(self):
utterance = self.interaction.utterances[self.index]
self.index += 1
available_snippets = self.available_snippets(snippet_keep_age=1)
return PredUtteranceItem(utterance.input_seq_to_use,
self,
self.processed_utterances[-1].anonymized_pred_query if len(self.processed_utterances) > 0 else [],
self.index - 1,
available_snippets)
def done(self):
return len(self.processed_utterances) == len(self.interaction)
def finish(self):
self.snippet_bank = []
self.processed_utterances = []
self.index = 0
def utterance_within_limits(self, utterance_item):
return utterance_item.within_limits(self.max_input_length,
self.max_output_length)
def available_snippets(self, snippet_keep_age):
return [
snippet for snippet in self.snippet_bank if snippet.index <= snippet_keep_age]
def gold_utterances(self):
utterances = []
for i, utterance in enumerate(self.interaction.utterances):
utterances.append(UtteranceItem(self.interaction, i))
return utterances
def get_schema(self):
return self.interaction.schema
def add_utterance(
self,
utterance,
predicted_sequence,
snippets=None,
previous_snippets=[],
simple=False):
if not snippets:
self.add_snippets(
predicted_sequence,
previous_snippets=previous_snippets, simple=simple)
else:
for snippet in snippets:
snippet.assign_id(len(self.snippet_bank))
self.snippet_bank.append(snippet)
for snippet in self.snippet_bank:
snippet.increase_age()
self.processed_utterances.append(utterance)
def add_snippets(self, sequence, previous_snippets=[], simple=False):
if sequence:
if simple:
snippets = sql_util.get_subtrees_simple(
sequence, oldsnippets=previous_snippets)
else:
snippets = sql_util.get_subtrees(
sequence, oldsnippets=previous_snippets)
for snippet in snippets:
snippet.assign_id(len(self.snippet_bank))
self.snippet_bank.append(snippet)
for snippet in self.snippet_bank:
snippet.increase_age()
def expand_snippets(self, sequence):
return sql_util.fix_parentheses(
snip.expand_snippets(
sequence, self.snippet_bank))
def remove_snippets(self, sequence):
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
no_snippets_sequence = self.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
return no_snippets_sequence
def flatten_sequence(self, sequence, gold_snippets=False):
if sequence[-1] == vocab.EOS_TOK:
sequence = sequence[:-1]
if gold_snippets:
no_snippets_sequence = self.interaction.expand_snippets(sequence)
else:
no_snippets_sequence = self.expand_snippets(sequence)
no_snippets_sequence = sql_util.fix_parentheses(no_snippets_sequence)
deanon_sequence = self.interaction.deanonymize(
no_snippets_sequence, "sql")
return deanon_sequence
def gold_query(self, index):
return self.interaction.utterances[index].gold_query_to_use + [
vocab.EOS_TOK]
def original_gold_query(self, index):
return self.interaction.utterances[index].original_gold_query
def gold_table(self, index):
return self.interaction.utterances[index].gold_sql_results
class InteractionBatch():
def __init__(self, items):
self.items = items
def __len__(self):
return len(self.items)
def start(self):
self.timestep = 0
self.current_interactions = []
def get_next_utterance_batch(self, snippet_keep_age, use_gold=False):
items = []
self.current_interactions = []
for interaction in self.items:
if self.timestep < len(interaction):
utterance_item = interaction.original_utterances(
snippet_keep_age, use_gold)[self.timestep]
self.current_interactions.append(interaction)
items.append(utterance_item)
self.timestep += 1
return UtteranceBatch(items)
def done(self):
finished = True
for interaction in self.items:
if self.timestep < len(interaction):
finished = False
return finished
return finished

View File

@ -0,0 +1,376 @@
""" Utility functions for loading and processing ATIS data."""
import os
import random
import json
from . import anonymization as anon
from . import atis_batch
from . import dataset_split as ds
from .interaction import load_function
from .entities import NLtoSQLDict
from .atis_vocab import ATISVocabulary
ENTITIES_FILENAME = './entities.txt'
ANONYMIZATION_FILENAME = 'data/anonymization.txt'
class ATISDataset():
""" Contains the ATIS data. """
def __init__(self, params):
self.anonymizer = None
if params.anonymize:
self.anonymizer = anon.Anonymizer(ANONYMIZATION_FILENAME)
if not os.path.exists(params.data_directory):
os.mkdir(params.data_directory)
self.entities_dictionary = NLtoSQLDict(ENTITIES_FILENAME)
database_schema = None
if params.database_schema_filename:
if 'removefrom' not in params.data_directory:
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema_simple(params.database_schema_filename)
else:
database_schema, column_names_surface_form, column_names_embedder_input = self.read_database_schema(params.database_schema_filename)
int_load_function = load_function(params,
self.entities_dictionary,
self.anonymizer,
database_schema=database_schema)
def collapse_list(the_list):
""" Collapses a list of list into a single list."""
return [s for i in the_list for s in i]
if 'atis' not in params.data_directory:
self.train_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_train_filename),
params.raw_train_filename,
int_load_function)
self.valid_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_validation_filename),
params.raw_validation_filename,
int_load_function)
train_input_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.input_seqs()))
valid_input_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.input_seqs()))
all_input_seqs = train_input_seqs + valid_input_seqs
self.input_vocabulary = ATISVocabulary(
all_input_seqs,
os.path.join(params.data_directory, params.input_vocabulary_filename),
params,
is_input='input',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
self.output_vocabulary_schema = ATISVocabulary(
column_names_embedder_input,
os.path.join(params.data_directory, 'schema_'+params.output_vocabulary_filename),
params,
is_input='schema',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
#train_output_seqs = collapse_list(self.train_data.get_ex_properties(lambda i: i.output_seqs()))
#valid_output_seqs = collapse_list(self.valid_data.get_ex_properties(lambda i: i.output_seqs()))
#all_output_seqs = train_output_seqs + valid_output_seqs
sql_keywords = ['.', 't1', 't2', '=', 'select', 'as', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'group', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
sql_keywords += ['count', 'from', 'value', 'order']
sql_keywords += ['group_by', 'order_by', 'limit_value', '!=']
# skip column_names_surface_form but keep sql_keywords
skip_tokens = list(set(column_names_surface_form) - set(sql_keywords))
#print(len(all_output_seqs))
all_output_seqs = []
out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'and', 'avg', 'having', '<', 'in', 'sum', 'max', 'asc', 'not', 'or', 'like', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
#if params.data_directory == 'processed_data_sparc_removefrom_test':
#out_vocab_ordered = ['select', 'value', ')', '(', 'where', '=', ',', 'count', 'group_by', 'order_by', 'limit_value', 'desc', '>', 'distinct', 'avg', 'and', 'having', '<', 'in', 'max', 'sum', 'asc', 'like', 'not', 'or', 'min', 'intersect', 'except', '!=', 'union', 'between', '-', '+']
for i in range(len(out_vocab_ordered)):
all_output_seqs.append(out_vocab_ordered[:i+1])
self.output_vocabulary = ATISVocabulary(
all_output_seqs,
os.path.join(params.data_directory, params.output_vocabulary_filename),
params,
is_input='output',
anonymizer=self.anonymizer if params.anonymization_scoring else None,
skip=skip_tokens)
else:
self.train_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_train_filename),
params.raw_train_filename,
int_load_function)
if params.train:
self.valid_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_validation_filename),
params.raw_validation_filename,
int_load_function)
if params.evaluate or params.attention:
self.dev_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_dev_filename),
params.raw_dev_filename,
int_load_function)
if params.enable_testing:
self.test_data = ds.DatasetSplit(
os.path.join(params.data_directory, params.processed_test_filename),
params.raw_test_filename,
int_load_function)
train_input_seqs = []
train_input_seqs = collapse_list(
self.train_data.get_ex_properties(
lambda i: i.input_seqs()))
self.input_vocabulary = ATISVocabulary(
train_input_seqs,
os.path.join(params.data_directory, params.input_vocabulary_filename),
params,
is_input='input',
min_occur=2,
anonymizer=self.anonymizer if params.anonymization_scoring else None)
train_output_seqs = collapse_list(
self.train_data.get_ex_properties(
lambda i: i.output_seqs()))
self.output_vocabulary = ATISVocabulary(
train_output_seqs,
os.path.join(params.data_directory, params.output_vocabulary_filename),
params,
is_input='output',
anonymizer=self.anonymizer if params.anonymization_scoring else None)
self.output_vocabulary_schema = None
def read_database_schema_simple(self, database_schema_filename):
with open(database_schema_filename, "r") as f:
database_schema = json.load(f)
database_schema_dict = {}
column_names_surface_form = []
column_names_embedder_input = []
for table_schema in database_schema:
db_id = table_schema['db_id']
database_schema_dict[db_id] = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
for table_name in table_names_original:
column_names_surface_form.append(table_name.lower())
for i, (table_id, column_name) in enumerate(column_names):
column_name_embedder_input = column_name
column_names_embedder_input.append(column_name_embedder_input.split())
for table_name in table_names:
column_names_embedder_input.append(table_name.split())
database_schema = database_schema_dict
return database_schema, column_names_surface_form, column_names_embedder_input
def read_database_schema(self, database_schema_filename):
with open(database_schema_filename, "r") as f:
database_schema = json.load(f)
database_schema_dict = {}
column_names_surface_form = []
column_names_embedder_input = []
for table_schema in database_schema:
db_id = table_schema['db_id']
database_schema_dict[db_id] = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
for i, (table_id, column_name) in enumerate(column_names):
if table_id >= 0:
table_name = table_names[table_id]
column_name_embedder_input = table_name + ' . ' + column_name
else:
column_name_embedder_input = column_name
column_names_embedder_input.append(column_name_embedder_input.split())
for table_name in table_names:
column_name_embedder_input = table_name + ' . *'
column_names_embedder_input.append(column_name_embedder_input.split())
database_schema = database_schema_dict
return database_schema, column_names_surface_form, column_names_embedder_input
def get_all_utterances(self,
dataset,
max_input_length=float('inf'),
max_output_length=float('inf')):
""" Returns all utterances in a dataset."""
items = []
for interaction in dataset.examples:
for i, utterance in enumerate(interaction.utterances):
if utterance.length_valid(max_input_length, max_output_length):
items.append(atis_batch.UtteranceItem(interaction, i))
return items
def get_all_interactions(self,
dataset,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf'),
sorted_by_length=False):
"""Gets all interactions in a dataset that fit the criteria.
Inputs:
dataset (ATISDatasetSplit): The dataset to use.
max_interaction_length (int): Maximum interaction length to keep.
max_input_length (int): Maximum input sequence length to keep.
max_output_length (int): Maximum output sequence length to keep.
sorted_by_length (bool): Whether to sort the examples by interaction length.
"""
ints = [
atis_batch.InteractionItem(
interaction,
max_input_length,
max_output_length,
self.entities_dictionary,
max_interaction_length) for interaction in dataset.examples]
if sorted_by_length:
return sorted(ints, key=lambda x: len(x))[::-1]
else:
return ints
# This defines a standard way of training: each example is an utterance, and
# the batch can contain unrelated utterances.
def get_utterance_batches(self,
batch_size,
max_input_length=float('inf'),
max_output_length=float('inf'),
randomize=True):
"""Gets batches of utterances in the data.
Inputs:
batch_size (int): Batch size to use.
max_input_length (int): Maximum length of input to keep.
max_output_length (int): Maximum length of output to use.
randomize (bool): Whether to randomize the ordering.
"""
# First, get all interactions and the positions of the utterances that are
# possible in them.
items = self.get_all_utterances(self.train_data,
max_input_length,
max_output_length)
if randomize:
random.shuffle(items)
batches = []
current_batch_items = []
for item in items:
if len(current_batch_items) >= batch_size:
batches.append(atis_batch.UtteranceBatch(current_batch_items))
current_batch_items = []
current_batch_items.append(item)
batches.append(atis_batch.UtteranceBatch(current_batch_items))
assert sum([len(batch) for batch in batches]) == len(items)
return batches
def get_interaction_batches(self,
batch_size,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf'),
randomize=True):
"""Gets batches of interactions in the data.
Inputs:
batch_size (int): Batch size to use.
max_interaction_length (int): Maximum length of interaction to keep
max_input_length (int): Maximum length of input to keep.
max_output_length (int): Maximum length of output to keep.
randomize (bool): Whether to randomize the ordering.
"""
items = self.get_all_interactions(self.train_data,
max_interaction_length,
max_input_length,
max_output_length,
sorted_by_length=not randomize)
if randomize:
random.shuffle(items)
batches = []
current_batch_items = []
for item in items:
if len(current_batch_items) >= batch_size:
batches.append(
atis_batch.InteractionBatch(current_batch_items))
current_batch_items = []
current_batch_items.append(item)
batches.append(atis_batch.InteractionBatch(current_batch_items))
assert sum([len(batch) for batch in batches]) == len(items)
return batches
def get_random_utterances(self,
num_samples,
max_input_length=float('inf'),
max_output_length=float('inf')):
"""Gets a random selection of utterances in the data.
Inputs:
num_samples (bool): Number of random utterances to get.
max_input_length (int): Limit of input length.
max_output_length (int): Limit on output length.
"""
items = self.get_all_utterances(self.train_data,
max_input_length,
max_output_length)
random.shuffle(items)
return items[:num_samples]
def get_random_interactions(self,
num_samples,
max_interaction_length=float('inf'),
max_input_length=float('inf'),
max_output_length=float('inf')):
"""Gets a random selection of interactions in the data.
Inputs:
num_samples (bool): Number of random interactions to get.
max_input_length (int): Limit of input length.
max_output_length (int): Limit on output length.
"""
items = self.get_all_interactions(self.train_data,
max_interaction_length,
max_input_length,
max_output_length)
random.shuffle(items)
return items[:num_samples]
def num_utterances(dataset):
"""Returns the total number of utterances in the dataset."""
return sum([len(interaction) for interaction in dataset.examples])

View File

@ -0,0 +1,74 @@
"""Gets and stores vocabulary for the ATIS data."""
from . import snippets
from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK
INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK]
OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK]
MIN_INPUT_OCCUR = 1
MIN_OUTPUT_OCCUR = 1
class ATISVocabulary():
""" Stores the vocabulary for the ATIS data.
Attributes:
raw_vocab (Vocabulary): Vocabulary object.
tokens (set of str): Set of all of the strings in the vocabulary.
inorder_tokens (list of str): List of all tokens, with a strict and
unchanging order.
"""
def __init__(self,
token_sequences,
filename,
params,
is_input='input',
min_occur=1,
anonymizer=None,
skip=None):
if is_input=='input':
functional_types = INPUT_FN_TYPES
elif is_input=='output':
functional_types = OUTPUT_FN_TYPES
elif is_input=='schema':
functional_types = [UNK_TOK]
else:
functional_types = []
self.raw_vocab = Vocabulary(
token_sequences,
filename,
functional_types=functional_types,
min_occur=min_occur,
ignore_fn=lambda x: snippets.is_snippet(x) or (
anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) )
self.tokens = set(self.raw_vocab.token_to_id.keys())
self.inorder_tokens = self.raw_vocab.id_to_token
assert len(self.inorder_tokens) == len(self.raw_vocab)
def __len__(self):
return len(self.raw_vocab)
def token_to_id(self, token):
""" Maps from a token to a unique ID.
Inputs:
token (str): The token to look up.
Returns:
int, uniquely identifying the token.
"""
return self.raw_vocab.token_to_id[token]
def id_to_token(self, identifier):
""" Maps from a unique integer to an identifier.
Inputs:
identifier (int): The unique ID.
Returns:
string, representing the token.
"""
return self.raw_vocab.id_to_token[identifier]

View File

@ -0,0 +1,56 @@
""" Utility functions for loading and processing ATIS data.
"""
import os
import pickle
class DatasetSplit:
"""Stores a split of the ATIS dataset.
Attributes:
examples (list of Interaction): Stores the examples in the split.
"""
def __init__(self, processed_filename, raw_filename, load_function):
if os.path.exists(processed_filename):
print("Loading preprocessed data from " + processed_filename)
with open(processed_filename, 'rb') as infile:
self.examples = pickle.load(infile)
else:
print(
"Loading raw data from " +
raw_filename +
" and writing to " +
processed_filename)
infile = open(raw_filename, 'rb')
examples_from_file = pickle.load(infile)
assert isinstance(examples_from_file, list), raw_filename + \
" does not contain a list of examples"
infile.close()
self.examples = []
for example in examples_from_file:
obj, keep = load_function(example)
if keep:
self.examples.append(obj)
print("Loaded " + str(len(self.examples)) + " examples")
outfile = open(processed_filename, 'wb')
pickle.dump(self.examples, outfile)
outfile.close()
def get_ex_properties(self, function):
""" Applies some function to the examples in the dataset.
Inputs:
function: (lambda Interaction -> T): Function to apply to all
examples.
Returns
list of the return value of the function
"""
elems = []
for example in self.examples:
elems.append(function(example))
return elems

View File

@ -0,0 +1,63 @@
""" Classes for keeping track of the entities in a natural language string. """
import json
class NLtoSQLDict:
"""
Entity dict file should contain, on each line, a JSON dictionary with
"input" and "output" keys specifying the string for the input and output
pairs. The idea is that the existence of the key in an input sequence
likely corresponds to the existence of the value in the output sequence.
The entity_dict should map keys (input strings) to a list of values (output
strings) where this property holds. This allows keys to map to multiple
output strings (e.g. for times).
"""
def __init__(self, entity_dict_filename):
self.entity_dict = {}
pairs = [json.loads(line)
for line in open(entity_dict_filename).readlines()]
for pair in pairs:
input_seq = pair["input"]
output_seq = pair["output"]
if input_seq not in self.entity_dict:
self.entity_dict[input_seq] = []
self.entity_dict[input_seq].append(output_seq)
def get_sql_entities(self, tokenized_nl_string):
"""
Gets the output-side entities which correspond to the input entities in
the input sequence.
Inputs:
tokenized_input_string: list of tokens in the input string.
Outputs:
set of output strings.
"""
assert len(tokenized_nl_string) > 0
flat_input_string = " ".join(tokenized_nl_string)
entities = []
# See if any input strings are in our input sequence, and add the
# corresponding output strings if so.
for entry, values in self.entity_dict.items():
in_middle = " " + entry + " " in flat_input_string
leftspace = " " + entry
at_end = leftspace in flat_input_string and flat_input_string.endswith(
leftspace)
rightspace = entry + " "
at_beginning = rightspace in flat_input_string and flat_input_string.startswith(
rightspace)
if in_middle or at_end or at_beginning:
for out_string in values:
entities.append(out_string)
# Also add any integers in the input string (these aren't in the entity)
# dict.
for token in tokenized_nl_string:
if token.isnumeric():
entities.append(token)
return entities

View File

@ -0,0 +1,312 @@
""" Contains the class for an interaction in ATIS. """
from . import anonymization as anon
from . import sql_util
from .snippets import expand_snippets
from .utterance import Utterance, OUTPUT_KEY, ANON_INPUT_KEY
import torch
class Schema:
def __init__(self, table_schema, simple=False):
if simple:
self.helper1(table_schema)
else:
self.helper2(table_schema)
def helper1(self, table_schema):
self.table_schema = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
column_keep_index = []
self.column_names_surface_form = []
self.column_names_surface_form_to_id = {}
for i, (table_id, column_name) in enumerate(column_names_original):
column_name_surface_form = column_name
column_name_surface_form = column_name_surface_form.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index.append(i)
column_keep_index_2 = []
for i, table_name in enumerate(table_names_original):
column_name_surface_form = table_name.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index_2.append(i)
self.column_names_embedder_input = []
self.column_names_embedder_input_to_id = {}
for i, (table_id, column_name) in enumerate(column_names):
column_name_embedder_input = column_name
if i in column_keep_index:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
for i, table_name in enumerate(table_names):
column_name_embedder_input = table_name
if i in column_keep_index_2:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
self.num_col = len(self.column_names_surface_form)
def helper2(self, table_schema):
self.table_schema = table_schema
column_names = table_schema['column_names']
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
assert len(column_names) == len(column_names_original) and len(table_names) == len(table_names_original)
column_keep_index = []
self.column_names_surface_form = []
self.column_names_surface_form_to_id = {}
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
column_name_surface_form = column_name
column_name_surface_form = column_name_surface_form.lower()
if column_name_surface_form not in self.column_names_surface_form_to_id:
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = len(self.column_names_surface_form) - 1
column_keep_index.append(i)
start_i = len(self.column_names_surface_form_to_id)
for i, table_name in enumerate(table_names_original):
column_name_surface_form = '{}.*'.format(table_name.lower())
self.column_names_surface_form.append(column_name_surface_form)
self.column_names_surface_form_to_id[column_name_surface_form] = i + start_i
self.column_names_embedder_input = []
self.column_names_embedder_input_to_id = {}
for i, (table_id, column_name) in enumerate(column_names):
if table_id >= 0:
table_name = table_names[table_id]
column_name_embedder_input = table_name + ' . ' + column_name
else:
column_name_embedder_input = column_name
if i in column_keep_index:
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = len(self.column_names_embedder_input) - 1
start_i = len(self.column_names_embedder_input_to_id)
for i, table_name in enumerate(table_names):
column_name_embedder_input = table_name + ' . *'
self.column_names_embedder_input.append(column_name_embedder_input)
self.column_names_embedder_input_to_id[column_name_embedder_input] = i + start_i
assert len(self.column_names_surface_form) == len(self.column_names_surface_form_to_id) == len(self.column_names_embedder_input) == len(self.column_names_embedder_input_to_id)
max_id_1 = max(v for k,v in self.column_names_surface_form_to_id.items())
max_id_2 = max(v for k,v in self.column_names_embedder_input_to_id.items())
assert (len(self.column_names_surface_form) - 1) == max_id_2 == max_id_1
self.num_col = len(self.column_names_surface_form)
def __len__(self):
return self.num_col
def in_vocabulary(self, column_name, surface_form=False):
if surface_form:
return column_name in self.column_names_surface_form_to_id
else:
return column_name in self.column_names_embedder_input_to_id
def column_name_embedder_bow(self, column_name, surface_form=False, column_name_token_embedder=None):
assert self.in_vocabulary(column_name, surface_form)
if surface_form:
column_name_id = self.column_names_surface_form_to_id[column_name]
column_name_embedder_input = self.column_names_embedder_input[column_name_id]
else:
column_name_embedder_input = column_name
column_name_embeddings = [column_name_token_embedder(token) for token in column_name_embedder_input.split()]
column_name_embeddings = torch.stack(column_name_embeddings, dim=0)
return torch.mean(column_name_embeddings, dim=0)
def set_column_name_embeddings(self, column_name_embeddings):
self.column_name_embeddings = column_name_embeddings
assert len(self.column_name_embeddings) == self.num_col
def column_name_embedder(self, column_name, surface_form=False):
assert self.in_vocabulary(column_name, surface_form)
if surface_form:
column_name_id = self.column_names_surface_form_to_id[column_name]
else:
column_name_id = self.column_names_embedder_input_to_id[column_name]
return self.column_name_embeddings[column_name_id]
class Interaction:
""" ATIS interaction class.
Attributes:
utterances (list of Utterance): The utterances in the interaction.
snippets (list of Snippet): The snippets that appear through the interaction.
anon_tok_to_ent:
identifier (str): Unique identifier for the interaction in the dataset.
"""
def __init__(self,
utterances,
schema,
snippets,
anon_tok_to_ent,
identifier,
params):
self.utterances = utterances
self.schema = schema
self.snippets = snippets
self.anon_tok_to_ent = anon_tok_to_ent
self.identifier = identifier
# Ensure that each utterance's input and output sequences, when remapped
# without anonymization or snippets, are the same as the original
# version.
for i, utterance in enumerate(self.utterances):
deanon_input = self.deanonymize(utterance.input_seq_to_use,
ANON_INPUT_KEY)
assert deanon_input == utterance.original_input_seq, "Anonymized sequence [" \
+ " ".join(utterance.input_seq_to_use) + "] is not the same as [" \
+ " ".join(utterance.original_input_seq) + "] when deanonymized (is [" \
+ " ".join(deanon_input) + "] instead)"
desnippet_gold = self.expand_snippets(utterance.gold_query_to_use)
deanon_gold = self.deanonymize(desnippet_gold, OUTPUT_KEY)
assert deanon_gold == utterance.original_gold_query, \
"Anonymized and/or snippet'd query " \
+ " ".join(utterance.gold_query_to_use) + " is not the same as " \
+ " ".join(utterance.original_gold_query)
def __str__(self):
string = "Utterances:\n"
for utterance in self.utterances:
string += str(utterance) + "\n"
string += "Anonymization dictionary:\n"
for ent_tok, deanon in self.anon_tok_to_ent.items():
string += ent_tok + "\t" + str(deanon) + "\n"
return string
def __len__(self):
return len(self.utterances)
def deanonymize(self, sequence, key):
""" Deanonymizes a predicted query or an input utterance.
Inputs:
sequence (list of str): The sequence to deanonymize.
key (str): The key in the anonymization table, e.g. NL or SQL.
"""
return anon.deanonymize(sequence, self.anon_tok_to_ent, key)
def expand_snippets(self, sequence):
""" Expands snippets for a sequence.
Inputs:
sequence (list of str): A SQL query.
"""
return expand_snippets(sequence, self.snippets)
def input_seqs(self):
in_seqs = []
for utterance in self.utterances:
in_seqs.append(utterance.input_seq_to_use)
return in_seqs
def output_seqs(self):
out_seqs = []
for utterance in self.utterances:
out_seqs.append(utterance.gold_query_to_use)
return out_seqs
def load_function(parameters,
nl_to_sql_dict,
anonymizer,
database_schema=None):
def fn(interaction_example):
keep = False
raw_utterances = interaction_example["interaction"]
if "database_id" in interaction_example:
database_id = interaction_example["database_id"]
interaction_id = interaction_example["interaction_id"]
identifier = database_id + '/' + str(interaction_id)
else:
identifier = interaction_example["id"]
schema = None
if database_schema:
if 'removefrom' not in parameters.data_directory:
schema = Schema(database_schema[database_id], simple=True)
else:
schema = Schema(database_schema[database_id])
snippet_bank = []
utterance_examples = []
anon_tok_to_ent = {}
for utterance in raw_utterances:
available_snippets = [
snippet for snippet in snippet_bank if snippet.index <= 1]
proc_utterance = Utterance(utterance,
available_snippets,
nl_to_sql_dict,
parameters,
anon_tok_to_ent,
anonymizer)
keep_utterance = proc_utterance.keep
if schema:
assert keep_utterance
if keep_utterance:
keep = True
utterance_examples.append(proc_utterance)
# Update the snippet bank, and age each snippet in it.
if parameters.use_snippets:
if 'atis' in parameters.data_directory:
snippets = sql_util.get_subtrees(
proc_utterance.anonymized_gold_query,
proc_utterance.available_snippets)
else:
snippets = sql_util.get_subtrees_simple(
proc_utterance.anonymized_gold_query,
proc_utterance.available_snippets)
for snippet in snippets:
snippet.assign_id(len(snippet_bank))
snippet_bank.append(snippet)
for snippet in snippet_bank:
snippet.increase_age()
interaction = Interaction(utterance_examples,
schema,
snippet_bank,
anon_tok_to_ent,
identifier,
parameters)
return interaction, keep
return fn

View File

@ -0,0 +1,105 @@
""" Contains the Snippet class and methods for handling snippets.
Attributes:
SNIPPET_PREFIX: string prefix for snippets.
"""
SNIPPET_PREFIX = "SNIPPET_"
def is_snippet(token):
""" Determines whether a token is a snippet or not.
Inputs:
token (str): The token to check.
Returns:
bool, indicating whether it's a snippet.
"""
return token.startswith(SNIPPET_PREFIX)
def expand_snippets(sequence, snippets):
""" Given a sequence and a list of snippets, expand the snippets in the sequence.
Inputs:
sequence (list of str): Query containing snippet references.
snippets (list of Snippet): List of available snippets.
return list of str representing the expanded sequence
"""
snippet_id_to_snippet = {}
for snippet in snippets:
assert snippet.name not in snippet_id_to_snippet
snippet_id_to_snippet[snippet.name] = snippet
expanded_seq = []
for token in sequence:
if token in snippet_id_to_snippet:
expanded_seq.extend(snippet_id_to_snippet[token].sequence)
else:
assert not is_snippet(token)
expanded_seq.append(token)
return expanded_seq
def snippet_index(token):
""" Returns the index of a snippet.
Inputs:
token (str): The snippet to check.
Returns:
integer, the index of the snippet.
"""
assert is_snippet(token)
return int(token.split("_")[-1])
class Snippet():
""" Contains a snippet. """
def __init__(self,
sequence,
startpos,
sql,
age=0):
self.sequence = sequence
self.startpos = startpos
self.sql = sql
# TODO: age vs. index?
self.age = age
self.index = 0
self.name = ""
self.embedding = None
self.endpos = self.startpos + len(self.sequence)
assert self.endpos < len(self.sql), "End position of snippet is " + str(
self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")"
assert self.sequence == self.sql[self.startpos:self.endpos], \
"Value of snippet (" + " ".join(self.sequence) + ") " \
"is not the same as SQL at the same positions (" \
+ " ".join(self.sql[self.startpos:self.endpos]) + ")"
def __str__(self):
return self.name + "\t" + \
str(self.age) + "\t" + " ".join(self.sequence)
def __len__(self):
return len(self.sequence)
def increase_age(self):
""" Ages a snippet by one. """
self.index += 1
def assign_id(self, number):
""" Assigns the name of the snippet to be the prefix + the number. """
self.name = SNIPPET_PREFIX + str(number)
def set_embedding(self, embedding):
""" Sets the embedding of the snippet.
Inputs:
embedding (dy.Expression)
"""
self.embedding = embedding

View File

@ -0,0 +1,445 @@
import copy
import pymysql
import random
import signal
import sqlparse
from . import util
from .snippets import Snippet
from sqlparse import tokens as token_types
from sqlparse import sql as sql_types
interesting_selects = ["DISTINCT", "MAX", "MIN", "count"]
ignored_subtrees = [["1", "=", "1"]]
# strip_whitespace_front
# Strips whitespace and punctuation from the front of a SQL token list.
#
# Inputs:
# token_list: the token list.
#
# Outputs:
# new token list.
def strip_whitespace_front(token_list):
new_token_list = []
found_valid = False
for token in token_list:
if not (token.is_whitespace or token.ttype ==
token_types.Punctuation) or found_valid:
found_valid = True
new_token_list.append(token)
return new_token_list
# strip_whitespace
# Strips whitespace from a token list.
#
# Inputs:
# token_list: the token list.
#
# Outputs:
# new token list with no whitespace/punctuation surrounding.
def strip_whitespace(token_list):
subtokens = strip_whitespace_front(token_list)
subtokens = strip_whitespace_front(subtokens[::-1])[::-1]
return subtokens
# token_list_to_seq
# Converts a Token list to a sequence of strings, stripping out surrounding
# punctuation and all whitespace.
#
# Inputs:
# token_list: the list of tokens.
#
# Outputs:
# list of strings
def token_list_to_seq(token_list):
subtokens = strip_whitespace(token_list)
seq = []
flat = sqlparse.sql.TokenList(subtokens).flatten()
for i, token in enumerate(flat):
strip_token = str(token).strip()
if len(strip_token) > 0:
seq.append(strip_token)
if len(seq) > 0:
if seq[0] == "(" and seq[-1] == ")":
seq = seq[1:-1]
return seq
# TODO: clean this up
# find_subtrees
# Finds subtrees for a subsequence of SQL.
#
# Inputs:
# sequence: sequence of SQL tokens.
# current_subtrees: current list of subtrees.
#
# Optional inputs:
# where_parent: whether the parent of the current sequence was a where clause
# keep_conj_subtrees: whether to look for a conjunction in this sequence and
# keep its arguments
def find_subtrees(sequence,
current_subtrees,
where_parent=False,
keep_conj_subtrees=False):
# If the parent of the subsequence was a WHERE clause, keep everything in the
# sequence except for the beginning WHERE and any surrounding parentheses.
if where_parent:
# Strip out the beginning WHERE, and any punctuation or whitespace at the
# beginning or end of the token list.
seq = token_list_to_seq(sequence.tokens[1:])
if len(seq) > 0 and seq not in current_subtrees:
current_subtrees.append(seq)
# If the current sequence has subtokens, i.e. if it's a node that can be
# expanded, check for a conjunction in its subtrees, and expand its subtrees.
# Also check for any SELECT statements and keep track of what follows.
if sequence.is_group:
if keep_conj_subtrees:
subtokens = strip_whitespace(sequence.tokens)
# Check if there is a conjunction in the subsequence. If so, keep the
# children. Also make sure you don't split where AND is used within a
# child -- the subtokens sequence won't treat those ANDs differently (a
# bit hacky but it works)
has_and = False
for i, token in enumerate(subtokens):
if token.value == "OR" or token.value == "AND":
has_and = True
break
if has_and:
and_subtrees = []
current_subtree = []
for i, token in enumerate(subtokens):
if token.value == "OR" or (token.value == "AND" and i - 4 >= 0 and i - 4 < len(
subtokens) and subtokens[i - 4].value != "BETWEEN"):
and_subtrees.append(current_subtree)
current_subtree = []
else:
current_subtree.append(token)
and_subtrees.append(current_subtree)
for subtree in and_subtrees:
seq = token_list_to_seq(subtree)
if len(seq) > 0 and seq[0] == "WHERE":
seq = seq[1:]
if seq not in current_subtrees:
current_subtrees.append(seq)
in_select = False
select_toks = []
for i, token in enumerate(sequence.tokens):
# Mark whether this current token is a WHERE.
is_where = (isinstance(token, sql_types.Where))
# If you are in a SELECT, start recording what follows until you hit a
# FROM
if token.value == "SELECT":
in_select = True
elif in_select:
select_toks.append(token)
if token.value == "FROM":
in_select = False
seq = []
if len(sequence.tokens) > i + 2:
seq = token_list_to_seq(
select_toks + [sequence.tokens[i + 2]])
if seq not in current_subtrees and len(
seq) > 0 and seq[0] in interesting_selects:
current_subtrees.append(seq)
select_toks = []
# Recursively find subtrees in the children of the node.
find_subtrees(token,
current_subtrees,
is_where,
where_parent or keep_conj_subtrees)
# get_subtrees
def get_subtrees(sql, oldsnippets=[]):
parsed = sqlparse.parse(" ".join(sql))[0]
subtrees = []
find_subtrees(parsed, subtrees)
final_subtrees = []
for subtree in subtrees:
if subtree not in ignored_subtrees:
final_version = []
keep = True
parens_counts = 0
for i, token in enumerate(subtree):
if token == ".":
newtoken = final_version[-1] + "." + subtree[i + 1]
final_version = final_version[:-1] + [newtoken]
keep = False
elif keep:
final_version.append(token)
else:
keep = True
if token == "(":
parens_counts -= 1
elif token == ")":
parens_counts += 1
if parens_counts == 0:
final_subtrees.append(final_version)
snippets = []
sql = [str(tok) for tok in sql]
for subtree in final_subtrees:
startpos = -1
for i in range(len(sql) - len(subtree) + 1):
if sql[i:i + len(subtree)] == subtree:
startpos = i
if startpos >= 0 and startpos + len(subtree) < len(sql):
age = 0
for prevsnippet in oldsnippets:
if prevsnippet.sequence == subtree:
age = prevsnippet.age + 1
snippet = Snippet(subtree, startpos, sql, age=age)
snippets.append(snippet)
return snippets
def get_subtrees_simple(sql, oldsnippets=[]):
sql_string = " ".join(sql)
format_sql = sqlparse.format(sql_string, reindent=True)
# get subtrees
subtrees = []
for sub_sql in format_sql.split('\n'):
sub_sql = sub_sql.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ')
subtree = sub_sql.strip().split()
if len(subtree) > 1:
subtrees.append(subtree)
final_subtrees = subtrees
snippets = []
sql = [str(tok) for tok in sql]
for subtree in final_subtrees:
startpos = -1
for i in range(len(sql) - len(subtree) + 1):
if sql[i:i + len(subtree)] == subtree:
startpos = i
if startpos >= 0 and startpos + len(subtree) <= len(sql):
age = 0
for prevsnippet in oldsnippets:
if prevsnippet.sequence == subtree:
age = prevsnippet.age + 1
new_sql = sql + [';']
snippet = Snippet(subtree, startpos, new_sql, age=age)
snippets.append(snippet)
return snippets
conjunctions = {"AND", "OR", "WHERE"}
def get_all_in_parens(sequence):
if sequence[-1] == ";":
sequence = sequence[:-1]
if not "(" in sequence:
return []
if sequence[0] == "(" and sequence[-1] == ")":
in_parens = sequence[1:-1]
return [in_parens] + get_all_in_parens(in_parens)
else:
paren_subseqs = []
current_seq = []
num_parens = 0
in_parens = False
for token in sequence:
if in_parens:
current_seq.append(token)
if token == ")":
num_parens -= 1
if num_parens == 0:
in_parens = False
paren_subseqs.append(current_seq)
current_seq = []
elif token == "(":
in_parens = True
current_seq.append(token)
if token == "(":
num_parens += 1
all_subseqs = []
for subseq in paren_subseqs:
all_subseqs.extend(get_all_in_parens(subseq))
return all_subseqs
def split_by_conj(sequence):
num_parens = 0
current_seq = []
subsequences = []
for token in sequence:
if num_parens == 0:
if token in conjunctions:
subsequences.append(current_seq)
current_seq = []
break
current_seq.append(token)
if token == "(":
num_parens += 1
elif token == ")":
num_parens -= 1
assert num_parens >= 0
return subsequences
def get_sql_snippets(sequence):
# First, get all subsequences of the sequence that are surrounded by
# parentheses.
all_in_parens = get_all_in_parens(sequence)
all_subseq = []
# Then for each one, split the sequence on conjunctions (AND/OR).
for seq in all_in_parens:
subsequences = split_by_conj(seq)
all_subseq.append(seq)
all_subseq.extend(subsequences)
# Finally, also get "interesting" selects
for i, seq in enumerate(all_subseq):
print(str(i) + "\t" + " ".join(seq))
exit()
# add_snippets_to_query
def add_snippets_to_query(snippets, ignored_entities, query, prob_align=1.):
query_copy = copy.copy(query)
# Replace the longest snippets first, so sort by length descending.
sorted_snippets = sorted(snippets, key=lambda s: len(s.sequence))[::-1]
for snippet in sorted_snippets:
ignore = False
snippet_seq = snippet.sequence
# TODO: continue here
# If it contains an ignored entity, then don't use it.
for entity in ignored_entities:
ignore = ignore or util.subsequence(entity, snippet_seq)
# No NL entities found in snippet, then see if snippet is a substring of
# the gold sequence
if not ignore:
snippet_length = len(snippet_seq)
# Iterate through gold sequence to see if it's a subsequence.
for start_idx in range(len(query_copy) - snippet_length + 1):
if query_copy[start_idx:start_idx +
snippet_length] == snippet_seq:
align = random.random() < prob_align
if align:
prev_length = len(query_copy)
# At the start position of the snippet, replace with an
# identifier.
query_copy[start_idx] = snippet.name
# Then cut out the indices which were collapsed into
# the snippet.
query_copy = query_copy[:start_idx + 1] + \
query_copy[start_idx + snippet_length:]
# Make sure the length is as expected
assert len(query_copy) == prev_length - \
(snippet_length - 1)
return query_copy
def execution_results(query, username, password, timeout=3):
connection = pymysql.connect(user=username, password=password)
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException
signal.signal(signal.SIGALRM, timeout_handler)
syntactic = True
semantic = True
table = []
with connection.cursor() as cursor:
signal.alarm(timeout)
try:
cursor.execute("SET sql_mode='IGNORE_SPACE';")
cursor.execute("use atis3;")
cursor.execute(query)
table = cursor.fetchall()
cursor.close()
except TimeoutException:
signal.alarm(0)
cursor.close()
except pymysql.err.ProgrammingError:
syntactic = False
semantic = False
cursor.close()
except pymysql.err.InternalError:
semantic = False
cursor.close()
except Exception as e:
signal.alarm(0)
signal.alarm(0)
cursor.close()
signal.alarm(0)
connection.close()
return (syntactic, semantic, sorted(table))
def executable(query, username, password, timeout=2):
return execution_results(query, username, password, timeout)[1]
def fix_parentheses(sequence):
num_left = sequence.count("(")
num_right = sequence.count(")")
if num_right < num_left:
fixed_sequence = sequence[:-1] + \
[")" for _ in range(num_left - num_right)] + [sequence[-1]]
return fixed_sequence
return sequence

View File

@ -0,0 +1,82 @@
"""Tokenizers for natural language SQL queries, and lambda calculus."""
import nltk
import sqlparse
def nl_tokenize(string):
"""Tokenizes a natural language string into tokens.
Inputs:
string: the string to tokenize.
Outputs:
a list of tokens.
Assumes data is space-separated (this is true of ZC07 data in ATIS2/3).
"""
return nltk.word_tokenize(string)
def sql_tokenize(string):
""" Tokenizes a SQL statement into tokens.
Inputs:
string: string to tokenize.
Outputs:
a list of tokens.
"""
tokens = []
statements = sqlparse.parse(string)
# SQLparse gives you a list of statements.
for statement in statements:
# Flatten the tokens in each statement and add to the tokens list.
flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten()
for token in flat_tokens:
strip_token = str(token).strip()
if len(strip_token) > 0:
tokens.append(strip_token)
newtokens = []
keep = True
for i, token in enumerate(tokens):
if token == ".":
newtoken = newtokens[-1] + "." + tokens[i + 1]
newtokens = newtokens[:-1] + [newtoken]
keep = False
elif keep:
newtokens.append(token)
else:
keep = True
return newtokens
def lambda_tokenize(string):
""" Tokenizes a lambda-calculus statement into tokens.
Inputs:
string: a lambda-calculus string
Outputs:
a list of tokens.
"""
space_separated = string.split(" ")
new_tokens = []
# Separate the string by spaces, then separate based on existence of ( or
# ).
for token in space_separated:
tokens = []
current_token = ""
for char in token:
if char == ")" or char == "(":
tokens.append(current_token)
tokens.append(char)
current_token = ""
else:
current_token += char
tokens.append(current_token)
new_tokens.extend([tok for tok in tokens if tok])
return new_tokens

View File

@ -0,0 +1,16 @@
"""Contains various utility functions."""
def subsequence(first_sequence, second_sequence):
"""
Returns whether the first sequence is a subsequence of the second sequence.
Inputs:
first_sequence (list): A sequence.
second_sequence (list): Another sequence.
Returns:
Boolean indicating whether first_sequence is a subsequence of second_sequence.
"""
for startidx in range(len(second_sequence) - len(first_sequence) + 1):
if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence:
return True
return False

View File

@ -0,0 +1,115 @@
""" Contains the Utterance class. """
from . import sql_util
from . import tokenizers
ANON_INPUT_KEY = "cleaned_nl"
OUTPUT_KEY = "sql"
class Utterance:
""" Utterance class. """
def process_input_seq(self,
anonymize,
anonymizer,
anon_tok_to_ent):
assert not anon_tok_to_ent or anonymize
assert not anonymize or anonymizer
if anonymize:
assert anonymizer
self.input_seq_to_use = anonymizer.anonymize(
self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True)
else:
self.input_seq_to_use = self.original_input_seq
def process_gold_seq(self,
output_sequences,
nl_to_sql_dict,
available_snippets,
anonymize,
anonymizer,
anon_tok_to_ent):
# Get entities in the input sequence:
# anonymized entity types
# othe recognized entities (this includes "flight")
entities_in_input = [
[tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent]
entities_in_input.extend(
nl_to_sql_dict.get_sql_entities(
self.input_seq_to_use))
# Get the shortest gold query (this is what we use to train)
shortest_gold_and_results = min(output_sequences,
key=lambda x: len(x[0]))
# Tokenize and anonymize it if necessary.
self.original_gold_query = shortest_gold_and_results[0]
self.gold_sql_results = shortest_gold_and_results[1]
self.contained_entities = entities_in_input
# Keep track of all gold queries and the resulting tables so that we can
# give credit if it predicts a different correct sequence.
self.all_gold_queries = output_sequences
self.anonymized_gold_query = self.original_gold_query
if anonymize:
self.anonymized_gold_query = anonymizer.anonymize(
self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False)
# Add snippets to it.
self.gold_query_to_use = sql_util.add_snippets_to_query(
available_snippets, entities_in_input, self.anonymized_gold_query)
def __init__(self,
example,
available_snippets,
nl_to_sql_dict,
params,
anon_tok_to_ent={},
anonymizer=None):
# Get output and input sequences from the dictionary representation.
output_sequences = example[OUTPUT_KEY]
self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key])
self.available_snippets = available_snippets
self.keep = False
# pruned_output_sequences = []
# for sequence in output_sequences:
# if len(sequence[0]) > 3:
# pruned_output_sequences.append(sequence)
# output_sequences = pruned_output_sequences
if len(output_sequences) > 0 and len(self.original_input_seq) > 0:
# Only keep this example if there is at least one output sequence.
self.keep = True
if len(output_sequences) == 0 or len(self.original_input_seq) == 0:
return
# Process the input sequence
self.process_input_seq(params.anonymize,
anonymizer,
anon_tok_to_ent)
# Process the gold sequence
self.process_gold_seq(output_sequences,
nl_to_sql_dict,
self.available_snippets,
params.anonymize,
anonymizer,
anon_tok_to_ent)
def __str__(self):
string = "Original input: " + " ".join(self.original_input_seq) + "\n"
string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n"
string += "Original output: " + " ".join(self.original_gold_query) + "\n"
string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n"
string += "Snippets:\n"
for snippet in self.available_snippets:
string += str(snippet) + "\n"
return string
def length_valid(self, input_limit, output_limit):
return (len(self.input_seq_to_use) < input_limit \
and len(self.gold_query_to_use) < output_limit)

View File

@ -0,0 +1,98 @@
"""Contains class and methods for storing and computing a vocabulary from text."""
import operator
import os
import pickle
# Special sequencing tokens.
UNK_TOK = "_UNK" # Replaces out-of-vocabulary words.
EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end.
DEL_TOK = ";"
class Vocabulary:
"""Vocabulary class: stores information about words in a corpus.
Members:
functional_types (list of str): Functional vocabulary words, such as EOS.
max_size (int): The maximum size of vocabulary to keep.
min_occur (int): The minimum number of times a word should occur to keep it.
id_to_token (list of str): Ordered list of word types.
token_to_id (dict str->int): Maps from each unique word type to its index.
"""
def get_vocab(self, sequences, ignore_fn):
"""Gets vocabulary from a list of sequences.
Inputs:
sequences (list of list of str): Sequences from which to compute the vocabulary.
ignore_fn (lambda str: bool): Function used to tell whether to ignore a
token during computation of the vocabulary.
Returns:
list of str, representing the unique word types in the vocabulary.
"""
type_counts = {}
for sequence in sequences:
for token in sequence:
if not ignore_fn(token):
if token not in type_counts:
type_counts[token] = 0
type_counts[token] += 1
# Create sorted list of tokens, by their counts. Reverse so it is in order of
# most frequent to least frequent.
sorted_type_counts = sorted(sorted(type_counts.items()),
key=operator.itemgetter(1))[::-1]
sorted_types = [typecount[0]
for typecount in sorted_type_counts if typecount[1] >= self.min_occur]
# Append the necessary functional tokens.
sorted_types = self.functional_types + sorted_types
# Cut off if vocab_size is set (nonnegative)
if self.max_size >= 0:
vocab = sorted_types[:max(self.max_size, len(sorted_types))]
else:
vocab = sorted_types
return vocab
def __init__(self,
sequences,
filename,
functional_types=None,
max_size=-1,
min_occur=0,
ignore_fn=lambda x: False):
self.functional_types = functional_types
self.max_size = max_size
self.min_occur = min_occur
vocab = self.get_vocab(sequences, ignore_fn)
self.id_to_token = []
self.token_to_id = {}
for i, word_type in enumerate(vocab):
self.id_to_token.append(word_type)
self.token_to_id[word_type] = i
# Load the previous vocab, if it exists.
if os.path.exists(filename):
infile = open(filename, 'rb')
loaded_vocab = pickle.load(infile)
infile.close()
print("Loaded vocabulary from " + str(filename))
if loaded_vocab.id_to_token != self.id_to_token \
or loaded_vocab.token_to_id != self.token_to_id:
print("Loaded vocabulary is different than generated vocabulary.")
else:
print("Writing vocabulary to " + str(filename))
outfile = open(filename, 'wb')
pickle.dump(self, outfile)
outfile.close()
def __len__(self):
return len(self.id_to_token)

View File

@ -0,0 +1,36 @@
import argparse
import os
import sqlite3 as db
import pprint
def process_schema(args):
schema_name = args.db
schema_path = os.path.join(args.root_path, schema_name, schema_name + '.sqlite')
print("load db data from", schema_path)
conn = db.connect(schema_path)
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables_name = cur.fetchall()
table_name_lst = [tuple[0] for tuple in tables_name]
print(table_name_lst)
for table_name in table_name_lst:
cur.execute("SELECT * FROM %s" % table_name)
col_name_list = [tuple[0] for tuple in cur.description]
print(table_name, col_name_list)
tables = cur.fetchall()
#print(tables)
def main(args):
process_schema(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="./data/database")
parser.add_argument("--save_path", type=str, default="./database_content.json")
parser.add_argument("--db", type=str, default="aircraft")
args = parser.parse_args()
main(args)

1
r2sql/sparc/entities.txt Normal file
View File

@ -0,0 +1 @@
{"input": "northeast express regional airlines", "output": ["'2V'"]}

View File

@ -0,0 +1,866 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os, sys
import json
import sqlite3
import traceback
import argparse
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)

View File

@ -0,0 +1,938 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os, sys
import json
import sqlite3
import traceback
import argparse
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for key, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
levels = ['easy', 'medium', 'hard', 'extra', 'all', "joint_all"]
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print("\n\n{:20} {:20} {:20} {:20} {:20} {:20}".format("", *turns))
counts = [scores[turn]['count'] for turn in turns]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== TRUN XECUTION ACCURACY =====================')
this_scores = [scores[turn]['exec'] for turn in turns]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== TRUN EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[turn]['exact'] for turn in turns]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = []
gseq_one = []
for l in f.readlines():
if len(l.strip()) == 0:
glist.append(gseq_one)
gseq_one = []
else:
lstrip = l.strip().split('\t')
gseq_one.append(lstrip)
#glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = []
pseq_one = []
for l in f.readlines():
if len(l.strip()) == 0:
plist.append(pseq_one)
pseq_one = []
else:
pseq_one.append(l.strip().split('\t'))
#plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
# glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
evaluator = Evaluator()
turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for turn in turns:
scores[turn] = {'count': 0, 'exact': 0.}
scores[turn]['exec'] = 0
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
scores['joint_all']['count'] += 1
turn_scores = {"exec": [], "exact": []}
for idx, pg in enumerate(zip(p, g)):
p, g = pg
p_str = p[0]
p_str = p_str.replace("value", "1")
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
if idx > 3:
idx = ">4"
else:
idx += 1
turn_id = "turn " + str(idx)
scores[turn_id]['count'] += 1
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
scores[turn_id]['exec'] += 1
turn_scores['exec'].append(1)
else:
turn_scores['exec'].append(0)
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
turn_scores['exact'].append(0)
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
else:
turn_scores['exact'].append(1)
scores[turn_id]['exact'] += exact_score
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
if all(v == 1 for v in turn_scores["exec"]):
scores['joint_all']['exec'] += 1
if all(v == 1 for v in turn_scores["exact"]):
scores['joint_all']['exact'] += 1
for turn in turns:
if scores[turn]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[turn]['exec'] /= scores[turn]['count']
if etype in ["all", "match"]:
scores[turn]['exact'] /= scores[turn]['count']
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,53 @@
import json
import sys
predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line]
string_count = 0.
sem_count = 0.
syn_count = 0.
table_count = 0.
strict_table_count = 0.
precision_denom = 0.
precision = 0.
recall_denom = 0.
recall = 0.
f1_score = 0.
f1_denom = 0.
time = 0.
for prediction in predictions:
if prediction["correct_string"]:
string_count += 1.
if prediction["semantic"]:
sem_count += 1.
if prediction["syntactic"]:
syn_count += 1.
if prediction["correct_table"]:
table_count += 1.
if prediction["strict_correct_table"]:
strict_table_count += 1.
if prediction["gold_tables"] !="[[]]":
precision += prediction["table_prec"]
precision_denom += 1
if prediction["pred_table"] != "[]":
recall += prediction["table_rec"]
recall_denom += 1
if prediction["gold_tables"] != "[[]]":
f1_score += prediction["table_f1"]
f1_denom += 1
num_p = len(predictions)
print("string precision: " + str(string_count / num_p))
print("% semantic: " + str(sem_count / num_p))
print("% syntactic: " + str(syn_count / num_p))
print("table prec: " + str(table_count / num_p))
print("strict table prec: " + str(strict_table_count / num_p))
print("table row prec: " + str(precision / precision_denom))
print("table row recall: " + str(recall / recall_denom))
print("table row f1: " + str(f1_score / f1_denom))
print("inference time: " + str(time / num_p))

View File

@ -0,0 +1,569 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
import sys
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
# make limit value can work, cannot assume put 1 as a fake limit number
if type(toks[idx-1]) != int:
return idx, 1
return idx, int(toks[idx-1])
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema):
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

58
r2sql/sparc/logger.py Normal file
View File

@ -0,0 +1,58 @@
"""Contains the logging class."""
class Logger():
"""Attributes:
fileptr (file): File pointer for input/output.
lines (list of str): The lines read from the log.
"""
def __init__(self, filename, option):
self.fileptr = open(filename, option)
if option == "r":
self.lines = self.fileptr.readlines()
else:
self.lines = []
def put(self, string):
"""Writes to the file."""
self.fileptr.write(string + "\n")
self.fileptr.flush()
def close(self):
"""Closes the logger."""
self.fileptr.close()
def findlast(self, identifier, default=0.):
"""Finds the last line in the log with a certain value."""
for line in self.lines[::-1]:
if line.lower().startswith(identifier):
string = line.strip().split("\t")[1]
if string.replace(".", "").isdigit():
return float(string)
elif string.lower() == "true":
return True
elif string.lower() == "false":
return False
else:
return string
return default
def contains(self, string):
"""Dtermines whether the string is present in the log."""
for line in self.lines[::-1]:
if string.lower() in line.lower():
return True
return False
def findlast_log_before(self, before_str):
"""Finds the last entry in the log before another entry."""
loglines = []
in_line = False
for line in self.lines[::-1]:
if line.startswith(before_str):
in_line = True
elif in_line:
loglines.append(line)
if line.strip() == "" and in_line:
return "".join(loglines[::-1])
return "".join(loglines[::-1])

View File

@ -0,0 +1,654 @@
import argparse
import os
import sys
import pickle
import json
import shutil
import sqlparse
from postprocess_eval import get_candidate_tables
def write_interaction(interaction_list,split,output_dir):
json_split = os.path.join(output_dir,split+'.json')
pkl_split = os.path.join(output_dir,split+'.pkl')
with open(json_split, 'w') as outfile:
for interaction in interaction_list:
json.dump(interaction, outfile, indent = 4)
outfile.write('\n')
new_objs = []
for i, obj in enumerate(interaction_list):
new_interaction = []
for ut in obj["interaction"]:
sql = ut["sql"]
sqls = [sql]
tok_sql_list = []
for sql in sqls:
results = []
tokenized_sql = sql.split()
tok_sql_list.append((tokenized_sql, results))
ut["sql"] = tok_sql_list
new_interaction.append(ut)
obj["interaction"] = new_interaction
new_objs.append(obj)
with open(pkl_split,'wb') as outfile:
pickle.dump(new_objs, outfile)
return
def read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas_dict):
with open(database_schema_filename) as f:
database_schemas = json.load(f)
def get_schema_tokens(table_schema):
column_names_surface_form = []
column_names = []
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name,column_name)
else:
# this is just *
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
column_names.append(column_name.lower())
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
return column_names_surface_form, column_names
for table_schema in database_schemas:
database_id = table_schema['db_id']
database_schemas_dict[database_id] = table_schema
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
return schema_tokens, column_names, database_schemas_dict
def remove_from_with_join(format_sql_2):
used_tables_list = []
format_sql_3 = []
table_to_name = {}
table_list = []
old_table_to_name = {}
old_table_list = []
for sub_sql in format_sql_2.split('\n'):
if 'select ' in sub_sql:
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
old_table_list = table_list
old_table_to_name = table_to_name
table_to_name = {}
table_list = []
format_sql_3.append(sub_sql)
elif sub_sql.startswith('from'):
new_sub_sql = None
sub_sql_tokens = sub_sql.split()
for t_i, t in enumerate(sub_sql_tokens):
if t == 'as':
table_to_name[sub_sql_tokens[t_i+1]] = sub_sql_tokens[t_i-1]
table_list.append(sub_sql_tokens[t_i-1])
elif t == ')' and new_sub_sql is None:
# new_sub_sql keeps some trailing parts after ')'
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
if len(table_list) > 0:
# if it's a from clause with join
if new_sub_sql is not None:
format_sql_3.append(new_sub_sql)
used_tables_list.append(table_list)
else:
# if it's a from clause without join
table_list = old_table_list
table_to_name = old_table_to_name
assert 'join' not in sub_sql
if new_sub_sql is not None:
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
assert len(sub_sub_sql.split()) == 2
used_tables_list.append([sub_sub_sql.split()[1]])
format_sql_3.append(sub_sub_sql)
format_sql_3.append(new_sub_sql)
elif 'join' not in sub_sql:
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
if len(sub_sql.split()) == 2:
used_tables_list.append([sub_sql.split()[1]])
format_sql_3.append(sub_sql)
else:
print('bad from clause in remove_from_with_join')
exit()
else:
format_sql_3.append(sub_sql)
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
used_tables = []
for t in used_tables_list:
for tt in t:
used_tables.append(tt)
used_tables = list(set(used_tables))
return format_sql_3, used_tables, used_tables_list
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
format_sql_4 = []
table_name = None
for sub_sql in format_sql_3.split('\n'):
if 'select ' in sub_sql:
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4[i] = ' '.join(tokens)
format_sql_4.append(sub_sql)
elif sub_sql.startswith('from'):
sub_sql_tokens = sub_sql.split()
if len(sub_sql_tokens) == 1:
table_name = None
elif len(sub_sql_tokens) == 2:
table_name = sub_sql_tokens[1]
else:
print('bad from clause in remove_from_without_join')
print(format_sql_3)
exit()
else:
format_sql_4.append(sub_sql)
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4[i] = ' '.join(tokens)
return format_sql_4
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
# If just one table used, easy case, replace all column_name -> table_name.column_name
if len(used_tables) == 1:
table_name = used_tables[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
if '{}.{}'.format(table_name,token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name,token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def get_table_name_for(token):
table_names = []
for table_name in used_tables:
if '{}.{}'.format(table_name, token) in schema_tokens:
table_names.append(table_name)
if len(table_names) == 0:
return 'table'
if len(table_names) > 1:
return None
else:
return table_names[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
# skip *
if token == '*':
continue
if token in column_names and tokens[ii-1] != '.':
if (ii+1 < len(tokens) and tokens[ii+1] != '.' and tokens[ii+1] != '(') or ii+1 == len(tokens):
table_name = get_table_name_for(token)
if table_name:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def check_oov(format_sql_final, output_vocab, schema_tokens):
for sql_tok in format_sql_final.split():
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
print('OOV!', sql_tok)
raise Exception('OOV')
def normalize_space(format_sql):
format_sql_1 = [' '.join(sub_sql.strip().replace(',',' , ').replace('.',' . ').replace('(',' ( ').replace(')',' ) ').split()) for sub_sql in format_sql.split('\n')]
format_sql_1 = '\n'.join(format_sql_1)
format_sql_2 = format_sql_1.replace('\njoin',' join').replace(',\n',', ').replace(' where','\nwhere').replace(' intersect', '\nintersect').replace('\nand',' and').replace('order by t2 .\nstart desc', 'order by t2 . start desc')
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor', 'select constructor').replace('select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork', 'select work').replace('select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder', 'select order').replace('from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor', 'from for').replace('from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace('.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace('.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor', '. constructor').replace('.\ngroup', '. group').replace('.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
return format_sql_2
def normalize_final_sql(format_sql_5):
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace('order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
# normalize two bad sqls
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
# This is the failure case of remove_from_without_join()
format_sql_final = format_sql_final.replace('select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value', 'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
return format_sql_final
def parse_sql(sql_string, db_id, column_names, output_vocab, schema_tokens, schema):
format_sql = sqlparse.format(sql_string, reindent=True)
format_sql_2 = normalize_space(format_sql)
num_from = sum([1 for sub_sql in format_sql_2.split('\n') if sub_sql.startswith('from')])
num_select = format_sql_2.count('select ') + format_sql_2.count('select\n')
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
format_sql_3 = '\n'.join(format_sql_3)
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
format_sql_4 = '\n'.join(format_sql_4)
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
format_sql_5 = '\n'.join(format_sql_5)
format_sql_final = normalize_final_sql(format_sql_5)
candidate_tables_id, table_names_original = get_candidate_tables(format_sql_final, schema)
failure = False
if len(candidate_tables_id) != len(used_tables):
failure = True
check_oov(format_sql_final, output_vocab, schema_tokens)
return format_sql_final
def read_spider_split(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
with open(split_json) as f:
split_data = json.load(f)
print('read_spider_split', split_json, len(split_data))
for i, ex in enumerate(split_data):
db_id = ex['db_id']
final_sql = []
skip = False
for query_tok in ex['query_toks_no_value']:
if query_tok != '.' and '.' in query_tok:
# invalid sql; didn't use table alias in join
final_sql += query_tok.replace('.',' . ').split()
skip = True
else:
final_sql.append(query_tok)
final_sql = ' '.join(final_sql)
if skip and 'train' in split_json:
continue
if remove_from:
final_sql_parse = parse_sql(final_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
else:
final_sql_parse = final_sql
final_utterance = ' '.join(ex['question_toks'])
if db_id not in interaction_list:
interaction_list[db_id] = []
interaction = {}
interaction['id'] = ''
interaction['scenario'] = ''
interaction['database_id'] = db_id
interaction['interaction_id'] = len(interaction_list[db_id])
interaction['final'] = {}
interaction['final']['utterance'] = final_utterance
interaction['final']['sql'] = final_sql_parse
interaction['interaction'] = [{'utterance': final_utterance, 'sql': final_sql_parse}]
interaction_list[db_id].append(interaction)
return interaction_list
def read_data_json(split_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
with open(split_json) as f:
split_data = json.load(f)
print('read_data_json', split_json, len(split_data))
for interaction_data in split_data:
db_id = interaction_data['database_id']
final_sql = interaction_data['final']['query']
final_utterance = interaction_data['final']['utterance']
if db_id not in interaction_list:
interaction_list[db_id] = []
# no interaction_id in train
if 'interaction_id' in interaction_data['interaction']:
interaction_id = interaction_data['interaction']['interaction_id']
else:
interaction_id = len(interaction_list[db_id])
interaction = {}
interaction['id'] = ''
interaction['scenario'] = ''
interaction['database_id'] = db_id
interaction['interaction_id'] = interaction_id
interaction['final'] = {}
interaction['final']['utterance'] = final_utterance
interaction['final']['sql'] = final_sql
interaction['interaction'] = []
for turn in interaction_data['interaction']:
turn_sql = []
skip = False
for query_tok in turn['query_toks_no_value']:
if query_tok != '.' and '.' in query_tok:
# invalid sql; didn't use table alias in join
turn_sql += query_tok.replace('.',' . ').split()
skip = True
else:
turn_sql.append(query_tok)
turn_sql = ' '.join(turn_sql)
# Correct some human sql annotation error
turn_sql = turn_sql.replace('select f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id', 'select t1 . f_id from files as t1 join song as t2 on t1 . f_id = t2 . f_id')
turn_sql = turn_sql.replace('select name from climber mountain', 'select name from climber')
turn_sql = turn_sql.replace('select sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid', 'select t1 . sid from sailors as t1 join reserves as t2 on t1 . sid = t2 . sid join boats as t3 on t3 . bid = t2 . bid')
turn_sql = turn_sql.replace('select avg ( price ) from goods )', 'select avg ( price ) from goods')
turn_sql = turn_sql.replace('select min ( annual_fuel_cost ) , from vehicles', 'select min ( annual_fuel_cost ) from vehicles')
turn_sql = turn_sql.replace('select * from goods where price < ( select avg ( price ) from goods', 'select * from goods where price < ( select avg ( price ) from goods )')
turn_sql = turn_sql.replace('select distinct id , price from goods where price < ( select avg ( price ) from goods', 'select distinct id , price from goods where price < ( select avg ( price ) from goods )')
turn_sql = turn_sql.replace('select id from goods where price > ( select avg ( price ) from goods', 'select id from goods where price > ( select avg ( price ) from goods )')
if skip and 'train' in split_json:
continue
if remove_from:
try:
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
except:
print('continue')
continue
else:
turn_sql_parse = turn_sql
if 'utterance_toks' in turn:
turn_utterance = ' '.join(turn['utterance_toks']) # not lower()
else:
turn_utterance = turn['utterance']
interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
if len(interaction['interaction']) > 0:
interaction_list[db_id].append(interaction)
return interaction_list
def read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(spider_dir, 'train.json')
interaction_list = read_spider_split(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(spider_dir, 'dev.json')
interaction_list = read_spider_split(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(sparc_dir, 'train_no_value.json')
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(sparc_dir, 'dev_no_value.json')
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from):
interaction_list = {}
train_json = os.path.join(cosql_dir, 'train.json')
interaction_list = read_data_json(train_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
dev_json = os.path.join(cosql_dir, 'dev.json')
interaction_list = read_data_json(dev_json, interaction_list, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
return interaction_list
def read_db_split(data_dir):
train_database = []
with open(os.path.join(data_dir,'train_db_ids.txt')) as f:
for line in f:
train_database.append(line.strip())
dev_database = []
with open(os.path.join(data_dir,'dev_db_ids.txt')) as f:
for line in f:
dev_database.append(line.strip())
return train_database, dev_database
def preprocess(datasets, remove_from=False):
# Validate output_vocab
output_vocab = ['_UNK', '_EOS', '.', 't1', 't2', '=', 'select', 'from', 'as', 'value', 'join', 'on', ')', '(', 'where', 't3', 'by', ',', 'count', 'group', 'order', 'distinct', 't4', 'and', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 't5', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!', 'union', 'between', 't6', '-', 't7', '+', '/']
if remove_from:
output_vocab = ['_UNK', '_EOS', '=', 'select', 'value', ')', '(', 'where', ',', 'count', 'group_by', 'order_by', 'distinct', 'and', 'limit_value', 'limit', 'desc', '>', 'avg', 'having', 'max', 'in', '<', 'sum', 'intersect', 'not', 'min', 'except', 'or', 'asc', 'like', '!=', 'union', 'between', '-', '+', '/']
print('size of output_vocab', len(output_vocab))
print('output_vocab', output_vocab)
print()
print('datasets', datasets)
if 'spider' in datasets:
spider_dir = 'data/spider/'
database_schema_filename = 'data/spider/tables.json'
output_dir = 'data/spider_data'
if remove_from:
output_dir = 'data/spider_data_removefrom'
spider_train_database, _ = read_db_split(spider_dir)
if 'sparc' in datasets:
sparc_dir = 'data/sparc/'
database_schema_filename = 'data/sparc/tables.json'
output_dir = 'data/sparc_data'
if remove_from:
output_dir = 'data/sparc_data_removefrom'
sparc_train_database, dev_database = read_db_split(sparc_dir)
if 'cosql' in datasets:
cosql_dir = 'data/cosql/'
database_schema_filename = 'data/cosql/tables.json'
output_dir = 'data/cosql_data'
if remove_from:
output_dir = 'data/cosql_data_removefrom'
cosql_train_database, _ = read_db_split(cosql_dir)
#print('sparc_train_database', sparc_train_database)
#print('cosql_train_database', cosql_train_database)
#print('spider_train_database', spider_train_database)
#print('dev_database', dev_database)
#for x in dev_database:
# if x in sparc_train_database:
# print('x1', x)
# if x in spider_train_database:
# print('x2', x)
# if x in cosql_train_database:
# print('x3', x)
output_dir = './data/merge_data_removefrom'
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
schema_tokens = {}
column_names = {}
database_schemas = {}
#assert False
print('Reading spider database schema file')
schema_tokens = dict()
column_names = dict()
database_schemas = dict()
assert remove_from
for file in ['data/sparc_data_removefrom', 'data/cosql_data_removefrom', 'data/sparc_data_removefrom']:
s, c, d = read_database_schema(database_schema_filename, schema_tokens, column_names, database_schemas)
schema_tokens.update(s)
column_names.update(c)
database_schemas.update(d)
#x = 0
#for k in schema_tokens.keys():
# x = k
#print('schema_tokens', schema_tokens[x], len(schema_tokens) )
#print('column_names', column_names[x], len(column_names) )
#print('database_schemas', database_schemas[x], len(database_schemas) )
#assert False
num_database = len(schema_tokens)
print('num_database', num_database)
print('total number of schema_tokens / databases:', len(schema_tokens))
#output_dir = 'data/merge_data_removefrom'
output_database_schema_filename = os.path.join(output_dir, 'tables.json')
#print("output_database_schema_filename", output_database_schema_filename)
with open(output_database_schema_filename, 'w') as outfile:
json.dump([v for k,v in database_schemas.items()], outfile, indent=4)
#assert False
if 'spider' in datasets:
interaction_list1 = read_spider(spider_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
if 'sparc' in datasets:
interaction_list2 = read_sparc(sparc_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
if 'cosql' in datasets:
interaction_list3 = read_cosql(cosql_dir, database_schemas, column_names, output_vocab, schema_tokens, remove_from)
print('interaction_list1 length', len(interaction_list1))
print('interaction_list2 length', len(interaction_list2))
print('interaction_list3 length', len(interaction_list3))
#assert False
s1 = 1.0
s2 = 1.0
n1 = 1.0
n2 = 1.0
m1 = 1.0
m2 = 1.0
n3 = 1.0
train_interaction = []
vis = set()
for database_id in interaction_list1:
if database_id not in dev_database:
#Len = len(interaction_list1[database_id])//2
Len = len(interaction_list1[database_id])
train_interaction += interaction_list1[database_id][:Len]
#print('xxx', type(interaction_list1[database_id]))
#print(interaction_list1[database_id][0])
#print(interaction_list2[database_id][0])
#print(interaction_list3[database_id][0])
#assert False
if database_id in interaction_list2:
train_interaction += interaction_list2[database_id]
#for x in interaction_list1[database_id]:
# if len( x['interaction'][0] ) > 300:
# continue
# train_interaction.append(x)
# n3+=1
'''if database_id in interaction_list2:
for x in interaction_list2[database_id]:
# continue
#n1 += 1
#m1 = max(m1, len(x['interaction']))
s = '&'.join([k['utterance'] for k in x['interaction']])
if s in vis:
print('s', s)
vis.add(s)
'''
if database_id in interaction_list3:
for x in interaction_list3[database_id]:
# continue
#n1 += 1
#m1 = max(m1, len(x['interaction']))
#s = '&'.join([k['utterance'] for k in x['interaction']])
#if s in vis:
# print('s1', s)
#vis.add(s)
#print(x['interaction'][0]['utterance'])
#print(x['interaction'][0]['utterance'])
if len(x['interaction']) > 4:
#assert False
continue
sumlen = 0
for one in x['interaction']:
sumlen += len(one['utterance'])
if sumlen > 300:
continue
n1 += 1
train_interaction.append(x)
print('n1', n1)
print('n3', n3)
dev_interaction = []
for database_id in dev_database:
dev_interaction += interaction_list2[database_id]
print('train interaction: ', len(train_interaction))
print('dev interaction: ', len(dev_interaction))
write_interaction(train_interaction, 'train', output_dir)
write_interaction(dev_interaction, 'dev', output_dir)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--datasets", type=str, default='')
parser.add_argument('--remove_from', action='store_true', default=False)
args = parser.parse_args()
assert args.datasets != ''
preprocess(args.datasets, args.remove_from)

View File

@ -0,0 +1,80 @@
"""Contains classes for computing and keeping track of attention distributions.
"""
from collections import namedtuple
import torch
import torch.nn.functional as F
from . import torch_utils
class AttentionResult(namedtuple('AttentionResult',
('scores',
'distribution',
'vector'))):
"""Stores the result of an attention calculation."""
__slots__ = ()
class Attention(torch.nn.Module):
"""Attention mechanism class. Stores parameters for and computes attention.
Attributes:
transform_query (bool): Whether or not to transform the query being
passed in with a weight transformation before computing attentino.
transform_key (bool): Whether or not to transform the key being
passed in with a weight transformation before computing attentino.
transform_value (bool): Whether or not to transform the value being
passed in with a weight transformation before computing attentino.
key_size (int): The size of the key vectors.
value_size (int): The size of the value vectors.
the query or key.
query_weights (dy.Parameters): Weights for transforming the query.
key_weights (dy.Parameters): Weights for transforming the key.
value_weights (dy.Parameters): Weights for transforming the value.
"""
def __init__(self, query_size, key_size, value_size):
super().__init__()
self.key_size = key_size
self.value_size = value_size
self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q")
def transform_arguments(self, query, keys, values):
""" Transforms the query/key/value inputs before attention calculations.
Arguments:
query (dy.Expression): Vector representing the query (e.g., hidden state.)
keys (list of dy.Expression): List of vectors representing the key
values.
values (list of dy.Expression): List of vectors representing the values.
Returns:
triple of dy.Expression, where the first represents the (transformed)
query, the second represents the (transformed and concatenated)
keys, and the third represents the (transformed and concatenated)
values.
"""
assert len(keys) == len(values)
all_keys = torch.stack(keys, dim=1)
all_values = torch.stack(values, dim=1)
assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0])
assert all_values.size()[0] == self.value_size
query = torch_utils.linear_layer(query, self.query_weights)
return query, all_keys, all_values
def forward(self, query, keys, values=None):
if not values:
values = keys
query_t, keys_t, values_t = self.transform_arguments(query, keys, values)
scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query)
distribution = F.softmax(scores, dim=0) # len(key) x len(query)
context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query)
return AttentionResult(scores, distribution, context_vector)

View File

@ -0,0 +1,97 @@
import heapq
import math
import numpy as np
from typing import Callable, Dict, Tuple, List
import torch
class BeamSearchState:
def __init__(self, sequence=[], log_probability=0, state={}):
self.sequence = sequence
self.log_probability = log_probability
self.state = state
def __lt__(self, other):
return self.log_probability < other.log_probability
def __str__(self):
return f"{self.log_probability}: {' '.join(self.sequence)}"
class BeamSearch:
"""
Implements the beam search algorithm for decoding the most likely sequences.
"""
def __init__(
self,
is_end_of_sequence: Callable[[int, str], bool],
max_steps: int = 50,
beam_size: int = 10,
per_node_beam_size: int = None,
) -> None:
self.is_end_of_sequence = is_end_of_sequence
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size
if not per_node_beam_size:
self.per_node_beam_size = beam_size
self.beam = []
def append(self, log_probability, new_state):
if len(self.beam) < self.beam_size:
heapq.heappush(self.beam, (log_probability, new_state))
else:
heapq.heapreplace(self.beam, (log_probability, new_state))
def get_largest(self, beam):
return list(heapq.nlargest(1, beam))[0]
def search(
self,
start_state: Dict,
step_function: Callable[[Dict], Tuple[List[str], torch.Tensor, Tuple]],
append_token_function: Callable[[Tuple, Tuple, str, float], Tuple]
) -> Tuple[List, float, Tuple]:
state = BeamSearchState(state=start_state)
heapq.heappush(self.beam, (0, state))
done = False
while not done:
current_beam = self.beam
self.beam = []
found_non_terminal = False
for (log_probability, state) in current_beam:
if self.is_end_of_sequence(state.state, self.max_steps, state.sequence):
self.append(log_probability, state)
else:
found_non_terminal = True
tokens, token_probabilities, extra_state = step_function(state.state)
tps = np.array(token_probabilities)
top_indices = heapq.nlargest(self.per_node_beam_size, range(len(tps)), tps.take)
top_tokens = [tokens[i] for i in top_indices]
top_probabilities = [token_probabilities[i] for i in top_indices]
sequence_length = len(state.sequence)
for i in range(len(top_tokens)):
if top_probabilities[i] > 1e-6:
#lp = ((log_probability * sequence_length) + math.log(top_probabilities[i])) / (sequence_length + 1.0)
lp = log_probability + math.log(top_probabilities[i])
t = top_tokens[i]
new_state = append_token_function(state.state, extra_state, t, lp)
new_sequence = state.sequence.copy()
new_sequence.append(t)
ns = BeamSearchState(sequence=new_sequence, log_probability=lp, state=new_state)
self.append(lp, ns)
if not found_non_terminal:
done = True
output = self.get_largest(self.beam)[1]
return output.sequence, output.log_probability, output.state, self.beam

View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,255 @@
# PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models
## Forked for wikisql application
## NSML
### SQuAD1.1 finetuning
```
nsml run -d squad_bert -g 4 -e run_squad.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu"
```
### SQuAD2.0 finetuning
```
nsml run -d squad_bert -g 4 -e run_squad2.py -a "--do_lower_case --do_train --do_predict --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu"
```
### Evaluation
1. Download prediction file from NSML session
```
nsml download -f /app/squad_base [NSML_ID]/squad_bert/[SESSION] .
```
2. Run official evaluation file
```
python3 evaluate-v1.1.py [dev.json] [predictions.json]
python3 evaluate-v2.0.py [dev.json] [predictions.json] -n [na_probs.json]
```
## Introduction
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below).
The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated).
## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models))
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
```
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
## PyTorch models for BERT
We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py):
- `BertModel` - the basic BERT Transformer model
- `BertForSequenceClassification` - the BERT model with a sequence classification head on top
- `BertForQuestionAnswering` - the BERT model with a token classification head on top
Here are some details on each class.
### 1. `BertModel`
`BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
This model outputs a tuple composed of:
- `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
### 2. `BertForSequenceClassification`
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`.
The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
### 3. `BertForQuestionAnswering`
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper).
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
## Installation, requirements, test
This code was tested on Python 3.5+. The requirements are:
- PyTorch (>= 0.4.1)
- tqdm
To install the dependencies:
````bash
pip install -r ./requirements.txt
````
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
You can run the tests with the command:
```bash
python -m pytest -sv tests/
```
## Training on large batches: gradient accumulation, multi-GPU and distributed training
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
To help with fine-tuning these models, we have included four techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: optimize on CPU, gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
Here is how to use these techniques in our scripts:
- **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script.
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details):
```bash
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
```
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
## TPU support and pretraining scripts
TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)).
We will add TPU support when this next release is published.
The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py).
Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts.
## Comparing the PyTorch model and the TensorFlow model predictions
We also include [two Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
- The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.
- The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide.
## Fine-tuning with BERT: running the examples
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
Before running these examples you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
```shell
export GLUE_DIR=/path/to/glue
python run_classifier.py \
--task_name MRPC \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/MRPC/ \
--vocab_file $BERT_BASE_DIR/vocab.txt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
```
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
```shell
export SQUAD_DIR=/path/to/SQUAD
python run_squad.py \
--vocab_file $BERT_BASE_DIR/vocab.txt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \
--do_train \
--do_predict \
--do_lower_case
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ../debug_squad/
```
Training with the previous hyper-parameters gave us the following results:
```bash
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
```
# Fine-tuning BERT-large on GPUs
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
For example, fine-tuning BERT-large on SQuAD can be done on a server with 4 k-80 (these are pretty old now) in 18 hours. Our results are similar to the TensorFlow implementation results (actually slightly higher):
```bash
{"exact_match": 84.56953642384106, "f1": 91.04028647786927}
```
To get these results that we used a combination of:
- multi-GPU training (automatically activated on a multi-GPU server),
- 2 steps of gradient accumulation and
- perform the optimization step on CPU to store Adam's averages in RAM.
Here are the full list of hyper-parameters we used for this run:
```bash
python ./run_squad.py --vocab_file $BERT_LARGE_DIR/vocab.txt --bert_config_file $BERT_LARGE_DIR/bert_config.json --init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin --do_lower_case --do_train --do_predict --train_file $SQUAD_TRAIN --predict_file $SQUAD_EVAL --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir $OUTPUT_DIR/bert_large_bsz_24 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu
```

View File

@ -0,0 +1,105 @@
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import argparse
import tensorflow as tf
import torch
import numpy as np
from modeling import BertConfig, BertModel
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
# Load weights from TF model
path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading {} with shape {}".format(name, shape))
array = tf.train.load_variable(path, name)
print("Numpy array shape {}".format(array.shape))
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
pointer.data = torch.from_numpy(array)
# Save pytorch-model
torch.save(model.state_dict(), args.pytorch_dump_path)
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,13 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@ -0,0 +1,668 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import math
import six
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
def print_status(self):
"""
Wonseok add this.
"""
print( f"vocab size: {self.vocab_size}")
print( f"hidden_size: {self.hidden_size}")
print( f"num_hidden_layer: {self.num_hidden_layers}")
print( f"num_attention_heads: {self.num_attention_heads}")
print( f"hidden_act: {self.hidden_act}")
print( f"intermediate_size: {self.intermediate_size}")
print( f"hidden_dropout_prob: {self.hidden_dropout_prob}")
print( f"attention_probs_dropout_prob: {self.attention_probs_dropout_prob}")
print( f"max_position_embeddings: {self.max_position_embeddings}")
print( f"type_vocab_size: {self.type_vocab_size}")
print( f"initializer_range: {self.initializer_range}")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BERTLayerNorm(nn.Module):
def __init__(self, config, variance_epsilon=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BERTLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
self.variance_epsilon = variance_epsilon
def forward(self, x):
# x = input to the neuron.
# normalize each vector (each token).
# regularize x.
# If x follows Gaussian distribution, it becomes standard Normal distribution (i.e., mu=0, std=1).
u = x.mean(-1, keepdim=True) # keepdim = keeprank of tensor.
s = (x - u).pow(2).mean(-1, keepdim=True) # variance
x = (x - u) / torch.sqrt(s + self.variance_epsilon) # standard
# Gamma & Beta is trainable parameters.
return self.gamma * x + self.beta
class BERTEmbeddings(nn.Module):
def __init__(self, config):
super(BERTEmbeddings, self).__init__()
"""Construct the embedding module from word, position and token_type embeddings.
"""
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BERTSelfAttention(nn.Module):
def __init__(self, config):
super(BERTSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
""" From this,
"""
# Maybe: x = [B, seq_len, all_head_size=hidden_size ]
# [B, seq_len] + [num_attention_heads, attention_head_size] ???
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # it is append (for torch.Size tensor, + acts as an concat. operation)
# [B, seq_len, all_head_size=hidden_size ] -> [B, seq_len, num_attention_heads, attention_head_size]
x = x.view(*new_x_shape)
# [B, seq_len, num_attention_heads, attention_head_size] -> [B, num_attention_heads, seq_len, attention_head_size]
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
# [B, num_attention_heads, seq_len, attention_head_size] * [B, num_attention_heads, attention_head_size, seq_len]
# -> [B, num_attention_heads, seq_len, seq_len]
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask # sort of multiplication in soft-max step. It is ~ -10000
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# [B, num_attention_heads, seq_len, seq_len] * [B, num_attention_heads, seq_len, attention_head_size]
# -> [B, num_attention_heads, seq_len, attention_head_size]
context_layer = torch.matmul(attention_probs, value_layer)
# -> [B, seq_len, num_attention_heads, attention_head_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [B, seq_len] + [all_head_size=hidden_size] -> [B, seq_len, all_head_size]
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BERTSelfOutput(nn.Module):
def __init__(self, config):
super(BERTSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BERTAttention(nn.Module):
def __init__(self, config):
super(BERTAttention, self).__init__()
self.self = BERTSelfAttention(config)
self.output = BERTSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BERTIntermediate(nn.Module):
def __init__(self, config):
super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = gelu
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BERTOutput(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BERTLayer(nn.Module):
def __init__(self, config):
super(BERTLayer, self).__init__()
self.attention = BERTAttention(config)
self.intermediate = BERTIntermediate(config)
self.output = BERTOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BERTEncoder(nn.Module):
def __init__(self, config):
super(BERTEncoder, self).__init__()
layer = BERTLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BERTPooler(nn.Module):
def __init__(self, config):
super(BERTPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertModel(nn.Module):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config: BertConfig):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
"""
super(BertModel, self).__init__()
self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.float()
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
sequence_output = all_encoder_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output
class BertForSequenceClassification(nn.Module):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__()
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
return loss, logits
else:
return logits
class BertForQuestionAnswering(nn.Module):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__()
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
class BertNoAnswer(nn.Module):
def __init__(self,hidden_size,context_length=317):
super(BertNoAnswer,self).__init__()
self.context_length = context_length
self.W_no = nn.Linear(hidden_size,1)
self.no_answer = nn.Sequential(nn.Linear(hidden_size*3,hidden_size),
nn.ReLU(),
nn.Linear(hidden_size,2))
def forward(self,sequence_output,start_logit,end_logit,mask=None):
if mask is None:
nbatch,length,_ = sequence_output.size()
mask = torch.ones(nbatch,length)
mask = mask.float()
mask = mask.unsqueeze(-1)[:,1:self.context_length+1]
mask = (1.0 - mask) * -10000.0
sequence_output = sequence_output[:,1:self.context_length+1]
start_logit = start_logit[:,1:self.context_length+1] + mask
end_logit = end_logit[:,1:self.context_length+1] + mask
# No-answer option
pa_1 = nn.functional.softmax(start_logit.transpose(1,2),-1) # B,1,T
v1 = torch.bmm(pa_1,sequence_output).squeeze(1) # B,H
pa_2 = nn.functional.softmax(end_logit.transpose(1,2),-1) # B,1,T
v2 = torch.bmm(pa_2,sequence_output).squeeze(1) # B,H
pa_3 = self.W_no(sequence_output) + mask
pa_3 = nn.functional.softmax(pa_3.transpose(1,2),-1) # B,1,T
v3 = torch.bmm(pa_3,sequence_output).squeeze(1) # B,H
bias = self.no_answer(torch.cat([v1,v2,v3],-1)) # B,1
return bias
class BertForSQuAD2(nn.Module):
def __init__(self, config, context_length=317):
super(BertForSQuAD2, self).__init__()
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 2)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.na_head = BertNoAnswer(config.hidden_size, context_length)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None, labels=None):
all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
span_logits = self.qa_outputs(sequence_output)
start_logits, end_logits = span_logits.split(1, dim=-1)
na_logits = self.na_head(sequence_output,start_logits,end_logits,attention_mask)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = (na_logits+logits)/2 # mean
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
unanswerable_loss = loss_fct(logits, labels)
span_loss = (start_loss + end_loss) / 2
total_loss = span_loss + unanswerable_loss
return total_loss
else:
probs = nn.functional.softmax(logits,-1)
_, probs = probs.split(1, dim=-1)
return start_logits, end_logits, probs
class BertForWikiSQL(nn.Module):
def __init__(self, config, context_length=317):
super(BertForWikiSQL, self).__init__()
self.bert = BertModel(config)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 2)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.na_head = BertNoAnswer(config.hidden_size, context_length)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None, labels=None):
all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
span_logits = self.qa_outputs(sequence_output)
start_logits, end_logits = span_logits.split(1, dim=-1)
na_logits = self.na_head(sequence_output, start_logits, end_logits, attention_mask)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = (na_logits + logits) / 2 # mean
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
unanswerable_loss = loss_fct(logits, labels)
span_loss = (start_loss + end_loss) / 2
total_loss = span_loss + unanswerable_loss
return total_loss
else:
probs = nn.functional.softmax(logits, -1)
_, probs = probs.split(1, dim=-1)
return start_logits, end_logits, probs

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,333 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import unicodedata
import six
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_tokens_to_ids(vocab, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(vocab[token])
return ids
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_tokens_to_ids(self.vocab, tokens)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

View File

@ -0,0 +1,287 @@
""" Decoder for the SQL generation problem."""
from collections import namedtuple
import copy
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import torch_utils
from .token_predictor import PredictionInput, PredictionInputWithSchema, PredictionStepInputWithSchema
from .beam_search import BeamSearch
import data_util.snippets as snippet_handler
from . import embedder
from data_util.vocabulary import EOS_TOK, UNK_TOK
import math
def flatten_distribution(distribution_map, probabilities):
""" Flattens a probability distribution given a map of "unique" values.
All values in distribution_map with the same value should get the sum
of the probabilities.
Arguments:
distribution_map (list of str): List of values to get the probability for.
probabilities (np.ndarray): Probabilities corresponding to the values in
distribution_map.
Returns:
list, np.ndarray of the same size where probabilities for duplicates
in distribution_map are given the sum of the probabilities in probabilities.
"""
assert len(distribution_map) == len(probabilities)
if len(distribution_map) != len(set(distribution_map)):
idx_first_dup = 0
seen_set = set()
for i, tok in enumerate(distribution_map):
if tok in seen_set:
idx_first_dup = i
break
seen_set.add(tok)
new_dist_map = distribution_map[:idx_first_dup] + list(
set(distribution_map) - set(distribution_map[:idx_first_dup]))
assert len(new_dist_map) == len(set(new_dist_map))
new_probs = np.array(
probabilities[:idx_first_dup] \
+ [0. for _ in range(len(set(distribution_map)) \
- idx_first_dup)])
assert len(new_probs) == len(new_dist_map)
for i, token_name in enumerate(
distribution_map[idx_first_dup:]):
if token_name not in new_dist_map:
new_dist_map.append(token_name)
new_index = new_dist_map.index(token_name)
new_probs[new_index] += probabilities[i +
idx_first_dup]
new_probs = new_probs.tolist()
else:
new_dist_map = distribution_map
new_probs = probabilities
assert len(new_dist_map) == len(new_probs)
return new_dist_map, new_probs
class SQLPrediction(namedtuple('SQLPrediction',
('predictions',
'sequence',
'probability',
'beam'))):
"""Contains prediction for a sequence."""
__slots__ = ()
def __str__(self):
return str(self.probability) + "\t" + " ".join(self.sequence)
class SequencePredictorWithSchema(torch.nn.Module):
""" Predicts a sequence.
Attributes:
lstms (list of dy.RNNBuilder): The RNN used.
token_predictor (TokenPredictor): Used to actually predict tokens.
"""
def __init__(self,
params,
input_size,
output_embedder,
column_name_token_embedder,
token_predictor):
super().__init__()
self.lstms = torch_utils.create_multilayer_lstm_params(params.decoder_num_layers, input_size, params.decoder_state_size, "LSTM-d")
self.token_predictor = token_predictor
self.output_embedder = output_embedder
self.column_name_token_embedder = column_name_token_embedder
self.start_token_embedding = torch_utils.add_params((params.output_embedding_size,), "y-0")
self.input_size = input_size
self.params = params
def _initialize_decoder_lstm(self, encoder_state, utterance_index):
decoder_lstm_states = []
for i, lstm in enumerate(self.lstms):
encoder_layer_num = 0
if len(encoder_state[0]) > 1:
encoder_layer_num = i
# check which one is h_0, which is c_0
c_0 = encoder_state[0][encoder_layer_num].view(1,-1)
h_0 = encoder_state[1][encoder_layer_num].view(1,-1)
decoder_lstm_states.append((h_0, c_0))
return decoder_lstm_states
def get_output_token_embedding(self, output_token, input_schema, snippets):
if self.params.use_snippets and snippet_handler.is_snippet(output_token):
output_token_embedding = embedder.bow_snippets(output_token, snippets, self.output_embedder, input_schema)
else:
if input_schema:
assert self.output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
if self.output_embedder.in_vocabulary(output_token):
output_token_embedding = self.output_embedder(output_token)
else:
output_token_embedding = input_schema.column_name_embedder(output_token, surface_form=True)
else:
output_token_embedding = self.output_embedder(output_token)
return output_token_embedding
def get_decoder_input(self, output_token_embedding, prediction):
if self.params.use_schema_attention and self.params.use_query_attention:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector, prediction.query_attention_results.vector], dim=0)
elif self.params.use_schema_attention:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector, prediction.schema_attention_results.vector], dim=0)
else:
decoder_input = torch.cat([output_token_embedding, prediction.utterance_attention_results.vector], dim=0)
return decoder_input
def forward(self,
final_encoder_state,
encoder_states,
schema_states,
max_generation_length,
utterance_index=None,
snippets=None,
gold_sequence=None,
input_sequence=None,
previous_queries=None,
previous_query_states=None,
input_schema=None,
dropout_amount=0.):
""" Generates a sequence. """
index = 0
context_vector_size = self.input_size - self.params.output_embedding_size
# Decoder states: just the initialized decoder.
# Current input to decoder: phi(start_token) ; zeros the size of the
# context vector
predictions = []
sequence = []
probability = 1.
decoder_states = self._initialize_decoder_lstm(final_encoder_state, utterance_index)
if self.start_token_embedding.is_cuda:
decoder_input = torch.cat([self.start_token_embedding, torch.cuda.FloatTensor(context_vector_size).fill_(0)], dim=0)
else:
decoder_input = torch.cat([self.start_token_embedding, torch.zeros(context_vector_size)], dim=0)
beam_search = BeamSearch(is_end_of_sequence=self.is_end_of_sequence, max_steps=max_generation_length, beam_size=10)
prediction_step_input = PredictionStepInputWithSchema(
index=index,
decoder_input=decoder_input,
decoder_states=decoder_states,
encoder_states=encoder_states,
schema_states=schema_states,
snippets=snippets,
gold_sequence=gold_sequence,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema,
dropout_amount=dropout_amount,
predictions=predictions,
)
sequence, probability, final_state, beam = beam_search.search(start_state=prediction_step_input, step_function=self.beam_search_step_function, append_token_function=self.beam_search_append_token)
return SQLPrediction(final_state.predictions,
sequence,
probability,
beam)
def beam_search_step_function(self, prediction_step_input):
# get decoded token probabilities
prediction, tokens, token_probabilities, decoder_states = self.decode_next_token(prediction_step_input)
return tokens, token_probabilities, (prediction, decoder_states)
def beam_search_append_token(self, prediction_step_input, step_function_output, token_to_append, token_log_probability):
output_token_embedding = self.get_output_token_embedding(token_to_append, prediction_step_input.input_schema, prediction_step_input.snippets)
decoder_input = self.get_decoder_input(output_token_embedding, step_function_output[0])
predictions = prediction_step_input.predictions.copy()
predictions.append(step_function_output[0])
new_state = prediction_step_input._replace(
index=prediction_step_input.index+1,
decoder_input=decoder_input,
predictions=predictions,
decoder_states=step_function_output[1],
)
return new_state
def is_end_of_sequence(self, prediction_step_input, max_generation_length, sequence):
if prediction_step_input.gold_sequence:
return prediction_step_input.index >= len(prediction_step_input.gold_sequence)
else:
return prediction_step_input.index >= max_generation_length or (len(sequence) > 0 and sequence[-1] == EOS_TOK)
def decode_next_token(self, prediction_step_input):
(index,
decoder_input,
decoder_states,
encoder_states,
schema_states,
snippets,
gold_sequence,
input_sequence,
previous_queries,
previous_query_states,
input_schema,
dropout_amount,
predictions) = prediction_step_input
_, decoder_state, decoder_states = torch_utils.forward_one_multilayer(self.lstms, decoder_input, decoder_states, dropout_amount)
prediction_input = PredictionInputWithSchema(decoder_state=decoder_state,
input_hidden_states=encoder_states,
schema_states=schema_states,
snippets=snippets,
input_sequence=input_sequence,
previous_queries=previous_queries,
previous_query_states=previous_query_states,
input_schema=input_schema)
prediction = self.token_predictor(prediction_input, dropout_amount=dropout_amount)
#gold_sequence = None
if gold_sequence:
decoded_token = gold_sequence[index]
distribution_map = [decoded_token]
probabilities = [1.0]
else:#/mnt/lichaochao/rqy/editsql_data/glove.840B.300d.txt
assert prediction.scores.dim() == 1
probabilities = F.softmax(prediction.scores, dim=0).cpu().data.numpy().tolist()
distribution_map = prediction.aligned_tokens
assert len(probabilities) == len(distribution_map)
if self.params.use_previous_query and self.params.use_copy_switch and len(previous_queries) > 0:
assert prediction.query_scores.dim() == 1
query_token_probabilities = F.softmax(prediction.query_scores, dim=0).cpu().data.numpy().tolist()
query_token_distribution_map = prediction.query_tokens
assert len(query_token_probabilities) == len(query_token_distribution_map)
copy_switch = prediction.copy_switch.cpu().data.numpy()
# Merge the two
probabilities = ((np.array(probabilities) * (1 - copy_switch)).tolist() +
(np.array(query_token_probabilities) * copy_switch).tolist()
)
distribution_map = distribution_map + query_token_distribution_map
assert len(probabilities) == len(distribution_map)
# Get a new probabilities and distribution_map consolidating duplicates
distribution_map, probabilities = flatten_distribution(distribution_map, probabilities)
# Modify the probability distribution so that the UNK token can never be produced
probabilities[distribution_map.index(UNK_TOK)] = 0.
return prediction, distribution_map, probabilities, decoder_states

View File

@ -0,0 +1,100 @@
""" Embedder for tokens. """
import torch
import torch.nn.functional as F
import data_util.snippets as snippet_handler
import data_util.vocabulary as vocabulary_handler
class Embedder(torch.nn.Module):
""" Embeds tokens. """
def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True):
super().__init__()
if vocabulary:
assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \
str(num_tokens)
self.in_vocabulary = lambda token: token in vocabulary.tokens
self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token)
if use_unk:
self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK)
else:
self.unknown_token_id = -1
self.vocabulary_size = len(vocabulary)
else:
def check_vocab(index):
""" Makes sure the index is in the vocabulary."""
assert index < num_tokens, "Passed token ID " + \
str(index) + "; expecting something less than " + str(num_tokens)
return index < num_tokens
self.in_vocabulary = check_vocab
self.vocab_token_lookup = lambda x: x
self.unknown_token_id = num_tokens # Deliberately throws an error here,
# But should crash before this
self.vocabulary_size = num_tokens
self.anonymizer = anonymizer
emb_name = name + "-tokens"
#print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size))
if initializer is not None:
word_embeddings_tensor = torch.FloatTensor(initializer)
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze)
else:
init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1)
self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
if self.anonymizer:
emb_name = name + "-entities"
entity_size = len(self.anonymizer.entity_types)
#print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size))
init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1)
self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False)
def forward(self, token):
assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets"
if self.in_vocabulary(token):
index_list = torch.LongTensor([self.vocab_token_lookup(token)])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.token_embedding_matrix(index_list).squeeze()
elif self.anonymizer and self.anonymizer.is_anon_tok(token):
index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.entity_embedding_matrix(index_list).squeeze()
else:
index_list = torch.LongTensor([self.unknown_token_id])
if self.token_embedding_matrix.weight.is_cuda:
index_list = index_list.cuda()
return self.token_embedding_matrix(index_list).squeeze()
def bow_snippets(token, snippets, output_embedder, input_schema):
""" Bag of words embedding for snippets"""
assert snippet_handler.is_snippet(token) and snippets
snippet_sequence = []
for snippet in snippets:
if snippet.name == token:
snippet_sequence = snippet.sequence
break
assert snippet_sequence
if input_schema:
snippet_embeddings = []
for output_token in snippet_sequence:
assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True)
if output_embedder.in_vocabulary(output_token):
snippet_embeddings.append(output_embedder(output_token))
else:
snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True))
else:
snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence]
snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size
return torch.mean(snippet_embeddings, dim=0) # emb_size

Some files were not shown because too many files have changed in this diff Show More