add: proton

This commit is contained in:
出蛰 2022-08-06 19:43:17 +08:00
parent 890ad4ee51
commit e42cda6d18
71 changed files with 12610 additions and 0 deletions

67
proton/README.md Normal file
View File

@ -0,0 +1,67 @@
# Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing
This repository contains code for the KDD 2022 paper [Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing](https://arxiv.org/abs/2206.14017).
If you use Proton in your work, please cite it as follows:
```
@inproceedings{wang2022proton,
title={Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing},
author={Wang, Lihan and Qin, Bowen and Hui, Binyuan and Li, Bowen and Yang, Min and Wang, Bailin and Li, Binhua and Huang, Fei and Si, Luo and Li, Yongbin},
booktitle={KDD},
year={2022}
}
```
## Overview
In this work, we propose a novel framework, called **Proton**, which first probes the underlying relational schema-linking structures between a NL query and its database schema from a pre-trained language model, and then effectively injects it into the downstream text-to-SQL parsing models.
![Overview](static/proton.jpg)
## Codebase
### Prepare Environment
The setup of the environment is exactly the same as that of [LGESQL](https://github.com/rhythmcao/text2sql-lgesql):
The environment-related commands are provided in `setup.sh`.
```bash
sh setup.sh
```
### Download dataset.
Download, unzip and rename the [spider.zip](https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0) into the directory `data`. Merge the `data/train_spider.json` and `data/train_others.json` into one single dataset `data/train.json`.
### Preprocess dataset.
Preprocess the train and dev dataset, including input normalization, schema linking, graph construction and output actions generation.
```bash
./run/run_preprocessing.sh
```
### Training
Training Proton with:
```
./run/run_lgesql_plm.sh msde electra-large-discriminator
```
### Evaluation
For evaluation, see `run/run_evaluation.sh` and `run/run_submission.sh` (eval from scratch) for reference.
## Result
| Model | DK | SYN | Spider |
| --------------- | ---- | ---- | ------ |
| LGESQL | 48.4 | 64.6 | 75.3 |
| LGESQL + Proton | 51.0 | 65.6 | 76.3 |
## Acknowledgements
This implementation is based on [RAT-SQL: Relation-Aware Schema Encoding and Linking for Text-to-SQL Parsers.](https://github.com/microsoft/rat-sql) and [LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations](https://github.com/rhythmcao/text2sql-lgesql). Thanks to the author for releasing the code.

View File

@ -0,0 +1,66 @@
# coding=utf-8
from asdl.hypothesis import Hypothesis
from asdl.transition_system import ApplyRuleAction, GenTokenAction
from asdl.sql.sql_transition_system import SelectColumnAction, SelectTableAction
class ActionInfo(object):
"""sufficient statistics for making a prediction of an action at a time step"""
def __init__(self, action=None):
self.t = 0
self.parent_t = -1
self.action = action
self.frontier_prod = None
self.frontier_field = None
# for GenToken actions only
self.copy_from_src = False
self.src_token_position = -1
def __repr__(self, verbose=False):
repr_str = '%s (t=%d, p_t=%d, frontier_field=%s)' % (repr(self.action),
self.t,
self.parent_t,
self.frontier_field.__repr__(True) if self.frontier_field else 'None')
if verbose:
verbose_repr = 'action_prob=%.4f, ' % self.action_prob
if isinstance(self.action, GenTokenAction):
verbose_repr += 'in_vocab=%s, ' \
'gen_copy_switch=%s, ' \
'p(gen)=%s, p(copy)=%s, ' \
'has_copy=%s, copy_pos=%s' % (self.in_vocab,
self.gen_copy_switch,
self.gen_token_prob, self.copy_token_prob,
self.copy_from_src, self.src_token_position)
repr_str += '\n' + verbose_repr
return repr_str
def get_action_infos(src_query: list = None, tgt_actions: list = [], force_copy=False):
action_infos = []
hyp = Hypothesis()
for t, action in enumerate(tgt_actions):
action_info = ActionInfo(action)
action_info.t = t
if hyp.frontier_node:
action_info.parent_t = hyp.frontier_node.created_time
action_info.frontier_prod = hyp.frontier_node.production
action_info.frontier_field = hyp.frontier_field.field
if isinstance(action, SelectColumnAction) or isinstance(action, SelectTableAction):
pass
elif isinstance(action, GenTokenAction): # GenToken
try:
tok_src_idx = src_query.index(str(action.token))
action_info.copy_from_src = True
action_info.src_token_position = tok_src_idx
except ValueError:
if force_copy: raise ValueError('cannot copy primitive token %s from source' % action.token)
hyp.apply_action(action)
action_infos.append(action_info)
return action_infos

323
proton/asdl/asdl.py Normal file
View File

@ -0,0 +1,323 @@
# coding=utf-8
from collections import OrderedDict, Counter
from itertools import chain
import re, os
def remove_comment(text):
text = re.sub(re.compile("#.*"), "", text)
text = '\n'.join(filter(lambda x: x, text.split('\n')))
return text
class ASDLGrammar(object):
"""
Collection of types, constructors and productions
"""
def __init__(self, productions, file_path):
# productions are indexed by their head types
file_name = os.path.basename(file_path)
grammar_name = file_name[:file_name.index('.txt')] if '.txt' in file_name else file_name
self._grammar_name = grammar_name
self._productions = OrderedDict()
self._constructor_production_map = dict()
for prod in productions:
if prod.type not in self._productions:
self._productions[prod.type] = list()
self._productions[prod.type].append(prod)
self._constructor_production_map[prod.constructor.name] = prod
self.root_type = productions[0].type
# number of constructors
self.size = sum(len(head) for head in self._productions.values())
# get entities to their ids map
self.prod2id = {prod: i for i, prod in enumerate(self.productions)}
self.type2id = {type: i for i, type in enumerate(self.types)}
self.field2id = {field: i for i, field in enumerate(self.fields)}
self.id2prod = {i: prod for i, prod in enumerate(self.productions)}
self.id2type = {i: type for i, type in enumerate(self.types)}
self.id2field = {i: field for i, field in enumerate(self.fields)}
def __len__(self):
return self.size
@property
def productions(self):
return sorted(chain.from_iterable(self._productions.values()), key=lambda x: repr(x))
def __getitem__(self, datum):
if isinstance(datum, str):
return self._productions[ASDLType(datum)]
elif isinstance(datum, ASDLType):
return self._productions[datum]
def get_prod_by_ctr_name(self, name):
return self._constructor_production_map[name]
@property
def types(self):
if not hasattr(self, '_types'):
all_types = set()
for prod in self.productions:
all_types.add(prod.type)
all_types.update(map(lambda x: x.type, prod.constructor.fields))
self._types = sorted(all_types, key=lambda x: x.name)
return self._types
@property
def fields(self):
if not hasattr(self, '_fields'):
all_fields = set()
for prod in self.productions:
all_fields.update(prod.constructor.fields)
self._fields = sorted(all_fields, key=lambda x: (x.name, x.type.name, x.cardinality))
return self._fields
@property
def primitive_types(self):
return filter(lambda x: isinstance(x, ASDLPrimitiveType), self.types)
@property
def composite_types(self):
return filter(lambda x: isinstance(x, ASDLCompositeType), self.types)
def is_composite_type(self, asdl_type):
return asdl_type in self.composite_types
def is_primitive_type(self, asdl_type):
return asdl_type in self.primitive_types
@staticmethod
def from_filepath(file_path):
def _parse_field_from_text(_text):
d = _text.strip().split(' ')
name = d[1].strip()
type_str = d[0].strip()
cardinality = 'single'
if type_str[-1] == '*':
type_str = type_str[:-1]
cardinality = 'multiple'
elif type_str[-1] == '?':
type_str = type_str[:-1]
cardinality = 'optional'
if type_str in primitive_type_names:
return Field(name, ASDLPrimitiveType(type_str), cardinality=cardinality)
else:
return Field(name, ASDLCompositeType(type_str), cardinality=cardinality)
def _parse_constructor_from_text(_text):
_text = _text.strip()
fields = None
if '(' in _text:
name = _text[:_text.find('(')]
field_blocks = _text[_text.find('(') + 1:_text.find(')')].split(',')
fields = map(_parse_field_from_text, field_blocks)
else:
name = _text
if name == '': name = None
return ASDLConstructor(name, fields)
with open(file_path, 'r') as inf:
text = inf.read()
lines = remove_comment(text).split('\n')
lines = list(map(lambda l: l.strip(), lines))
lines = list(filter(lambda l: l, lines))
line_no = 0
# first line is always the primitive types
primitive_type_names = list(map(lambda x: x.strip(), lines[line_no].split(',')))
line_no += 1
all_productions = list()
while True:
type_block = lines[line_no]
type_name = type_block[:type_block.find('=')].strip()
constructors_blocks = type_block[type_block.find('=') + 1:].split('|')
i = line_no + 1
while i < len(lines) and lines[i].strip().startswith('|'):
t = lines[i].strip()
cont_constructors_blocks = t[1:].split('|')
constructors_blocks.extend(cont_constructors_blocks)
i += 1
constructors_blocks = filter(lambda x: x and x.strip(), constructors_blocks)
# parse type name
new_type = ASDLPrimitiveType(type_name) if type_name in primitive_type_names else ASDLCompositeType(type_name)
constructors = map(_parse_constructor_from_text, constructors_blocks)
productions = list(map(lambda c: ASDLProduction(new_type, c), constructors))
all_productions.extend(productions)
line_no = i
if line_no == len(lines):
break
grammar = ASDLGrammar(all_productions, file_path)
return grammar
class ASDLProduction(object):
def __init__(self, type, constructor):
self.type = type
self.constructor = constructor
@property
def fields(self):
return self.constructor.fields
def __getitem__(self, field_name):
return self.constructor[field_name]
def __hash__(self):
h = hash(self.type) ^ hash(self.constructor)
return h
def __eq__(self, other):
return isinstance(other, ASDLProduction) and \
self.type == other.type and \
self.constructor == other.constructor
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return '%s -> %s' % (self.type.__repr__(plain=True), self.constructor.__repr__(plain=True))
class ASDLConstructor(object):
def __init__(self, name, fields=None):
self.name = name
self.fields = []
if fields:
self.fields = list(fields)
def __getitem__(self, field_name):
for field in self.fields:
if field.name == field_name: return field
raise KeyError
def __hash__(self):
h = hash(self.name)
for field in self.fields:
h ^= hash(field)
return h
def __eq__(self, other):
return isinstance(other, ASDLConstructor) and \
self.name == other.name and \
self.fields == other.fields
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self, plain=False):
plain_repr = '%s(%s)' % (self.name,
', '.join(f.__repr__(plain=True) for f in self.fields))
if plain: return plain_repr
else: return 'Constructor(%s)' % plain_repr
class Field(object):
def __init__(self, name, type, cardinality):
self.name = name
self.type = type
assert cardinality in ['single', 'optional', 'multiple']
self.cardinality = cardinality
def __hash__(self):
h = hash(self.name) ^ hash(self.type)
h ^= hash(self.cardinality)
return h
def __eq__(self, other):
return isinstance(other, Field) and \
self.name == other.name and \
self.type == other.type and \
self.cardinality == other.cardinality
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self, plain=False):
plain_repr = '%s%s %s' % (self.type.__repr__(plain=True),
Field.get_cardinality_repr(self.cardinality),
self.name)
if plain: return plain_repr
else: return 'Field(%s)' % plain_repr
@staticmethod
def get_cardinality_repr(cardinality):
return '' if cardinality == 'single' else '?' if cardinality == 'optional' else '*'
class ASDLType(object):
def __init__(self, type_name):
self.name = type_name
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return isinstance(other, ASDLType) and self.name == other.name
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self, plain=False):
plain_repr = self.name
if plain: return plain_repr
else: return '%s(%s)' % (self.__class__.__name__, plain_repr)
class ASDLCompositeType(ASDLType):
pass
class ASDLPrimitiveType(ASDLType):
pass
if __name__ == '__main__':
asdl_desc = """
var, ent, num, var_type
expr = Variable(var variable)
| Entity(ent entity)
| Number(num number)
| Apply(pred predicate, expr* arguments)
| Argmax(var variable, expr domain, expr body)
| Argmin(var variable, expr domain, expr body)
| Count(var variable, expr body)
| Exists(var variable, expr body)
| Lambda(var variable, var_type type, expr body)
| Max(var variable, expr body)
| Min(var variable, expr body)
| Sum(var variable, expr domain, expr body)
| The(var variable, expr body)
| Not(expr argument)
| And(expr* arguments)
| Or(expr* arguments)
| Compare(cmp_op op, expr left, expr right)
cmp_op = GreaterThan | Equal | LessThan
"""
grammar = ASDLGrammar.from_text(asdl_desc)
print(ASDLCompositeType('1') == ASDLPrimitiveType('1'))

200
proton/asdl/asdl_ast.py Normal file
View File

@ -0,0 +1,200 @@
# coding=utf-8
from io import StringIO
from asdl.asdl import *
class AbstractSyntaxTree(object):
def __init__(self, production, realized_fields=None):
self.production = production
# a child is essentially a *realized_field*
self.fields = []
# record its parent field to which it's attached
self.parent_field = None
# used in decoding, record the time step when this node was created
self.created_time = 0
if realized_fields:
assert len(realized_fields) == len(self.production.fields)
for field in realized_fields:
self.add_child(field)
else:
for field in self.production.fields:
self.add_child(RealizedField(field))
def add_child(self, realized_field):
# if isinstance(realized_field.value, AbstractSyntaxTree):
# realized_field.value.parent = self
self.fields.append(realized_field)
realized_field.parent_node = self
def __getitem__(self, field_name):
for field in self.fields:
if field.name == field_name: return field
raise KeyError
def sanity_check(self):
if len(self.production.fields) != len(self.fields):
raise ValueError('filed number must match')
for field, realized_field in zip(self.production.fields, self.fields):
assert field == realized_field.field
for child in self.fields:
for child_val in child.as_value_list:
if isinstance(child_val, AbstractSyntaxTree):
child_val.sanity_check()
def copy(self):
new_tree = AbstractSyntaxTree(self.production)
new_tree.created_time = self.created_time
for i, old_field in enumerate(self.fields):
new_field = new_tree.fields[i]
new_field._not_single_cardinality_finished = old_field._not_single_cardinality_finished
if isinstance(old_field.type, ASDLCompositeType):
for value in old_field.as_value_list:
new_field.add_value(value.copy())
else:
for value in old_field.as_value_list:
new_field.add_value(value)
return new_tree
def to_string(self, sb=None):
is_root = False
if sb is None:
is_root = True
sb = StringIO()
sb.write('(')
sb.write(self.production.constructor.name)
for field in self.fields:
sb.write(' ')
sb.write('(')
sb.write(field.type.name)
sb.write(Field.get_cardinality_repr(field.cardinality))
sb.write('-')
sb.write(field.name)
if field.value is not None:
for val_node in field.as_value_list:
sb.write(' ')
if isinstance(field.type, ASDLCompositeType):
val_node.to_string(sb)
else:
sb.write(str(val_node).replace(' ', '-SPACE-'))
sb.write(')') # of field
sb.write(')') # of node
if is_root:
return sb.getvalue()
def __hash__(self):
code = hash(self.production)
for field in self.fields:
code = code + 37 * hash(field)
return code
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
if self.created_time != other.created_time:
return False
if self.production != other.production:
return False
if len(self.fields) != len(other.fields):
return False
for i in range(len(self.fields)):
if self.fields[i] != other.fields[i]: return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return repr(self.production)
@property
def size(self):
node_num = 1
for field in self.fields:
for val in field.as_value_list:
if isinstance(val, AbstractSyntaxTree):
node_num += val.size
else: node_num += 1
return node_num
class RealizedField(Field):
"""wrapper of field realized with values"""
def __init__(self, field, value=None, parent=None):
super(RealizedField, self).__init__(field.name, field.type, field.cardinality)
# record its parent AST node
self.parent_node = None
# FIXME: hack, return the field as a property
self.field = field
# initialize value to correct type
if self.cardinality == 'multiple':
self.value = []
if value is not None:
for child_node in value:
self.add_value(child_node)
else:
self.value = None
# note the value could be 0!
if value is not None: self.add_value(value)
# properties only used in decoding, record if the field is finished generating
# when card in [optional, multiple]
self._not_single_cardinality_finished = False
def add_value(self, value):
if isinstance(value, AbstractSyntaxTree):
value.parent_field = self
if self.cardinality == 'multiple':
self.value.append(value)
else:
self.value = value
@property
def as_value_list(self):
"""get value as an iterable"""
if self.cardinality == 'multiple': return self.value
elif self.value is not None: return [self.value]
else: return []
@property
def finished(self):
if self.cardinality == 'single':
if self.value is None: return False
else: return True
elif self.cardinality == 'optional' and self.value is not None:
return True
else:
if self._not_single_cardinality_finished: return True
else: return False
def set_finish(self):
# assert self.cardinality in ('optional', 'multiple')
self._not_single_cardinality_finished = True
def __eq__(self, other):
if super(RealizedField, self).__eq__(other):
if type(other) == Field: return True # FIXME: hack, Field and RealizedField can compare!
if self.value == other.value: return True
else: return False
else: return False

View File

@ -0,0 +1,37 @@
# coding=utf-8
from asdl.asdl import *
from asdl.hypothesis import Hypothesis
from asdl.transition_system import *
class DecodeHypothesis(Hypothesis):
def __init__(self):
super(DecodeHypothesis, self).__init__()
self.action_infos = []
self.code = None
def clone_and_apply_action_info(self, action_info):
action = action_info.action
new_hyp = self.clone_and_apply_action(action)
new_hyp.action_infos.append(action_info)
return new_hyp
def copy(self):
new_hyp = DecodeHypothesis()
if self.tree:
new_hyp.tree = self.tree.copy()
new_hyp.actions = list(self.actions)
new_hyp.action_infos = list(self.action_infos)
new_hyp.score = self.score
new_hyp._value_buffer = list(self._value_buffer)
new_hyp.t = self.t
new_hyp.code = self.code
new_hyp.update_frontier_info()
return new_hyp

122
proton/asdl/hypothesis.py Normal file
View File

@ -0,0 +1,122 @@
# coding=utf-8
from asdl.asdl import *
from asdl.asdl_ast import AbstractSyntaxTree
from asdl.transition_system import *
class Hypothesis(object):
def __init__(self):
self.tree = None
self.actions = []
self.score = 0.
self.frontier_node = None
self.frontier_field = None
self._value_buffer = []
# record the current time step
self.t = 0
def apply_action(self, action):
if self.tree is None: # the first action
assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
'at the beginning of decoding'
self.tree = AbstractSyntaxTree(action.production)
self.update_frontier_info()
elif self.frontier_node:
if isinstance(self.frontier_field.type, ASDLCompositeType):
if isinstance(action, ApplyRuleAction):
field_value = AbstractSyntaxTree(action.production)
field_value.created_time = self.t
self.frontier_field.add_value(field_value)
self.update_frontier_info()
elif isinstance(action, ReduceAction):
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
'applied on field with multiple ' \
'cardinality'
self.frontier_field.set_finish()
self.update_frontier_info()
else:
raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field))
else: # fill in a primitive field
if isinstance(action, GenTokenAction):
# only field of type string requires termination signal </primitive>
end_primitive = False
if self.frontier_field.type.name == 'string':
if action.is_stop_signal():
self.frontier_field.add_value(' '.join(self._value_buffer))
self._value_buffer = []
end_primitive = True
else:
self._value_buffer.append(action.token)
else:
self.frontier_field.add_value(action.token)
end_primitive = True
if end_primitive and self.frontier_field.cardinality in ('single', 'optional'):
self.frontier_field.set_finish()
self.update_frontier_info()
elif isinstance(action, ReduceAction):
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
'applied on field with multiple ' \
'cardinality'
self.frontier_field.set_finish()
self.update_frontier_info()
else:
raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields')
self.t += 1
self.actions.append(action)
def update_frontier_info(self):
def _find_frontier_node_and_field(tree_node):
# return None if each field of this ast node is realized else unfinished ast node, unrealized field
if tree_node:
for field in tree_node.fields:
# if it's an intermediate node, check its children
if isinstance(field.type, ASDLCompositeType) and field.value:
if field.cardinality in ('single', 'optional'): iter_values = [field.value]
else: iter_values = field.value
for child_node in iter_values:
result = _find_frontier_node_and_field(child_node)
if result: return result
# now all its possible children are checked
if not field.finished:
return tree_node, field
return None
else: return None
frontier_info = _find_frontier_node_and_field(self.tree)
if frontier_info:
self.frontier_node, self.frontier_field = frontier_info
else:
self.frontier_node, self.frontier_field = None, None
def clone_and_apply_action(self, action):
new_hyp = self.copy()
new_hyp.apply_action(action)
return new_hyp
def copy(self):
new_hyp = Hypothesis()
if self.tree:
new_hyp.tree = self.tree.copy()
new_hyp.actions = list(self.actions)
new_hyp.score = self.score
new_hyp._value_buffer = list(self._value_buffer)
new_hyp.t = self.t
new_hyp.update_frontier_info()
return new_hyp
@property
def completed(self):
return self.tree and self.frontier_field is None

View File

@ -0,0 +1,103 @@
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
# val: value(float/string)/sql(dict)/col_unit(tuple)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, tab_id/sql)
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/integer
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
# JOIN_KEYWORDS = ('join', 'on', 'as')
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
# UNIT_OPS = ('none', '-', '+', "*", '/')
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
# TABLE_TYPE = ('sql', 'table_unit')
# COND_OPS = ('and', 'or')
# SQL_OPS = ('intersect', 'union', 'except')
# ORDER_OPS = ('desc', 'asc')
##########################################################################
# 1. eliminate ? by enumerating different number of items
# 2. from conds, distinct and value generation are also not considered
# 3. for select items, we use val_unit instead of (agg, val_unit)
# 4. for orderby items, we use col_unit instead of val_unit
# 5. for groupby items, we use col_unit instead of col_id
# 6. predict from clause first, to obtain an overview of the entire query graph
col_id, tab_id
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
| Union(sql_unit sql_unit, sql_unit sql_unit)
| Except(sql_unit sql_unit, sql_unit sql_unit)
| Single(sql_unit sql_unit)
sql_unit = Complete(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause)
| NoWhere(from from_clause, val_unit* select_clause, group_by group_by_clause, order_by order_by_clause)
| NoGroupBy(from from_clause, val_unit* select_clause, cond where_clause, order_by order_by_clause)
| NoOrderBy(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause)
| OnlyWhere(from from_clause, val_unit* select_clause, cond where_clause)
| OnlyGroupBy(from from_clause, val_unit* select_clause, group_by group_by_clause)
| OnlyOrderBy(from from_clause, val_unit* select_clause, order_by order_by_clause)
| Simple(from from_clause, val_unit* select_clause)
from = FromTable(tab_id* tab_id_list)
| FromSQL(sql from_sql)
group_by = Having(col_unit* col_unit_list, cond having_clause)
| NoHaving(col_unit* col_unit_list)
order_by = Asc(col_unit* col_unit_list)
| Desc(col_unit* col_unit_list)
| AscLimit(col_unit* col_unit_list)
| DescLimit(col_unit* col_unit_list)
cond = And(cond left, cond right)
| Or(cond left, cond right)
| Between(val_unit val_unit)
| Eq(val_unit val_unit)
| Gt(val_unit val_unit)
| Lt(val_unit val_unit)
| Ge(val_unit val_unit)
| Le(val_unit val_unit)
| Neq(val_unit val_unit)
| Like(val_unit val_unit)
| NotLike(val_unit val_unit)
| BetweenSQL(val_unit val_unit, sql cond_sql)
| EqSQL(val_unit val_unit, sql cond_sql)
| GtSQL(val_unit val_unit, sql cond_sql)
| LtSQL(val_unit val_unit, sql cond_sql)
| GeSQL(val_unit val_unit, sql cond_sql)
| LeSQL(val_unit val_unit, sql cond_sql)
| NeqSQL(val_unit val_unit, sql cond_sql)
| InSQL(val_unit val_unit, sql cond_sql)
| NotInSQL(val_unit val_unit, sql cond_sql)
val_unit = Unary(col_unit col_unit)
| Minus(col_unit col_unit, col_unit col_unit)
| Plus(col_unit col_unit, col_unit col_unit)
| Times(col_unit col_unit, col_unit col_unit)
| Divide(col_unit col_unit, col_unit col_unit)
col_unit = None(col_id col_id)
| Max(col_id col_id)
| Min(col_id col_id)
| Count(col_id col_id)
| Sum(col_id col_id)
| Avg(col_id col_id)

View File

@ -0,0 +1,111 @@
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
# val: value(float/string)/sql(dict)/col_unit(tuple)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, tab_id/sql)
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/integer
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
# JOIN_KEYWORDS = ('join', 'on', 'as')
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
# UNIT_OPS = ('none', '-', '+', "*", '/')
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
# TABLE_TYPE = ('sql', 'table_unit')
# COND_OPS = ('and', 'or')
# SQL_OPS = ('intersect', 'union', 'except')
# ORDER_OPS = ('desc', 'asc')
##########################################################################
# 1. eliminate * by enumerating different number of items
# 2. from conds, distinct and value generation are also not considered
# 3. for select items, we use val_unit instead of (agg, val_unit)
# 4. for orderby items, we use col_unit instead of val_unit
# 5. for groupby items, we use col_unit instead of col_id
# 6. predict from clause first, to obtain an overview of the entire query graph
col_id, tab_id
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
| Union(sql_unit sql_unit, sql_unit sql_unit)
| Except(sql_unit sql_unit, sql_unit sql_unit)
| Single(sql_unit sql_unit)
sql_unit = SQL(from from_clause, select select_clause, cond? where_clause, group_by? group_by_clause, order_by? order_by_clause)
select = SelectOne(val_unit val_unit)
| SelectTwo(val_unit val_unit, val_unit val_unit)
| SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit)
| SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
| SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
from = FromOneTable(tab_id tab_id)
| FromTwoTable(tab_id tab_id, tab_id tab_id)
| FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromSQL(sql from_sql)
group_by = GroupByOne(col_unit col_unit, cond? having_clause)
| GroupByTwo(col_unit col_unit, col_unit col_unit, cond? having_clause)
order_by = OneAsc(col_unit col_unit)
| OneDesc(col_unit col_unit)
| OneAscLimit(col_unit col_unit)
| OneDescLimit(col_unit col_unit)
| TwoAsc(col_unit col_unit, col_unit col_unit)
| TwoDesc(col_unit col_unit, col_unit col_unit)
| TwoAscLimit(col_unit col_unit, col_unit col_unit)
| TwoDescLimit(col_unit col_unit, col_unit col_unit)
cond = And(cond left, cond right)
| Or(cond left, cond right)
| Between(val_unit val_unit)
| Eq(val_unit val_unit)
| Gt(val_unit val_unit)
| Lt(val_unit val_unit)
| Ge(val_unit val_unit)
| Le(val_unit val_unit)
| Neq(val_unit val_unit)
| Like(val_unit val_unit)
| NotLike(val_unit val_unit)
| BetweenSQL(val_unit val_unit, sql cond_sql)
| EqSQL(val_unit val_unit, sql cond_sql)
| GtSQL(val_unit val_unit, sql cond_sql)
| LtSQL(val_unit val_unit, sql cond_sql)
| GeSQL(val_unit val_unit, sql cond_sql)
| LeSQL(val_unit val_unit, sql cond_sql)
| NeqSQL(val_unit val_unit, sql cond_sql)
| InSQL(val_unit val_unit, sql cond_sql)
| NotInSQL(val_unit val_unit, sql cond_sql)
val_unit = Unary(col_unit col_unit)
| Minus(col_unit col_unit, col_unit col_unit)
| Plus(col_unit col_unit, col_unit col_unit)
| Times(col_unit col_unit, col_unit col_unit)
| Divide(col_unit col_unit, col_unit col_unit)
col_unit = None(col_id col_id)
| Max(col_id col_id)
| Min(col_id col_id)
| Count(col_id col_id)
| Sum(col_id col_id)
| Avg(col_id col_id)

View File

@ -0,0 +1,120 @@
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
# val: value(float/string)/sql(dict)/col_unit(tuple)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, tab_id/sql)
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/integer
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
# JOIN_KEYWORDS = ('join', 'on', 'as')
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
# UNIT_OPS = ('none', '-', '+', "*", '/')
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
# TABLE_TYPE = ('sql', 'table_unit')
# COND_OPS = ('and', 'or')
# SQL_OPS = ('intersect', 'union', 'except')
# ORDER_OPS = ('desc', 'asc')
##########################################################################
# 1. eliminate both ? and * cardinality by enumerating different number of items
# 2. from conds, distinct and value generation are also not considered
# 3. for select items, we use val_unit instead of (agg, val_unit)
# 4. for orderby items, we use col_unit instead of val_unit
# 5. for groupby items, we use col_unit instead of col_id
# 6. predict from clause first, to obtain an overview of the entire query graph
col_id, tab_id
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
| Union(sql_unit sql_unit, sql_unit sql_unit)
| Except(sql_unit sql_unit, sql_unit sql_unit)
| Single(sql_unit sql_unit)
sql_unit = Complete(from from_clause, select select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause)
| NoWhere(from from_clause, select select_clause, group_by group_by_clause, order_by order_by_clause)
| NoGroupBy(from from_clause, select select_clause, cond where_clause, order_by order_by_clause)
| NoOrderBy(from from_clause, select select_clause, cond where_clause, group_by group_by_clause)
| OnlyWhere(from from_clause, select select_clause, cond where_clause)
| OnlyGroupBy(from from_clause, select select_clause, group_by group_by_clause)
| OnlyOrderBy(from from_clause, select select_clause, order_by order_by_clause)
| Simple(from from_clause, select select_clause)
select = SelectOne(val_unit val_unit)
| SelectTwo(val_unit val_unit, val_unit val_unit)
| SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit)
| SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
| SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
from = FromOneTable(tab_id tab_id)
| FromTwoTable(tab_id tab_id, tab_id tab_id)
| FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
| FromSQL(sql from_sql)
group_by = OneNoHaving(col_unit col_unit)
| TwoNoHaving(col_unit col_unit, col_unit col_unit)
| OneHaving(col_unit col_unit, cond having_clause)
| TwoHaving(col_unit col_unit, col_unit col_unit, cond having_clause)
order_by = OneAsc(col_unit col_unit)
| OneDesc(col_unit col_unit)
| OneAscLimit(col_unit col_unit)
| OneDescLimit(col_unit col_unit)
| TwoAsc(col_unit col_unit, col_unit col_unit)
| TwoDesc(col_unit col_unit, col_unit col_unit)
| TwoAscLimit(col_unit col_unit, col_unit col_unit)
| TwoDescLimit(col_unit col_unit, col_unit col_unit)
cond = And(cond left, cond right)
| Or(cond left, cond right)
| Between(val_unit val_unit)
| Eq(val_unit val_unit)
| Gt(val_unit val_unit)
| Lt(val_unit val_unit)
| Ge(val_unit val_unit)
| Le(val_unit val_unit)
| Neq(val_unit val_unit)
| Like(val_unit val_unit)
| NotLike(val_unit val_unit)
| BetweenSQL(val_unit val_unit, sql cond_sql)
| EqSQL(val_unit val_unit, sql cond_sql)
| GtSQL(val_unit val_unit, sql cond_sql)
| LtSQL(val_unit val_unit, sql cond_sql)
| GeSQL(val_unit val_unit, sql cond_sql)
| LeSQL(val_unit val_unit, sql cond_sql)
| NeqSQL(val_unit val_unit, sql cond_sql)
| InSQL(val_unit val_unit, sql cond_sql)
| NotInSQL(val_unit val_unit, sql cond_sql)
val_unit = Unary(col_unit col_unit)
| Minus(col_unit col_unit, col_unit col_unit)
| Plus(col_unit col_unit, col_unit col_unit)
| Times(col_unit col_unit, col_unit col_unit)
| Divide(col_unit col_unit, col_unit col_unit)
col_unit = None(col_id col_id)
| Max(col_id col_id)
| Min(col_id col_id)
| Count(col_id col_id)
| Sum(col_id col_id)
| Avg(col_id col_id)

View File

@ -0,0 +1,175 @@
#coding=utf8
from asdl.asdl import ASDLGrammar
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class Parser():
""" Parse a sql dict into AbstractSyntaxTree object according to specified grammar rules
Some common methods are implemented in this parent class.
"""
def __init__(self, grammar: ASDLGrammar):
super(Parser, self).__init__()
self.grammar = grammar
@classmethod
def from_grammar(cls, grammar: ASDLGrammar):
grammar_name = grammar._grammar_name
if 'v0' in grammar_name:
from asdl.sql.parser.parser_v0 import ParserV0
return ParserV0(grammar)
elif 'v1' in grammar_name:
from asdl.sql.parser.parser_v1 import ParserV1
return ParserV1(grammar)
elif 'v2' in grammar_name:
from asdl.sql.parser.parser_v2 import ParserV2
return ParserV2(grammar)
else:
raise ValueError('Not recognized grammar name %s' % (grammar_name))
def parse(self, sql_json: dict):
""" sql_json is exactly the 'sql' field of each data sample
return AbstractSyntaxTree of sql
"""
try:
ast_node = self.parse_sql(sql_json)
return ast_node
except Exception as e:
print('Something Error happened while parsing:', e)
# if fail to parse, just return select * from table(id=0)
error_sql = {
"select": [False, [(0, [0, [0, 0, False], None])]],
"from": {'table_units': [('table_unit', 0)], 'conds': []},
"where": [], "groupBy": [], "orderBy": [], "having": [], "limit": None,
"intersect": [], "union": [], "except": []
}
ast_node = self.parse_sql(error_sql)
return ast_node
def parse_sql(self, sql: dict):
""" Determine whether sql has intersect/union/except,
at most one in the current dict
"""
for choice in ['intersect', 'union', 'except']:
if sql[choice]:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(choice.title()))
nested_sql = sql[choice]
sql_field1, sql_field2 = ast_node.fields
sql_field1.add_value(self.parse_sql_unit(sql))
sql_field2.add_value(self.parse_sql_unit(nested_sql))
return ast_node
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Single'))
ast_node.fields[0].add_value(self.parse_sql_unit(sql))
return ast_node
def parse_sql_unit(self, sql: dict):
""" Parse a single sql unit, determine the existence of different clauses
"""
sql_ctr = ['Complete', 'NoWhere', 'NoGroupBy', 'NoOrderBy', 'OnlyWhere', 'OnlyGroupBy', 'OnlyOrderBy', 'Simple']
where_field, groupby_field, orderby_field = [None] * 3
if sql['where'] and sql['groupBy'] and sql['orderBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[0]))
from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields
elif sql['groupBy'] and sql['orderBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[1]))
from_field, select_field, groupby_field, orderby_field = ast_node.fields
elif sql['where'] and sql['orderBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[2]))
from_field, select_field, where_field, orderby_field = ast_node.fields
elif sql['where'] and sql['groupBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[3]))
from_field, select_field, where_field, groupby_field = ast_node.fields
elif sql['where']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[4]))
from_field, select_field, where_field = ast_node.fields
elif sql['groupBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[5]))
from_field, select_field, groupby_field = ast_node.fields
elif sql['orderBy']:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[6]))
from_field, select_field, orderby_field = ast_node.fields
else:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[7]))
from_field, select_field = ast_node.fields
self.parse_from(sql['from'], from_field)
self.parse_select(sql['select'], select_field)
if sql['where']:
self.parse_where(sql['where'], where_field)
if sql['groupBy']: # if having clause is not empty, groupBy must exist
self.parse_groupby(sql['groupBy'], sql['having'], groupby_field)
if sql['orderBy']: # if limit is not None, orderBY is not empty
self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field)
return ast_node
def parse_select(self, select_clause: list, select_field: RealizedField):
raise NotImplementedError
def parse_from(self, from_clause: dict, from_field: RealizedField):
raise NotImplementedError
def parse_where(self, where_clause: list, where_field: RealizedField):
where_field.add_value(self.parse_conds(where_clause))
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
raise NotImplementedError
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
raise NotImplementedError
def parse_conds(self, conds: list):
assert len(conds) > 0
and_or = (len(conds) - 1) // 2
root_node, left_field, right_field = [None] * 3
for i in reversed(range(and_or)):
and_or_idx = 2 * i + 1
conj = conds[and_or_idx]
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(conj.title()))
if root_node is None:
root_node = ast_node
if left_field is not None:
left_field.add_value(ast_node)
left_field, right_field = ast_node.fields
right_field.add_value(self.parse_cond(conds[2 * (i + 1)]))
if left_field is None:
root_node = self.parse_cond(conds[0])
else:
left_field.add_value(self.parse_cond(conds[0]))
return root_node
def parse_cond(self, cond: list):
not_op, cmp_op, val_unit, val1, val2 = cond
not_op = '^' if not_op else ''
sql_val = 'sql' if type(val1) == dict else ''
op_list = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
cmp_op = not_op + op_list[cmp_op] + sql_val
op_dict = {
'between': 'Between', '=': 'Eq', '>': 'Gt', '<': 'Lt', '>=': 'Ge', '<=': 'Le', '!=': 'Neq',
'insql': 'InSQL', 'like': 'Like', '^insql': 'NotInSQL', '^like': 'NotLike', 'betweensql': 'BetweenSQL', '=sql': 'EqSQL',
'>sql': 'GtSQL', '<sql': 'LtSQL', '>=sql': 'GeSQL', '<=sql': 'LeSQL', '!=sql': 'NeqSQL'
}
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(op_dict[cmp_op]))
val_unit_field = ast_node.fields[0]
val_unit_field.add_value(self.parse_val_unit(val_unit))
if len(ast_node.fields) == 2:
val_field = ast_node.fields[1]
val_field.add_value(self.parse_sql(val1))
return ast_node
def parse_val_unit(self, val_unit: list):
unit_op, col_unit1, col_unit2 = val_unit
unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(unit_op_list[unit_op]))
if unit_op == 0:
ast_node.fields[0].add_value(self.parse_col_unit(col_unit1))
else:
# ast_node.fields[0].add_value(int(col_unit1[1]))
# ast_node.fields[1].add_value(int(col_unit2[1]))
ast_node.fields[0].add_value(self.parse_col_unit(col_unit1))
ast_node.fields[1].add_value(self.parse_col_unit(col_unit2))
return ast_node
def parse_col_unit(self, col_unit: list):
agg_op, col_id, distinct_flag = col_unit
agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg_op]))
ast_node.fields[0].add_value(int(col_id))
return ast_node

View File

@ -0,0 +1,69 @@
#coding=utf8
from asdl.sql.parser.parser_base import Parser
from asdl.asdl import ASDLGrammar
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class ParserV0(Parser):
""" In this version, we eliminate all cardinality ? and restrict that * must have at least one item
"""
def parse_select(self, select_clause: list, select_field: RealizedField):
"""
ignore cases agg(col_id1 op col_id2) and agg(col_id1) op agg(col_id2)
"""
select_clause = select_clause[1] # list of (agg, val_unit)
unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide']
agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg']
for agg, val_unit in select_clause:
if agg != 0: # agg col_id
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
col_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg]))
col_node.fields[0].add_value(int(val_unit[1][1]))
ast_node.fields[0].add_value(col_node)
else: # binary_op col_id1 col_id2
ast_node = self.parse_val_unit(val_unit)
select_field.add_value(ast_node)
def parse_from(self, from_clause: dict, from_field: RealizedField):
"""
Ignore from conditions, since it is not evaluated in evaluation script
"""
table_units = from_clause['table_units']
t = table_units[0][0]
if t == 'table_unit':
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromTable'))
tables_field = ast_node.fields[0]
for _, v in table_units:
tables_field.add_value(int(v))
else:
assert t == 'sql'
v = table_units[0][1]
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
ast_node.fields[0].add_value(self.parse_sql(v))
from_field.add_value(ast_node)
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
col_ids = []
for col_unit in groupby_clause:
col_ids.append(col_unit[1]) # agg is None and isDistinct False
if having_clause:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Having'))
col_units_field, having_fields = ast_node.fields
having_fields.add_value(self.parse_conds(having_clause))
else:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('NoHaving'))
col_units_field = ast_node.fields[0]
for col_unit in groupby_clause:
col_units_field.add_value(self.parse_col_unit(col_unit))
groupby_field.add_value(ast_node)
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
if limit is None:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Asc')) if orderby_clause[0] == 'asc' \
else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Desc'))
else:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('AscLimit')) if orderby_clause[0] == 'asc' \
else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('DescLimit'))
col_units_field = ast_node.fields[0]
for val_unit in orderby_clause[1]:
col_units_field.add_value(self.parse_col_unit(val_unit[1]))
orderby_field.add_value(ast_node)

View File

@ -0,0 +1,83 @@
#coding=utf8
from asdl.sql.parser.parser_base import Parser
from asdl.asdl import ASDLGrammar
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class ParserV1(Parser):
""" In this version, we eliminate all cardinality * and use ?
"""
def parse_sql_unit(self, sql: dict):
""" Parse a single sql unit, determine the existence of different clauses
"""
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('SQL'))
from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields
self.parse_from(sql['from'], from_field)
self.parse_select(sql['select'], select_field)
if sql['where']:
self.parse_where(sql['where'], where_field)
if sql['groupBy']: # if having clause is not empty, groupBy must exist
self.parse_groupby(sql['groupBy'], sql['having'], groupby_field)
if sql['orderBy']: # if limit is not None, orderBY is not empty
self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field)
return ast_node
def parse_select(self, select_clause: list, select_field: RealizedField):
select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag
select_num = min(5, len(select_clause))
select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1]))
for i, (agg, val_unit) in enumerate(select_clause):
if i >= 5: break
if agg != 0: # MAX/MIN/COUNT/SUM/AVG
val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
col_unit = [agg] + val_unit[1][1:]
val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit))
else:
val_unit_ast = self.parse_val_unit(val_unit)
ast_node.fields[i].add_value(val_unit_ast)
select_field.add_value(ast_node)
def parse_from(self, from_clause: dict, from_field: RealizedField):
""" Ignore from conditions, since it is not evaluated in evaluation script
"""
table_units = from_clause['table_units']
t = table_units[0][0]
if t == 'table_unit':
table_num = min(6, len(table_units))
table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1]))
for i, (_, tab_id) in enumerate(table_units):
if i >= 6: break
ast_node.fields[i].add_value(int(tab_id))
else:
assert t == 'sql'
v = table_units[0][1]
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
ast_node.fields[0].add_value(self.parse_sql(v))
from_field.add_value(ast_node)
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
groupby_ctr = ['GroupByOne', 'GroupByTwo']
groupby_num = min(2, len(groupby_clause)) - 1
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num]))
if having_clause:
having_field = ast_node.fields[-1]
having_field.add_value(self.parse_conds(having_clause))
for i, col_unit in enumerate(groupby_clause):
if i >= 2: break
# ast_node.fields[i].add_value(int(col_unit[1]))
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
groupby_field.add_value(ast_node)
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
orderby_num = min(2, len(orderby_clause[1]))
num_str = 'One' if orderby_num == 1 else 'Two'
order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc'
limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str))
for i, val_unit in enumerate(orderby_clause[1]):
if i >= 2: break
col_unit = val_unit[1]
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
# ast_node.fields[i].add_value(self.parse_val_unit(val_unit))
orderby_field.add_value(ast_node)

View File

@ -0,0 +1,71 @@
#coding=utf8
from asdl.sql.parser.parser_base import Parser
from asdl.asdl import ASDLGrammar
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class ParserV2(Parser):
""" In this version, we remove all cardinality ? or *
by enumerating all different lengths of item list, such as SelectOne, SelectTwo
"""
def parse_select(self, select_clause: list, select_field: RealizedField):
select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag
select_num = min(5, len(select_clause))
select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1]))
for i, (agg, val_unit) in enumerate(select_clause):
if i >= 5: break
if agg != 0: # MAX/MIN/COUNT/SUM/AVG
val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
col_unit = [agg] + val_unit[1][1:]
val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit))
else:
val_unit_ast = self.parse_val_unit(val_unit)
ast_node.fields[i].add_value(val_unit_ast)
select_field.add_value(ast_node)
def parse_from(self, from_clause: dict, from_field: RealizedField):
""" Ignore from conditions, since it is not evaluated in evaluation script
"""
table_units = from_clause['table_units']
t = table_units[0][0]
if t == 'table_unit':
table_num = min(6, len(table_units))
table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable']
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1]))
for i, (_, tab_id) in enumerate(table_units):
if i >= 6: break
ast_node.fields[i].add_value(int(tab_id))
else:
assert t == 'sql'
v = table_units[0][1]
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
ast_node.fields[0].add_value(self.parse_sql(v))
from_field.add_value(ast_node)
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
groupby_ctr = ['OneNoHaving', 'TwoNoHaving', 'OneHaving', 'TwoHaving']
groupby_num = min(2, len(groupby_clause))
if having_clause:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num + 1]))
having_field = ast_node.fields[-1]
having_field.add_value(self.parse_conds(having_clause))
else:
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num - 1]))
for i, col_unit in enumerate(groupby_clause):
if i >= 2: break
# ast_node.fields[i].add_value(int(col_unit[1]))
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
groupby_field.add_value(ast_node)
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
orderby_num = min(2, len(orderby_clause[1]))
num_str = 'One' if orderby_num == 1 else 'Two'
order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc'
limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str))
for i, val_unit in enumerate(orderby_clause[1]):
if i >= 2: break
col_unit = val_unit[1]
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
# ast_node.fields[i].add_value(self.parse_val_unit(val_unit))
orderby_field.add_value(ast_node)

View File

@ -0,0 +1,174 @@
# coding=utf-8
import json, os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from asdl.sql.parser.parser_base import Parser
from asdl.sql.unparser.unparser_base import UnParser
from asdl.asdl import ASDLGrammar
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
from asdl.transition_system import GenTokenAction, TransitionSystem, ApplyRuleAction, ReduceAction
class SelectColumnAction(GenTokenAction):
def __init__(self, column_id):
super(SelectColumnAction, self).__init__(column_id)
@property
def column_id(self):
return self.token
def __repr__(self):
return 'SelectColumnAction[id=%s]' % self.column_id
class SelectTableAction(GenTokenAction):
def __init__(self, table_id):
super(SelectTableAction, self).__init__(table_id)
@property
def table_id(self):
return self.token
def __repr__(self):
return 'SelectTableAction[id=%s]' % self.table_id
class SQLTransitionSystem(TransitionSystem):
def __init__(self, grammar):
self.grammar = grammar
self.parser = Parser.from_grammar(self.grammar)
self.unparser = UnParser.from_grammar(self.grammar)
def ast_to_surface_code(self, asdl_ast, table, *args, **kargs):
return self.unparser.unparse(asdl_ast, table, *args, **kargs)
def compare_ast(self, hyp_ast, ref_ast):
raise NotImplementedError
def tokenize_code(self, code, mode):
raise NotImplementedError
def surface_code_to_ast(self, code):
return self.parser.parse(code)
def get_valid_continuation_types(self, hyp):
if hyp.tree:
if self.grammar.is_composite_type(hyp.frontier_field.type):
if hyp.frontier_field.cardinality == 'single':
return ApplyRuleAction,
elif hyp.frontier_field.cardinality == 'multiple':
if len(hyp.frontier_field.value) == 0:
return ApplyRuleAction,
else:
return ApplyRuleAction, ReduceAction
else:
return ApplyRuleAction, ReduceAction
elif hyp.frontier_field.type.name == 'col_id':
if hyp.frontier_field.cardinality == 'single':
return SelectColumnAction,
elif hyp.frontier_field.cardinality == 'multiple':
if len(hyp.frontier_field.value) == 0:
return SelectColumnAction,
else:
return SelectColumnAction, ReduceAction
else: # optional, not used
return SelectColumnAction, ReduceAction
elif hyp.frontier_field.type.name == 'tab_id':
if hyp.frontier_field.cardinality == 'single':
return SelectTableAction,
elif hyp.frontier_field.cardinality == 'multiple':
if len(hyp.frontier_field.value) == 0:
return SelectTableAction,
else:
return SelectTableAction, ReduceAction
else: # optional, not used
return SelectTableAction, ReduceAction
else: # not used now
return GenTokenAction,
else:
return ApplyRuleAction,
def get_primitive_field_actions(self, realized_field):
if realized_field.type.name == 'col_id':
if realized_field.cardinality == 'multiple':
action_list = []
for idx in realized_field.value:
action_list.append(SelectColumnAction(int(idx)))
return action_list
elif realized_field.value is not None:
return [SelectColumnAction(int(realized_field.value))]
else:
return []
elif realized_field.type.name == 'tab_id':
if realized_field.cardinality == 'multiple':
action_list = []
for idx in realized_field.value:
action_list.append(SelectTableAction(int(idx)))
return action_list
elif realized_field.value is not None:
return [SelectTableAction(int(realized_field.value))]
else:
return []
else:
raise ValueError('unknown primitive field type')
if __name__ == '__main__':
try:
from evaluation import evaluate, build_foreign_key_map_from_json
except Exception:
print('Cannot find evaluator ...')
grammar = ASDLGrammar.from_filepath('asdl/sql/grammar/sql_asdl_v2.txt')
print('Total number of productions:', len(grammar))
for each in grammar.productions:
print(each)
print('Total number of types:', len(grammar.types))
for each in grammar.types:
print(each)
print('Total number of fields:', len(grammar.fields))
for each in grammar.fields:
print(each)
spider_trans = SQLTransitionSystem(grammar)
kmaps = build_foreign_key_map_from_json('data/tables.json')
dbs_list = json.load(open('data/tables.json', 'r'))
dbs = {}
for each in dbs_list:
dbs[each['db_id']] = each
train = json.load(open('data/train.json', 'r'))
train_db = [ex['db_id'] for ex in train]
train = [ex['sql'] for ex in train]
dev = json.load(open('data/dev.json', 'r'))
dev_db = [ex['db_id'] for ex in dev]
dev = [ex['sql'] for ex in dev]
recovered_sqls = []
for idx in range(len(train)):
sql_ast = spider_trans.surface_code_to_ast(train[idx])
sql_ast.sanity_check()
# print(spider_trans.get_actions(sql_ast))
recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[train_db[idx]])
# print(recovered_sql)
recovered_sqls.append(recovered_sql)
with open('data/train_pred.sql', 'w') as of:
for each in recovered_sqls:
of.write(each + '\n')
with open('data/eval_train.log', 'w') as of:
old_print = sys.stdout
sys.stdout = of
evaluate('data/train_gold.sql', 'data/train_pred.sql', 'data/database', 'match', kmaps)
sys.stdout = old_print
recovered_sqls = []
for idx in range(len(dev)):
sql_ast = spider_trans.surface_code_to_ast(dev[idx])
sql_ast.sanity_check()
recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[dev_db[idx]])
recovered_sqls.append(recovered_sql)
with open('data/dev_pred.sql', 'w') as of:
for each in recovered_sqls:
of.write(each + '\n')
with open('data/eval_dev.log', 'w') as of:
old_print = sys.stdout
sys.stdout = of
evaluate('data/dev_gold.sql', 'data/dev_pred.sql', 'data/database', 'match', kmaps)
sys.stdout = old_print

View File

@ -0,0 +1,165 @@
#coding=utf8
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class UnParser():
def __init__(self, grammar: ASDLGrammar):
""" ASDLGrammar """
super(UnParser, self).__init__()
self.grammar = grammar
@classmethod
def from_grammar(cls, grammar: ASDLGrammar):
grammar_name = grammar._grammar_name
if 'v0' in grammar_name:
from asdl.sql.unparser.unparser_v0 import UnParserV0
return UnParserV0(grammar)
elif 'v1' in grammar_name:
from asdl.sql.unparser.unparser_v1 import UnParserV1
return UnParserV1(grammar)
elif 'v2' in grammar_name:
from asdl.sql.unparser.unparser_v2 import UnParserV2
return UnParserV2(grammar)
else:
raise ValueError('Not recognized grammar name %s' % (grammar_name))
def unparse(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
try:
sql = self.unparse_sql(sql_ast, db, *args, **kargs)
sql = ' '.join([i for i in sql.split(' ') if i != ''])
return sql
except Exception as e:
print('Something Error happened while unparsing:', e)
return 'SELECT * FROM %s' % (db['table_names_original'][0])
def unparse_sql(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
prod_name = sql_ast.production.constructor.name
if prod_name == 'Intersect':
return '%s INTERSECT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
elif prod_name == 'Union':
return '%s UNION %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
elif prod_name == 'Except':
return '%s EXCEPT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
else:
return self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs)
def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs):
sql_ast = sql_field.value
prod_name = sql_ast.production.constructor.name
from_str = self.unparse_from(sql_ast.fields[0], db, *args, **kargs)
select_str = self.unparse_select(sql_ast.fields[1], db, *args, **kargs)
if prod_name == 'Complete':
return 'SELECT %s FROM %s WHERE %s GROUP BY %s ORDER BY %s' % (
select_str, from_str,
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs),
self.unparse_orderby(sql_ast.fields[4], db, *args, **kargs)
)
elif prod_name == 'NoWhere':
return 'SELECT %s FROM %s GROUP BY %s ORDER BY %s' % (
select_str, from_str,
self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs),
self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs),
)
elif prod_name == 'NoGroupBy':
return 'SELECT %s FROM %s WHERE %s ORDER BY %s' % (
select_str, from_str,
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs),
)
elif prod_name == 'NoOrderBy':
return 'SELECT %s FROM %s WHERE %s GROUP BY %s' % (
select_str, from_str,
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs),
)
elif prod_name == 'OnlyWhere':
return 'SELECT %s FROM %s WHERE %s' % (
select_str, from_str,
self.unparse_where(sql_ast.fields[2], db, *args, **kargs)
)
elif prod_name == 'OnlyGroupBy':
return 'SELECT %s FROM %s GROUP BY %s' % (
select_str, from_str,
self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs)
)
elif prod_name == 'OnlyOrderBy':
return 'SELECT %s FROM %s ORDER BY %s' % (
select_str, from_str,
self.unparse_orderby(sql_ast.fields[2], db, *args, **kargs)
)
else:
return 'SELECT %s FROM %s' % (select_str, from_str)
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
raise NotImplementedError
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
raise NotImplementedError
def unparse_where(self, where_field: RealizedField, db: dict, *args, **kargs):
return self.unparse_conds(where_field.value, db, *args, **kargs)
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
raise NotImplementedError
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
raise NotImplementedError
def unparse_conds(self, conds_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
ctr_name = conds_ast.production.constructor.name
if ctr_name in ['And', 'Or']:
left_cond, right_cond = conds_ast.fields
return self.unparse_conds(left_cond.value, db, *args, **kargs) + ' ' + ctr_name.upper() + ' ' + \
self.unparse_conds(right_cond.value, db, *args, **kargs)
else:
return self.unparse_cond(conds_ast, db, *args, **kargs)
def unparse_cond(self, cond_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
ctr_name = cond_ast.production.constructor.name
val_unit_str = self.unparse_val_unit(cond_ast.fields[0].value, db, *args, **kargs)
val_str = '( ' + self.unparse_sql(cond_ast.fields[1].value, db, *args, **kargs) + ' )' if len(cond_ast.fields) == 2 else '"value"'
if ctr_name.startswith('Between'):
return val_unit_str + ' BETWEEN ' + val_str + ' AND "value"'
else:
op_dict = {
'Between': ' BETWEEN ', 'Eq': ' = ', 'Gt': ' > ', 'Lt': ' < ', 'Ge': ' >= ', 'Le': ' <= ', 'Neq': ' != ',
'In': ' IN ', 'Like': ' LIKE ', 'NotIn': ' NOT IN ', 'NotLike': ' NOT LIKE '
}
ctr_name = ctr_name if 'SQL' not in ctr_name else ctr_name[:ctr_name.index('SQL')]
op = op_dict[ctr_name]
return op.join([val_unit_str, val_str])
def unparse_val_unit(self, val_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
unit_op = val_unit_ast.production.constructor.name
if unit_op == 'Unary':
return self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs)
else:
binary = {'Minus': ' - ', 'Plus': ' + ', 'Times': ' * ', 'Divide': ' / '}
op = binary[unit_op]
return op.join([self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs),
self.unparse_col_unit(val_unit_ast.fields[1].value, db, *args, **kargs)])
# col_id1, col_id2 = int(val_unit_ast.fields[0].value), int(val_unit_ast.fields[1].value)
# tab_id1, col_name1 = db['column_names_original'][col_id1]
# if col_id1 != 0:
# tab_name1 = db['table_names_original'][tab_id1]
# col_name1 = tab_name1 + '.' + col_name1
# tab_id2, col_name2 = db['column_names_original'][col_id2]
# if col_id2 != 0:
# tab_name2 = db['table_names_original'][tab_id2]
# col_name2 = tab_name2 + '.' + col_name2
# return op.join([col_name1, col_name2])
def unparse_col_unit(self, col_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
agg = col_unit_ast.production.constructor.name
col_id = int(col_unit_ast.fields[0].value)
tab_id, col_name = db['column_names_original'][col_id]
if col_id != 0:
tab_name = db['table_names_original'][tab_id]
col_name = tab_name + '.' + col_name
if agg == 'None':
return col_name
else: # Max/Min/Count/Sum/Avg
return agg.upper() + '(' + col_name + ')'

View File

@ -0,0 +1,59 @@
#coding=utf8
from asdl.sql.unparser.unparser_base import UnParser
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class UnParserV0(UnParser):
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
select_list = select_field.value
select_items = []
for val_unit_ast in select_list:
val_unit_str = self.unparse_val_unit(val_unit_ast, db, *args, **kargs)
select_items.append(val_unit_str)
return ' , '.join(select_items)
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
from_ast = from_field.value
ctr_name = from_ast.production.constructor.name
if ctr_name == 'FromTable':
tab_ids = from_ast.fields[0].value
if len(tab_ids) == 1:
return db['table_names_original'][tab_ids[0]]
else:
tab_names = [db['table_names_original'][i] for i in tab_ids]
return ' JOIN '.join(tab_names)
else:
sql_ast = from_ast.fields[0].value
return '( ' + self.unparse_sql(sql_ast, db, *args, **kargs) + ' )'
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
groupby_ast = groupby_field.value
ctr_name = groupby_ast.production.constructor.name
groupby_str = []
for col_unit_ast in groupby_ast.fields[0].value:
groupby_str.append(self.unparse_col_unit(col_unit_ast, db, *args, **kargs))
groupby_str = ' , '.join(groupby_str)
if ctr_name == 'Having':
having = groupby_ast.fields[1].value
having_str = self.unparse_conds(having, db, *args, **kargs)
return groupby_str + ' HAVING ' + having_str
else:
return groupby_str
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
orderby_ast = orderby_field.value
ctr_name = orderby_ast.production.constructor.name.lower()
val_unit_str = []
for val_unit_ast in orderby_ast.fields[0].value:
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
val_unit_str = ' , '.join(val_unit_str)
if 'asc' in ctr_name and 'limit' in ctr_name:
return '%s ASC LIMIT 1' % (val_unit_str)
elif 'asc' in ctr_name:
return '%s ASC' % (val_unit_str)
elif 'desc' in ctr_name and 'limit' in ctr_name:
return '%s DESC LIMIT 1' % (val_unit_str)
else:
return '%s DESC' % (val_unit_str)

View File

@ -0,0 +1,80 @@
#coding=utf8
from asdl.sql.unparser.unparser_base import UnParser
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class UnParserV1(UnParser):
def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs):
sql_ast = sql_field.value
from_field, select_field, where_field, groupby_field, orderby_field = sql_ast.fields
from_str = 'FROM ' + self.unparse_from(from_field, db, *args, **kargs)
select_str = 'SELECT ' + self.unparse_select(select_field, db, *args, **kargs)
where_str, groupby_str, orderby_str = '', '', ''
if where_field.value is not None:
where_str = 'WHERE ' + self.unparse_where(where_field, db, *args, **kargs)
if groupby_field.value is not None:
groupby_str = 'GROUP BY ' + self.unparse_groupby(groupby_field, db, *args, **kargs)
if orderby_field.value is not None:
orderby_str = 'ORDER BY ' + self.unparse_orderby(orderby_field, db, *args, **kargs)
return ' '.join([select_str, from_str, where_str, groupby_str, orderby_str])
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
select_ast = select_field.value
select_list = select_ast.fields
select_items = []
for val_unit_field in select_list:
val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs)
select_items.append(val_unit_str)
return ' , '.join(select_items)
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
from_ast = from_field.value
ctr_name = from_ast.production.constructor.name
if 'Table' in ctr_name:
tab_names = []
for tab_field in from_ast.fields:
tab_name = db['table_names_original'][int(tab_field.value)]
tab_names.append(tab_name)
return ' JOIN '.join(tab_names)
else:
return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )'
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
groupby_ast = groupby_field.value
ctr_name = groupby_ast.production.constructor.name
groupby_str = []
num = len(groupby_ast.fields) - 1
for col_id_field in groupby_ast.fields[:num]:
# col_id = int(col_id_field.value)
# tab_id, col_name = db['column_names_original'][col_id]
# if col_id != 0:
# tab_name = db['table_names_original'][tab_id]
# col_name = tab_name + '.' + col_name
col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs)
groupby_str.append(col_name)
groupby_str = ' , '.join(groupby_str)
if groupby_ast.fields[-1].value is not None:
having = groupby_ast.fields[-1].value
having_str = self.unparse_conds(having, db, *args, **kargs)
return groupby_str + ' HAVING ' + having_str
else:
return groupby_str
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
orderby_ast = orderby_field.value
ctr_name = orderby_ast.production.constructor.name.lower()
val_unit_str = []
for val_unit_field in orderby_ast.fields:
val_unit_ast = val_unit_field.value
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
val_unit_str = ' , '.join(val_unit_str)
if 'asc' in ctr_name and 'limit' in ctr_name:
return '%s ASC LIMIT 1' % (val_unit_str)
elif 'asc' in ctr_name:
return '%s ASC' % (val_unit_str)
elif 'desc' in ctr_name and 'limit' in ctr_name:
return '%s DESC LIMIT 1' % (val_unit_str)
else:
return '%s DESC' % (val_unit_str)

View File

@ -0,0 +1,66 @@
#coding=utf8
from asdl.sql.unparser.unparser_base import UnParser
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
class UnParserV2(UnParser):
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
select_ast = select_field.value
select_list = select_ast.fields
select_items = []
for val_unit_field in select_list:
val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs)
select_items.append(val_unit_str)
return ' , '.join(select_items)
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
from_ast = from_field.value
ctr_name = from_ast.production.constructor.name
if 'Table' in ctr_name:
tab_names = []
for tab_field in from_ast.fields:
tab_name = db['table_names_original'][int(tab_field.value)]
tab_names.append(tab_name)
return ' JOIN '.join(tab_names)
else:
return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )'
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
groupby_ast = groupby_field.value
ctr_name = groupby_ast.production.constructor.name
groupby_str = []
num = len(groupby_ast.fields) if 'NoHaving' in ctr_name else len(groupby_ast.fields) - 1
for col_id_field in groupby_ast.fields[:num]:
# col_id = int(col_id_field.value)
# tab_id, col_name = db['column_names_original'][col_id]
# if col_id != 0:
# tab_name = db['table_names_original'][tab_id]
# col_name = tab_name + '.' + col_name
col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs)
groupby_str.append(col_name)
groupby_str = ' , '.join(groupby_str)
if 'NoHaving' in ctr_name:
return groupby_str
else:
having = groupby_ast.fields[-1].value
having_str = self.unparse_conds(having, db, *args, **kargs)
return groupby_str + ' HAVING ' + having_str
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
orderby_ast = orderby_field.value
ctr_name = orderby_ast.production.constructor.name.lower()
val_unit_str = []
for val_unit_field in orderby_ast.fields:
val_unit_ast = val_unit_field.value
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
val_unit_str = ' , '.join(val_unit_str)
if 'asc' in ctr_name and 'limit' in ctr_name:
return '%s ASC LIMIT 1' % (val_unit_str)
elif 'asc' in ctr_name:
return '%s ASC' % (val_unit_str)
elif 'desc' in ctr_name and 'limit' in ctr_name:
return '%s DESC LIMIT 1' % (val_unit_str)
else:
return '%s DESC' % (val_unit_str)

View File

@ -0,0 +1,136 @@
# coding=utf-8
class Action(object):
pass
class ApplyRuleAction(Action):
def __init__(self, production):
self.production = production
def __hash__(self):
return hash(self.production)
def __eq__(self, other):
return isinstance(other, ApplyRuleAction) and self.production == other.production
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return 'ApplyRule[%s]' % self.production.__repr__()
class GenTokenAction(Action):
def __init__(self, token):
self.token = token
def is_stop_signal(self):
return self.token == '</primitive>'
def __repr__(self):
return 'GenToken[%s]' % self.token
class ReduceAction(Action):
def __repr__(self):
return 'Reduce'
class TransitionSystem(object):
def __init__(self, grammar):
self.grammar = grammar
def get_actions(self, asdl_ast):
"""
generate action sequence given the ASDL Syntax Tree
"""
actions = []
parent_action = ApplyRuleAction(asdl_ast.production)
actions.append(parent_action)
for field in asdl_ast.fields:
# is a composite field
if self.grammar.is_composite_type(field.type):
if field.cardinality == 'single':
field_actions = self.get_actions(field.value)
else:
field_actions = []
if field.value is not None:
if field.cardinality == 'multiple':
for val in field.value:
cur_child_actions = self.get_actions(val)
field_actions.extend(cur_child_actions)
elif field.cardinality == 'optional':
field_actions = self.get_actions(field.value)
# if an optional field is filled, then do not need Reduce action
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
field_actions.append(ReduceAction())
else: # is a primitive field
field_actions = self.get_primitive_field_actions(field)
# if an optional field is filled, then do not need Reduce action
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
# reduce action
field_actions.append(ReduceAction())
actions.extend(field_actions)
return actions
def tokenize_code(self, code, mode):
raise NotImplementedError
def compare_ast(self, hyp_ast, ref_ast):
raise NotImplementedError
def ast_to_surface_code(self, asdl_ast):
raise NotImplementedError
def surface_code_to_ast(self, code):
raise NotImplementedError
def get_primitive_field_actions(self, realized_field):
raise NotImplementedError
def get_valid_continuation_types(self, hyp):
if hyp.tree:
if self.grammar.is_composite_type(hyp.frontier_field.type):
if hyp.frontier_field.cardinality == 'single':
return ApplyRuleAction,
else: # optional, multiple
return ApplyRuleAction, ReduceAction
else:
if hyp.frontier_field.cardinality == 'single':
return GenTokenAction,
elif hyp.frontier_field.cardinality == 'optional':
if hyp._value_buffer:
return GenTokenAction,
else:
return GenTokenAction, ReduceAction
else:
return GenTokenAction, ReduceAction
else:
return ApplyRuleAction,
def get_valid_continuating_productions(self, hyp):
if hyp.tree:
if self.grammar.is_composite_type(hyp.frontier_field.type):
return self.grammar[hyp.frontier_field.type]
else:
raise ValueError
else:
return self.grammar[self.grammar.root_type]
@staticmethod
def get_class_by_lang(lang):
if lang == 'sql':
from asdl.sql.sql_transition_system import SQLTransitionSystem
else:
raise ValueError('unknown language %s' % lang)
return SQLTransitionSystem

71
proton/eval.py Normal file
View File

@ -0,0 +1,71 @@
#coding=utf8
import sys, os, json, pickle, argparse, time, torch
from argparse import Namespace
# These three lines are used in the docker environment rhythmcao/text2sql:v2.0
# os.environ['NLTK_DATA'] = os.path.join(os.path.sep, 'root', 'nltk_data')
# os.environ["STANZA_RESOURCES_DIR"] = os.path.join(os.path.sep, 'root', 'stanza_resources')
# os.environ['EMBEDDINGS_ROOT'] = os.path.join(os.path.sep, 'root', '.embeddings')
from preprocess.process_dataset import process_tables, process_dataset
from preprocess.process_graphs import process_dataset_graph
from preprocess.common_utils import Preprocessor
from preprocess.graph_utils import GraphProcessor
from utils.example import Example
from utils.batch import Batch
from model.model_utils import Registrable
from model.model_constructor import *
def preprocess_database_and_dataset(db_dir='database/', table_path='data/tables.json', dataset_path='data/dev.json', method='lgesql'):
tables = json.load(open(table_path, 'r'))
dataset = json.load(open(dataset_path, 'r'))
processor = Preprocessor(db_dir=db_dir, db_content=True)
output_tables = process_tables(processor, tables)
output_dataset = process_dataset(processor, dataset, output_tables)
graph_processor = GraphProcessor()
output_dataset = process_dataset_graph(graph_processor, output_dataset, output_tables, method=method)
return output_dataset, output_tables
def load_examples(dataset, tables):
ex_list = []
for ex in dataset:
ex_list.append(Example(ex, tables[ex['db_id']]))
return ex_list
parser = argparse.ArgumentParser()
parser.add_argument('--db_dir', default='database', help='path to db dir')
parser.add_argument('--table_path', default='data/tables.json', help='path to tables json file')
parser.add_argument('--dataset_path', default='data/dev.json', help='path to raw dataset json file')
parser.add_argument('--saved_model', default='saved_models/glove42B', help='path to saved model path, at least contain param.json and model.bin')
parser.add_argument('--output_path', default='predicted_sql.txt', help='output predicted sql file')
parser.add_argument('--batch_size', default=20, type=int, help='batch size for evaluation')
parser.add_argument('--beam_size', default=5, type=int, help='beam search size')
parser.add_argument('--use_gpu', action='store_true', help='whether use gpu')
args = parser.parse_args(sys.argv[1:])
params = json.load(open(os.path.join(args.saved_model, 'params.json'), 'r'), object_hook=lambda d: Namespace(**d))
params.lazy_load = True # load PLM from AutoConfig instead of AutoModel.from_pretrained(...)
dataset, tables = preprocess_database_and_dataset(db_dir=args.db_dir, table_path=args.table_path, dataset_path=args.dataset_path, method=params.model)
Example.configuration(plm=params.plm, method=params.model, tables=tables, table_path=args.table_path, db_dir=args.db_dir)
dataset = load_examples(dataset, tables)
device = torch.device("cuda:0") if torch.cuda.is_available() and args.use_gpu else torch.device("cpu")
model = Registrable.by_name('text2sql')(params, Example.trans).to(device)
check_point = torch.load(open(os.path.join(args.saved_model, 'model.bin'), 'rb'), map_location=device)
model.load_state_dict(check_point['model'])
start_time = time.time()
print('Start evaluating ...')
model.eval()
all_hyps = []
with torch.no_grad():
for i in range(0, len(dataset), args.batch_size):
current_batch = Batch.from_example_list(dataset[i: i + args.batch_size], device, train=False)
hyps = model.parse(current_batch, args.beam_size)
all_hyps.extend(hyps)
with open(args.output_path, 'w', encoding='utf8') as of:
evaluator = Example.evaluator
for idx, hyp in enumerate(all_hyps):
pred_sql = evaluator.obtain_sql(hyp, dataset[idx].db)
# best_ast = hyp[0].tree # by default, the top beam prediction
# pred_sql = Example.trans.ast_to_surface_code(best_ast, dataset[idx].db)
of.write(pred_sql + '\n')
print('Evaluation costs %.4fs .' % (time.time() - start_time))

877
proton/evaluation.py Normal file
View File

@ -0,0 +1,877 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os, sys
import json
import sqlite3
import traceback
import argparse
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]] # val1 is also considered
label_wo_agg = [unit[2] for unit in label_conds] # val_unit
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
if label['from']['table_units'][0][0] == 'sql' and pred['from']['table_units'][0][0] == 'sql':
return self.eval_exact_match(pred['from']['table_units'][0][1], label['from']['table_units'][0][1]) # still wrong
else:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
# .schema: map lowercased raw tab name to lowercased raw col name list
# .idMap: map tab name to __tab__, tab.col to __tab.col__, * to __all__, all lowercased
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
# extract all __tab.col__ that has tab in from clause, not include __all__
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) # kmap: map __tab.col__ to pivot __tab.col__
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
return scores
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
if len(sql['from']['table_units']) > 0 and sql['from']['table_units'][0][0] == 'sql':
sql['from']['table_units'][0] = ('sql', rebuild_sql_val(sql['from']['table_units'][0][1]))
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
"""
foreign key map(dict): str -> str
map each lowercased column to a pivot column in its cluster set
normalized column: __all__, __t.lower() + . + c.lower()__
"""
evaluate(gold, pred, db_dir, etype, kmaps)

View File

@ -0,0 +1,215 @@
#coding=utf8
""" ONLSTM and traditional LSTM with locked dropout """
import torch
import torch.nn as nn
import torch.nn.functional as F
def cumsoftmax(x, dim=-1):
return torch.cumsum(F.softmax(x, dim=dim), dim=dim)
class LinearDropConnect(nn.Linear):
""" Used in recurrent connection dropout """
def __init__(self, in_features, out_features, bias=True, dropconnect=0.):
super(LinearDropConnect, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
self.dropconnect = dropconnect
def sample_mask(self):
if self.dropconnect == 0.:
self._weight = self.weight.clone()
else:
mask = self.weight.new_zeros(self.weight.size(), dtype=torch.bool)
mask.bernoulli_(self.dropconnect)
self._weight = self.weight.masked_fill(mask, 0.)
def forward(self, inputs, sample_mask=False):
if self.training:
if sample_mask:
self.sample_mask()
return F.linear(inputs, self._weight, self.bias) # apply the same mask to weight matrix in linear module
else:
return F.linear(inputs, self.weight * (1 - self.dropconnect), self.bias)
class LockedDropout(nn.Module):
""" Used in dropout between layers """
def __init__(self, hidden_size, num_layers=1, dropout=0.2):
super(LockedDropout, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
def sample_masks(self, x):
self.masks = []
for _ in range(self.num_layers - 1):
mask = x.new_zeros(x.size(0), 1, self.hidden_size).bernoulli_(1 - self.dropout)
mask = mask.div_(1 - self.dropout)
mask.requires_grad = False
self.masks.append(mask)
def forward(self, x, layer=0):
""" x: bsize x seqlen x hidden_size """
if (not self.training) or self.dropout == 0. or layer == self.num_layers - 1: # output hidden states, no dropout
return x
mask = self.masks[layer]
mask = mask.expand_as(x)
return mask * x
class RecurrentNeuralNetwork(nn.Module):
def init_hiddens(self, x):
return x.new_zeros(self.num_layers, x.size(0), self.hidden_size), \
x.new_zeros(self.num_layers, x.size(0), self.hidden_size)
def forward(self, inputs, hiddens=None, start=False, layerwise=False):
"""
@args:
start: whether sampling locked masks for recurrent connections and between layers
layerwise: whether return a list, results of intermediate layer outputs
@return:
outputs: bsize x seqlen x hidden_size
final_hiddens: hT and cT, each of size: num_layers x bsize x hidden_size
"""
assert inputs.dim() == 3
if hiddens is None:
hiddens = self.init_hiddens(inputs)
bsize, seqlen, _ = list(inputs.size())
prev_state = list(hiddens) # each of size: num_layers, bsize, hidden_size
prev_layer = inputs # size: bsize, seqlen, input_size
each_layer_outputs, final_h, final_c = [], [], []
if self.training and start:
for c in self.cells:
c.sample_masks()
self.locked_dropout.sample_masks(inputs)
for l in range(len(self.cells)):
curr_layer = [None] * seqlen
curr_inputs = self.cells[l].ih(prev_layer)
next_h, next_c = prev_state[0][l], prev_state[1][l]
for t in range(seqlen):
hidden, cell = self.cells[l](None, (next_h, next_c), transformed_inputs=curr_inputs[:, t])
next_h, next_c = hidden, cell # overwritten every timestep
curr_layer[t] = hidden
prev_layer = torch.stack(curr_layer, dim=1) # bsize x seqlen x hidden_size
each_layer_outputs.append(prev_layer)
final_h.append(next_h)
final_c.append(next_c)
prev_layer = self.locked_dropout(prev_layer, layer=l)
outputs, final_hiddens = prev_layer, (torch.stack(final_h, dim=0), torch.stack(final_c, dim=0))
if layerwise:
return outputs, final_hiddens, each_layer_outputs
else:
return outputs, final_hiddens
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True, dropconnect=0.):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.ih = nn.Linear(input_size, hidden_size * 4, bias=bias)
self.hh = LinearDropConnect(hidden_size, hidden_size * 4, bias=bias, dropconnect=dropconnect)
self.drop_weight_modules = [self.hh]
def sample_masks(self):
for m in self.drop_weight_modules:
m.sample_mask()
def forward(self, inputs, hiddens, transformed_inputs=None):
"""
inputs: bsize x input_size
hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size)
transformed_inputs: short cut for inputs, save time if seq len is already provied in training
return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size)
"""
if transformed_inputs is None:
transformed_inputs = self.ih(inputs)
h0, c0 = hiddens
gates = transformed_inputs + self.hh(h0)
ingate, forgetgate, outgate, cell = gates.contiguous().\
view(-1, 4, self.hidden_size).chunk(4, 1)
forgetgate = torch.sigmoid(forgetgate.squeeze(1))
ingate = torch.sigmoid(ingate.squeeze(1))
cell = torch.tanh(cell.squeeze(1))
outgate = torch.sigmoid(outgate.squeeze(1))
c1 = forgetgate * c0 + ingate * cell
h1 = outgate * torch.tanh(c1)
return h1, c1
class LSTM(RecurrentNeuralNetwork):
def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=1, bias=True, dropout=0., dropconnect=0.):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.cells = nn.ModuleList(
[LSTMCell(input_size, hidden_size, bias, dropconnect)] +
[LSTMCell(hidden_size, hidden_size, bias, dropconnect) for i in range(num_layers - 1)]
)
self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers
class ONLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, chunk_num=8, bias=True, dropconnect=0.2):
super(ONLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.chunk_num = chunk_num # chunk_num should be divided by hidden_size
if self.hidden_size % self.chunk_num != 0:
raise ValueError('[Error]: chunk number must be divided by hidden size in ONLSTM Cell')
self.chunk_size = int(hidden_size / chunk_num)
self.ih = nn.Linear(input_size, self.chunk_size * 2 + hidden_size * 4, bias=bias)
self.hh = LinearDropConnect(hidden_size, self.chunk_size * 2 + hidden_size * 4, bias=bias, dropconnect=dropconnect)
self.drop_weight_modules = [self.hh]
def sample_masks(self):
for m in self.drop_weight_modules:
m.sample_mask()
def forward(self, inputs, hiddens, transformed_inputs=None):
"""
inputs: bsize x input_size
hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size)
transformed_inputs: short cut for inputs, save time if seq len is already provied in training
return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size)
"""
if transformed_inputs is None:
transformed_inputs = self.ih(inputs)
h0, c0 = hiddens
gates = transformed_inputs + self.hh(h0)
cingate, cforgetgate = gates[:, :self.chunk_size * 2].chunk(2, 1)
ingate, forgetgate, outgate, cell = gates[:, self.chunk_size * 2:].contiguous().\
view(-1, self.chunk_size * 4, self.chunk_num).chunk(4, 1)
cingate = 1. - cumsoftmax(cingate)
cforgetgate = cumsoftmax(cforgetgate)
cingate = cingate[:, :, None]
cforgetgate = cforgetgate[:, :, None]
forgetgate = torch.sigmoid(forgetgate)
ingate = torch.sigmoid(ingate)
cell = torch.tanh(cell)
outgate = torch.sigmoid(outgate)
overlap = cforgetgate * cingate
forgetgate = forgetgate * overlap + (cforgetgate - overlap)
ingate = ingate * overlap + (cingate - overlap)
c0 = c0.contiguous().view(-1, self.chunk_size, self.chunk_num)
c1 = forgetgate * c0 + ingate * cell
h1 = outgate * torch.tanh(c1)
return h1.contiguous().view(-1, self.hidden_size), c1.contiguous().view(-1, self.hidden_size)
class ONLSTM(RecurrentNeuralNetwork):
def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=8, bias=True, dropout=0., dropconnect=0.):
super(ONLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.cells = nn.ModuleList(
[ONLSTMCell(input_size, hidden_size, chunk_num, bias, dropconnect)] +
[ONLSTMCell(hidden_size, hidden_size, chunk_num, bias, dropconnect) for i in range(num_layers - 1)]
)
self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers

View File

@ -0,0 +1,397 @@
# coding=utf-8
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from asdl.sql.sql_transition_system import SelectColumnAction, SelectTableAction
from asdl.transition_system import ApplyRuleAction, ReduceAction
from asdl.action_info import ActionInfo
from asdl.hypothesis import Hypothesis
from asdl.decode_hypothesis import DecodeHypothesis
from utils.batch import Batch
from model.decoder.onlstm import LSTM, ONLSTM
from model.model_utils import Registrable, MultiHeadAttention
@Registrable.register('decoder_tranx')
class SqlParser(nn.Module):
def __init__(self, args, transition_system):
super(SqlParser, self).__init__()
self.args = args
self.transition_system = transition_system
self.grammar = self.transition_system.grammar
# the last entry is the embedding for Reduce action
self.production_embed = nn.Embedding(len(transition_system.grammar) + 1, args.action_embed_size)
# embedding table for ASDL fields in constructors
self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size)
# embedding table for ASDL types
self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size)
# input of decoder lstm
input_dim = args.action_embed_size # previous action
# parent node
cxt_num = 2 if args.sep_cxt else 1
input_dim += args.gnn_hidden_size * cxt_num * int(not args.no_context_feeding)
input_dim += args.action_embed_size * int(not args.no_parent_production_embed)
input_dim += args.field_embed_size * int(not args.no_parent_field_embed)
input_dim += args.type_embed_size * int(not args.no_parent_field_type_embed)
input_dim += args.lstm_hidden_size * int(not args.no_parent_state)
cell_constructor = ONLSTM if args.lstm == 'onlstm' else LSTM
self.decoder_lstm = cell_constructor(input_dim, args.lstm_hidden_size, num_layers=args.lstm_num_layers,
chunk_num=args.chunk_size, dropout=args.dropout, dropconnect=args.drop_connect)
# transform column embedding to production embedding space
self.column_lstm_input = nn.Linear(args.gnn_hidden_size, args.action_embed_size)
# transform table embedding to production embedding space
self.table_lstm_input = self.column_lstm_input # nn.Linear(args.gnn_hidden_size, args.action_embed_size)
self.context_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
if args.sep_cxt: # calculate seperate context vector for question and database schema
self.schema_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
# feature vector before ApplyRule or SelectColumn/Table
self.att_vec_linear = nn.Sequential(nn.Linear(args.lstm_hidden_size + cxt_num * args.gnn_hidden_size, args.att_vec_size), nn.Tanh())
self.apply_rule_affine = nn.Linear(args.att_vec_size, args.action_embed_size, bias=False)
self.apply_rule = lambda x: F.linear(x, self.production_embed.weight) # re-use the action embedding matrix
self.select_column = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
self.select_table = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
def score(self, encodings, mask, h0, batch):
""" Training function
@input:
encodings: encoded representations and mask matrix from encoder
bsize x seqlen x gnn_hidden_size
batch: see utils.batch, we use fields
batch.examples, batch.get_frontier_prod_idx(t),
batch.get_frontier_field_idx(t), batch.get_frontier_field_type_idx(t),
batch.max_action_num, example.tgt_action (ActionInfo)
output:
loss: sum of loss for each training batch
"""
args = self.args
zero_action_embed = encodings.new_zeros(args.action_embed_size)
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
q, tab, col = encodings.split(split_len, dim=1)
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
if args.sep_cxt:
encodings, mask = q, q_mask
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
schema_context, _ = self.schema_attn(schema, h0, schema_mask)
context, _ = self.context_attn(encodings, h0, mask)
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
h_c = (h0, h0.new_zeros(h0.size()))
action_probs = [[] for _ in range(encodings.size(0))]
history_states = []
for t in range(batch.max_action_num):
# x: [prev_action_embed, [prev_context, parent_production, parent_field, parent_type, parent_state]]
if t == 0:
x = encodings.new_zeros(encodings.size(0), self.decoder_lstm.input_size)
offset = args.action_embed_size
if not args.no_context_feeding:
x[:, offset: offset + args.gnn_hidden_size] = context
offset += args.gnn_hidden_size
if args.sep_cxt:
x[:, offset: offset + args.gnn_hidden_size] = schema_context
offset += args.gnn_hidden_size
offset += args.action_embed_size * int(not args.no_parent_production_embed)
offset += args.field_embed_size * int(not args.no_parent_field_embed)
if not args.no_parent_field_type_embed:
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * x.size(0), dtype=torch.long, device=x.device)
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
else:
prev_action_embed = []
for e_id, example in enumerate(batch.examples):
if t < len(example.tgt_action):
prev_action = example.tgt_action[t - 1].action
if isinstance(prev_action, ApplyRuleAction):
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
elif isinstance(prev_action, ReduceAction):
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
elif isinstance(prev_action, SelectColumnAction):
# map schema item into prod embed space
col_embed = self.column_lstm_input(col[e_id, prev_action.column_id])
prev_action_embed.append(col_embed)
elif isinstance(prev_action, SelectTableAction):
tab_embed = self.table_lstm_input(tab[e_id, prev_action.table_id])
prev_action_embed.append(tab_embed)
else:
raise ValueError('Unrecognized previous action object!')
else:
prev_action_embed.append(zero_action_embed)
inputs = [torch.stack(prev_action_embed)]
if not args.no_context_feeding:
inputs.append(context)
if args.sep_cxt:
inputs.append(schema_context)
if not args.no_parent_production_embed:
inputs.append(self.production_embed(batch.get_frontier_prod_idx(t)))
if not args.no_parent_field_embed:
inputs.append(self.field_embed(batch.get_frontier_field_idx(t)))
if not args.no_parent_field_type_embed:
inputs.append(self.type_embed(batch.get_frontier_field_type_idx(t)))
if not args.no_parent_state:
actions_t = [e.tgt_action[t] if t < len(e.tgt_action) else None for e in batch.examples]
parent_state = torch.stack([history_states[p_t][e_id]
for e_id, p_t in
enumerate([a_t.parent_t if a_t else 0 for a_t in actions_t])])
inputs.append(parent_state)
x = torch.cat(inputs, dim=-1)
# advance decoder lstm and attention calculation
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c, start=(t==0))
out = out.squeeze(1)
context, _ = self.context_attn(encodings, out, mask)
if args.sep_cxt:
schema_context, _ = self.schema_attn(schema, out, schema_mask)
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # bsize x args.att_vec_size
else:
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
# action logprobs
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # bsize x prod_num
_, select_tab_prob = self.select_table(tab, att_vec, tab_mask)
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
_, select_col_prob = self.select_column(col, att_vec, col_mask)
select_col_logprob = torch.log(select_col_prob + 1e-32)
for e_id, example in enumerate(batch.examples):
if t < len(example.tgt_action):
action_t = example.tgt_action[t].action
if isinstance(action_t, ApplyRuleAction):
logprob_t = apply_rule_logprob[e_id, self.grammar.prod2id[action_t.production]]
# print('Rule %s with prob %s' %(action_t.production, logprob_t.item()))
elif isinstance(action_t, ReduceAction):
logprob_t = apply_rule_logprob[e_id, len(self.grammar)]
# print('Rule %s with prob %s' % ('Reduce', logprob_t.item()))
elif isinstance(action_t, SelectColumnAction):
logprob_t = select_col_logprob[e_id, action_t.column_id]
# print('SelectColumn %s with prob %s' % (action_t.column_id, logprob_t.item()))
elif isinstance(action_t, SelectTableAction):
logprob_t = select_tab_logprob[e_id, action_t.table_id]
# print('SelectTable %s with prob %s' % (action_t.table_id, logprob_t.item()))
else:
raise ValueError('Unrecognized action object!')
action_probs[e_id].append(logprob_t)
h_c = (h_t, c_t)
history_states.append(h_t[-1])
# loss is negative sum of all the action probabilities
loss = - torch.stack([torch.stack(logprob_i).sum() for logprob_i in action_probs]).sum()
return loss
def parse(self, encodings, mask, h0, batch, beam_size=5):
""" Parse one by one, batch size for each args in encodings is 1
"""
args = self.args
zero_action_embed = encodings.new_zeros(args.action_embed_size)
assert encodings.size(0) == 1 and mask.size(0) == 1
encodings, mask = encodings.repeat(beam_size, 1, 1), mask.repeat(beam_size, 1)
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
q, tab, col = encodings.split(split_len, dim=1)
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
if args.sep_cxt:
encodings, mask = q, q_mask
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
schema_context, _ = self.schema_attn(schema[:1], h0, schema_mask[:1])
context, _ = self.context_attn(encodings[:1], h0, mask[:1])
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
h_c = (h0, h0.new_zeros(h0.size()))
hypotheses = [DecodeHypothesis()]
hyp_states, completed_hypotheses, t = [[]], [], 0
while t < args.decode_max_step:
hyp_num = len(hypotheses)
cur_encodings, cur_mask = encodings[:hyp_num], mask[:hyp_num]
if args.sep_cxt:
cur_schema, cur_schema_mask = schema[:hyp_num], schema_mask[:hyp_num]
cur_tab, cur_tab_mask, cur_col, cur_col_mask = tab[:hyp_num], tab_mask[:hyp_num], col[:hyp_num], col_mask[:hyp_num]
# x: [prev_action_embed, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_state]
if t == 0:
x = encodings.new_zeros(hyp_num, self.decoder_lstm.input_size)
offset = args.action_embed_size
if not args.no_context_feeding:
x[:, offset: offset + args.gnn_hidden_size] = context
offset += args.gnn_hidden_size
if args.sep_cxt:
x[:, offset: offset + args.gnn_hidden_size] = schema_context
offset += args.gnn_hidden_size
offset += args.action_embed_size * int(not args.no_parent_production_embed)
offset += args.field_embed_size * int(not args.no_parent_field_embed)
if not args.no_parent_field_type_embed:
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * hyp_num, dtype=torch.long, device=x.device)
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
else:
prev_action_embed = []
for e_id, hyp in enumerate(hypotheses):
prev_action = hyp.actions[-1]
if isinstance(prev_action, ApplyRuleAction):
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
elif isinstance(prev_action, ReduceAction):
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
elif isinstance(prev_action, SelectColumnAction): # need to first map schema item into prod embed space
col_embed = self.column_lstm_input(cur_col[e_id, prev_action.column_id])
prev_action_embed.append(col_embed)
elif isinstance(prev_action, SelectTableAction):
tab_embed = self.table_lstm_input(cur_tab[e_id, prev_action.table_id])
prev_action_embed.append(tab_embed)
else:
raise ValueError('Unrecognized previous action object!')
inputs = [torch.stack(prev_action_embed)]
if not args.no_context_feeding:
inputs.append(context)
if args.sep_cxt:
inputs.append(schema_context)
if not args.no_parent_production_embed:
frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
frontier_prod_embeds = self.production_embed(
torch.tensor([self.grammar.prod2id[prod] for prod in frontier_prods], dtype=torch.long, device=encodings.device))
inputs.append(frontier_prod_embeds)
if not args.no_parent_field_embed:
frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
frontier_field_embeds = self.field_embed(
torch.tensor([self.grammar.field2id[field] for field in frontier_fields], dtype=torch.long, device=encodings.device))
inputs.append(frontier_field_embeds)
if not args.no_parent_field_type_embed:
frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
frontier_field_type_embeds = self.type_embed(
torch.tensor([self.grammar.type2id[tp] for tp in frontier_field_types], dtype=torch.long, device=encodings.device))
inputs.append(frontier_field_type_embeds)
if not args.no_parent_state:
p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
parent_states = torch.stack([hyp_states[hyp_id][p_t] for hyp_id, p_t in enumerate(p_ts)])
inputs.append(parent_states)
x = torch.cat(inputs, dim=-1) # hyp_num x input_size
# advance decoder lstm and attention calculation
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c)
out = out.squeeze(1)
context, _ = self.context_attn(cur_encodings, out, cur_mask)
if args.sep_cxt:
schema_context, _ = self.schema_attn(cur_schema, out, cur_schema_mask)
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # hyp_num x args.att_vec_size
else:
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
# action logprobs
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # remaining_sents*beam x prod_num
_, select_tab_prob = self.select_table(cur_tab, att_vec, cur_tab_mask)
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
_, select_col_prob = self.select_column(cur_col, att_vec, cur_col_mask)
select_col_logprob = torch.log(select_col_prob + 1e-32)
new_hyp_meta = []
for hyp_id, hyp in enumerate(hypotheses):
action_types = self.transition_system.get_valid_continuation_types(hyp)
for action_type in action_types:
if action_type == ApplyRuleAction:
productions = self.transition_system.get_valid_continuating_productions(hyp)
for production in productions:
prod_id = self.grammar.prod2id[production]
prod_score = apply_rule_logprob[hyp_id, prod_id]
new_hyp_score = hyp.score + prod_score
meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
'score': prod_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == ReduceAction:
action_score = apply_rule_logprob[hyp_id, len(self.grammar)]
new_hyp_score = hyp.score + action_score
meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
'score': action_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == SelectColumnAction:
col_num = cur_col_mask[hyp_id].int().sum().item()
for col_id in range(col_num):
col_sel_score = select_col_logprob[hyp_id, col_id]
new_hyp_score = hyp.score + col_sel_score
meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
'score': col_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == SelectTableAction:
tab_num = cur_tab_mask[hyp_id].int().sum().item()
for tab_id in range(tab_num):
tab_sel_score = select_tab_logprob[hyp_id, tab_id]
new_hyp_score = hyp.score + tab_sel_score
meta_entry = {'action_type': 'sel_tab', 'tab_id': tab_id,
'score': tab_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
else:
raise ValueError('Unrecognized action type while decoding!')
if not new_hyp_meta: break
# rank and pick
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta])
top_new_hyp_scores, meta_ids = new_hyp_scores.topk(min(beam_size, new_hyp_scores.size(0)))
live_hyp_ids, new_hypotheses = [], []
for new_hyp_score, meta_id in zip(top_new_hyp_scores.tolist(), meta_ids.tolist()):
action_info = ActionInfo()
hyp_meta_entry = new_hyp_meta[meta_id]
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
prev_hyp = hypotheses[prev_hyp_id]
action_type_str = hyp_meta_entry['action_type']
if action_type_str == 'apply_rule':
prod_id = hyp_meta_entry['prod_id']
if prod_id < len(self.grammar):
production = self.grammar.id2prod[prod_id]
action = ApplyRuleAction(production)
else:
action = ReduceAction()
elif action_type_str == 'sel_col':
action = SelectColumnAction(hyp_meta_entry['col_id'])
elif action_type_str == 'sel_tab':
action = SelectTableAction(hyp_meta_entry['tab_id'])
else:
raise ValueError('Unrecognized action type str!')
action_info.action = action
action_info.t = t
if t > 0:
action_info.parent_t = prev_hyp.frontier_node.created_time
action_info.frontier_prod = prev_hyp.frontier_node.production
action_info.frontier_field = prev_hyp.frontier_field.field
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
new_hyp.score = new_hyp_score
if new_hyp.completed:
completed_hypotheses.append(new_hyp)
else:
new_hypotheses.append(new_hyp)
live_hyp_ids.append(prev_hyp_id)
if live_hyp_ids:
if not args.no_parent_state:
hyp_states = [hyp_states[i] + [h_t[-1, i]] for i in live_hyp_ids]
h_c = (h_t[:, live_hyp_ids], c_t[:, live_hyp_ids])
if not args.no_context_feeding:
context = context[live_hyp_ids]
if args.sep_cxt:
schema_context = schema_context[live_hyp_ids]
hypotheses = new_hypotheses
t += 1
else:
break
# for idx, hyp in enumerate(new_hypotheses):
# print('Idx %d' % (idx))
# print('Tree:', hyp.tree.to_string())
# print('Action:', hyp.action_infos[-1])
# print('============================================')
# input('********************Next parse step ....************************')
if len(completed_hypotheses) == 0: # no completed sql
completed_hypotheses.append(DecodeHypothesis())
else:
completed_hypotheses.sort(key=lambda hyp: - hyp.score) # / hyp.tree.size
return completed_hypotheses

View File

@ -0,0 +1,33 @@
#coding=utf8
import dgl, math, torch
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
return func
def src_sum_edge_mul_dst(src_field, dst_field, e_field, out_field):
def func(edges):
return {out_field: ((edges.src[src_field] + edges.data[e_field]) * edges.dst[dst_field]).sum(-1, keepdim=True)}
return func
def scaled_exp(field, scale_constant):
def func(edges):
# clamp for softmax numerical stability
return {field: torch.exp((edges.data[field] / scale_constant).clamp(-10, 10))}
return func
def src_sum_edge_mul_edge(src_field, e_field1, e_field2, out_field):
def func(edges):
return {out_field: (edges.src[src_field] + edges.data[e_field1]) * edges.data[e_field2]}
return func
def div_by_z(in_field, norm_field, out_field):
def func(nodes):
return {out_field: nodes.data[in_field] / nodes.data[norm_field]}
return func

View File

@ -0,0 +1,25 @@
#coding=utf8
import torch
import torch.nn as nn
from model.encoder.graph_input import *
from model.encoder.rgatsql import RGATSQL
from model.encoder.lgesql import LGESQL
from model.encoder.graph_output import *
from model.model_utils import Registrable
@Registrable.register('encoder_text2sql')
class Text2SQLEncoder(nn.Module):
def __init__(self, args):
super(Text2SQLEncoder, self).__init__()
lazy_load = args.lazy_load if hasattr(args, 'lazy_load') else False
self.input_layer = GraphInputLayer(args.embed_size, args.gnn_hidden_size, args.word_vocab, dropout=args.dropout, schema_aggregation=args.schema_aggregation) \
if args.plm is None else GraphInputLayerPLM(args.plm, args.gnn_hidden_size, dropout=args.dropout,
subword_aggregation=args.subword_aggregation, schema_aggregation=args.schema_aggregation, lazy_load=lazy_load)
self.hidden_layer = Registrable.by_name(args.model)(args)
self.output_layer = Registrable.by_name(args.output_model)(args)
def forward(self, batch):
outputs = self.input_layer(batch)
outputs = self.hidden_layer(outputs, batch)
return self.output_layer(outputs, batch)

View File

@ -0,0 +1,153 @@
#coding=utf8
import os, math
import torch
import torch.nn as nn
from model.model_utils import rnn_wrapper, lens2mask, PoolingFunction
from transformers import AutoModel, AutoConfig
class GraphInputLayer(nn.Module):
def __init__(self, embed_size, hidden_size, word_vocab, dropout=0.2, fix_grad_idx=60, schema_aggregation='head+tail'):
super(GraphInputLayer, self).__init__()
self.embed_size = embed_size
self.hidden_size = hidden_size
self.word_vocab = word_vocab
self.fix_grad_idx = fix_grad_idx
self.word_embed = nn.Embedding(self.word_vocab, self.embed_size, padding_idx=0)
self.dropout_layer = nn.Dropout(p=dropout)
self.rnn_layer = InputRNNLayer(self.embed_size, self.hidden_size, cell='lstm', schema_aggregation=schema_aggregation)
def pad_embedding_grad_zero(self, index=None):
self.word_embed.weight.grad[0].zero_() # padding symbol is always 0
if index is not None:
if not torch.is_tensor(index):
index = torch.tensor(index, dtype=torch.long, device=self.word_embed.weight.grad.device)
self.word_embed.weight.grad.index_fill_(0, index, 0.)
else:
self.word_embed.weight.grad[self.fix_grad_idx:].zero_()
def forward(self, batch):
question, table, column = self.word_embed(batch.questions), self.word_embed(batch.tables), self.word_embed(batch.columns)
if batch.question_unk_mask is not None:
question = question.masked_scatter_(batch.question_unk_mask.unsqueeze(-1), batch.question_unk_embeddings[:, :self.embed_size])
if batch.table_unk_mask is not None:
table = table.masked_scatter_(batch.table_unk_mask.unsqueeze(-1), batch.table_unk_embeddings[:, :self.embed_size])
if batch.column_unk_mask is not None:
column = column.masked_scatter_(batch.column_unk_mask.unsqueeze(-1), batch.column_unk_embeddings[:, :self.embed_size])
input_dict = {
"question": self.dropout_layer(question),
"table": self.dropout_layer(table),
"column": self.dropout_layer(column)
}
inputs = self.rnn_layer(input_dict, batch)
return inputs
class GraphInputLayerPLM(nn.Module):
def __init__(self, plm='bert-base-uncased', hidden_size=256, dropout=0., subword_aggregation='mean',
schema_aggregation='head+tail', lazy_load=False):
super(GraphInputLayerPLM, self).__init__()
self.plm_model = AutoModel.from_config(AutoConfig.from_pretrained(os.path.join('./pretrained_models', plm))) \
if lazy_load else AutoModel.from_pretrained(os.path.join('./pretrained_models', plm))
self.config = self.plm_model.config
self.subword_aggregation = SubwordAggregation(self.config.hidden_size, subword_aggregation=subword_aggregation)
self.dropout_layer = nn.Dropout(p=dropout)
self.rnn_layer = InputRNNLayer(self.config.hidden_size, hidden_size, cell='lstm', schema_aggregation=schema_aggregation)
def pad_embedding_grad_zero(self, index=None):
pass
def forward(self, batch):
outputs = self.plm_model(**batch.inputs)[0] # final layer hidden states
question, table, column = self.subword_aggregation(outputs, batch)
input_dict = {
"question": self.dropout_layer(question),
"table": self.dropout_layer(table),
"column": self.dropout_layer(column)
}
inputs = self.rnn_layer(input_dict, batch)
return inputs
class SubwordAggregation(nn.Module):
""" Map subword or wordpieces into one fixed size vector based on aggregation method
"""
def __init__(self, hidden_size, subword_aggregation='mean-pooling'):
super(SubwordAggregation, self).__init__()
self.hidden_size = hidden_size
self.aggregation = PoolingFunction(self.hidden_size, self.hidden_size, method=subword_aggregation)
def forward(self, inputs, batch):
""" Transform pretrained model outputs into our desired format
questions: bsize x max_question_len x hidden_size
tables: bsize x max_table_word_len x hidden_size
columns: bsize x max_column_word_len x hidden_size
"""
old_questions, old_tables, old_columns = inputs.masked_select(batch.question_mask_plm.unsqueeze(-1)), \
inputs.masked_select(batch.table_mask_plm.unsqueeze(-1)), inputs.masked_select(batch.column_mask_plm.unsqueeze(-1))
questions = old_questions.new_zeros(batch.question_subword_lens.size(0), batch.max_question_subword_len, self.hidden_size)
questions = questions.masked_scatter_(batch.question_subword_mask.unsqueeze(-1), old_questions)
tables = old_tables.new_zeros(batch.table_subword_lens.size(0), batch.max_table_subword_len, self.hidden_size)
tables = tables.masked_scatter_(batch.table_subword_mask.unsqueeze(-1), old_tables)
columns = old_columns.new_zeros(batch.column_subword_lens.size(0), batch.max_column_subword_len, self.hidden_size)
columns = columns.masked_scatter_(batch.column_subword_mask.unsqueeze(-1), old_columns)
questions = self.aggregation(questions, mask=batch.question_subword_mask)
tables = self.aggregation(tables, mask=batch.table_subword_mask)
columns = self.aggregation(columns, mask=batch.column_subword_mask)
new_questions, new_tables, new_columns = questions.new_zeros(len(batch), batch.max_question_len, self.hidden_size),\
tables.new_zeros(batch.table_word_mask.size(0), batch.max_table_word_len, self.hidden_size), \
columns.new_zeros(batch.column_word_mask.size(0), batch.max_column_word_len, self.hidden_size)
new_questions = new_questions.masked_scatter_(batch.question_mask.unsqueeze(-1), questions)
new_tables = new_tables.masked_scatter_(batch.table_word_mask.unsqueeze(-1), tables)
new_columns = new_columns.masked_scatter_(batch.column_word_mask.unsqueeze(-1), columns)
return new_questions, new_tables, new_columns
class InputRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size, cell='lstm', schema_aggregation='head+tail', share_lstm=False):
super(InputRNNLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.cell = cell.upper()
self.question_lstm = getattr(nn, self.cell)(self.input_size, self.hidden_size // 2, num_layers=1, bidirectional=True, batch_first=True)
self.schema_lstm = self.question_lstm if share_lstm else \
getattr(nn, self.cell)(self.input_size, self.hidden_size // 2, num_layers=1, bidirectional=True, batch_first=True)
self.schema_aggregation = schema_aggregation
if self.schema_aggregation != 'head+tail':
self.aggregation = PoolingFunction(self.hidden_size, self.hidden_size, method=schema_aggregation)
def forward(self, input_dict, batch):
"""
for question sentence, forward into a bidirectional LSTM to get contextual info and sequential dependence
for schema phrase, extract representation for each phrase by concatenating head+tail vectors,
batch.question_lens, batch.table_word_lens, batch.column_word_lens are used
"""
questions, _ = rnn_wrapper(self.question_lstm, input_dict['question'], batch.question_lens, cell=self.cell)
questions = questions.contiguous().view(-1, self.hidden_size)[lens2mask(batch.question_lens).contiguous().view(-1)]
table_outputs, table_hiddens = rnn_wrapper(self.schema_lstm, input_dict['table'], batch.table_word_lens, cell=self.cell)
if self.schema_aggregation != 'head+tail':
tables = self.aggregation(table_outputs, mask=batch.table_word_mask)
else:
table_hiddens = table_hiddens[0].transpose(0, 1) if self.cell == 'LSTM' else table_hiddens.transpose(0, 1)
tables = table_hiddens.contiguous().view(-1, self.hidden_size)
column_outputs, column_hiddens = rnn_wrapper(self.schema_lstm, input_dict['column'], batch.column_word_lens, cell=self.cell)
if self.schema_aggregation != 'head+tail':
columns = self.aggregation(column_outputs, mask=batch.column_word_mask)
else:
column_hiddens = column_hiddens[0].transpose(0, 1) if self.cell == 'LSTM' else column_hiddens.transpose(0, 1)
columns = column_hiddens.contiguous().view(-1, self.hidden_size)
questions = questions.split(batch.question_lens.tolist(), dim=0)
tables = tables.split(batch.table_lens.tolist(), dim=0)
columns = columns.split(batch.column_lens.tolist(), dim=0)
# dgl graph node feats format: q11 q12 ... t11 t12 ... c11 c12 ... q21 q22 ...
outputs = [th for q_t_c in zip(questions, tables, columns) for th in q_t_c]
outputs = torch.cat(outputs, dim=0)
# transformer input format: bsize x max([q1 q2 ... t1 t2 ... c1 c2 ...]) x hidden_size
# outputs = []
# for q, t, c in zip(questions, tables, columns):
# zero_paddings = q.new_zeros((batch.max_len - q.size(0) - t.size(0) - c.size(0), q.size(1)))
# cur_outputs = torch.cat([q, t, c, zero_paddings], dim=0)
# outputs.append(cur_outputs)
# outputs = torch.stack(outputs, dim=0) # bsize x max_len x hidden_size
return outputs

View File

@ -0,0 +1,138 @@
#coding=utf8
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from model.model_utils import Registrable
from model.encoder.functions import scaled_exp, div_by_z, src_dot_dst
class ScoreFunction(nn.Module):
def __init__(self, hidden_size, mlp=1, method='biaffine'):
super(ScoreFunction, self).__init__()
assert method in ['dot', 'bilinear', 'affine', 'biaffine']
self.mlp = int(mlp)
self.hidden_size = hidden_size // self.mlp
if self.mlp > 1: # use mlp to perform dim reduction
self.mlp_q = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh())
self.mlp_s = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh())
self.method = method
if self.method == 'bilinear':
self.W = nn.Linear(self.hidden_size, self.hidden_size)
elif self.method == 'affine':
self.affine = nn.Linear(self.hidden_size * 2, 1)
elif self.method == 'biaffine':
self.W = nn.Linear(self.hidden_size, self.hidden_size)
self.affine = nn.Linear(self.hidden_size * 2, 1)
def forward(self, context, node):
"""
@args:
context(torch.FloatTensor): num_nodes x hidden_size
node(torch.FloatTensor): num_nodes x hidden_size
@return:
scores(torch.FloatTensor): num_nodes
"""
if self.mlp > 1:
context, node = self.mlp_q(context), self.mlp_s(node)
if self.method == 'dot':
scores = (context * node).sum(dim=-1)
elif self.method == 'bilinear':
scores = (context * self.W(node)).sum(dim=-1)
elif self.method == 'affine':
scores = self.affine(torch.cat([context, node], dim=-1)).squeeze(-1)
elif self.method == 'biaffine':
scores = (context * self.W(node)).sum(dim=-1)
scores += self.affine(torch.cat([context, node], dim=-1)).squeeze(-1)
else:
raise ValueError('[Error]: Unrecognized score function method %s!' % (self.method))
return scores
@Registrable.register('without_pruning')
class GraphOutputLayer(nn.Module):
def __init__(self, args):
super(GraphOutputLayer, self).__init__()
self.hidden_size = args.gnn_hidden_size
def forward(self, inputs, batch):
""" Re-scatter data format:
inputs: sum(q_len + t_len + c_len) x hidden_size
outputs: bsize x (max_q_len + max_t_len + max_c_len) x hidden_size
"""
outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size)
outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs)
if self.training:
return outputs, batch.mask, torch.tensor(0., dtype=torch.float).to(outputs.device)
else:
return outputs, batch.mask
@Registrable.register('with_pruning')
class GraphOutputLayerWithPruning(nn.Module):
def __init__(self, args):
super(GraphOutputLayerWithPruning, self).__init__()
self.hidden_size = args.gnn_hidden_size
self.graph_pruning = GraphPruning(self.hidden_size, args.num_heads, args.dropout, args.score_function)
def forward(self, inputs, batch):
outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size)
outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs)
if self.training:
g = batch.graph
question = inputs.masked_select(g.question_mask.unsqueeze(-1)).view(-1, self.hidden_size)
schema = inputs.masked_select(g.schema_mask.unsqueeze(-1)).view(-1, self.hidden_size)
loss = self.graph_pruning(question, schema, g.gp, g.node_label)
return outputs, batch.mask, loss
else:
return outputs, batch.mask
class GraphPruning(nn.Module):
def __init__(self, hidden_size, num_heads=8, feat_drop=0.2, score_function='affine'):
super(GraphPruning, self).__init__()
self.hidden_size = hidden_size
self.node_mha = DGLMHA(hidden_size, hidden_size, num_heads, feat_drop)
self.node_score_function = ScoreFunction(self.hidden_size, mlp=2, method=score_function)
self.loss_function = nn.BCEWithLogitsLoss(reduction='sum')
def forward(self, question, schema, graph, node_label):
node_context = self.node_mha(question, schema, graph)
node_score = self.node_score_function(node_context, schema)
loss = self.loss_function(node_score, node_label)
return loss
class DGLMHA(nn.Module):
""" Multi-head attention implemented with DGL lib
"""
def __init__(self, hidden_size, output_size, num_heads=8, feat_drop=0.2):
super(DGLMHA, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.num_heads = num_heads
self.d_k = self.hidden_size // self.num_heads
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.output_size, self.hidden_size),\
nn.Linear(self.hidden_size, self.hidden_size, bias=False), nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.affine_o = nn.Linear(self.hidden_size, self.output_size)
self.feat_dropout = nn.Dropout(p=feat_drop)
def forward(self, context, node, g):
q, k, v = self.affine_q(self.feat_dropout(node)), self.affine_k(self.feat_dropout(context)), self.affine_v(self.feat_dropout(context))
with g.local_scope():
g.nodes['schema'].data['q'] = q.view(-1, self.num_heads, self.d_k)
g.nodes['question'].data['k'] = k.view(-1, self.num_heads, self.d_k)
g.nodes['question'].data['v'] = v.view(-1, self.num_heads, self.d_k)
out_x = self.propagate_attention(g)
return self.affine_o(out_x.view(-1, self.hidden_size))
def propagate_attention(self, g):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'))
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
# Update node state
g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
out_x = g.nodes['schema'].data['o']
return out_x

View File

@ -0,0 +1,98 @@
#coding=utf8
import copy, math
import torch, dgl
import torch.nn as nn
import dgl.function as fn
from model.model_utils import Registrable, FFN
from model.encoder.rgatsql import RGATLayer, MultiViewRGATLayer
from model.encoder.functions import *
@Registrable.register('lgesql')
class LGESQL(nn.Module):
""" Compared with RGAT, we utilize a line graph to explicitly model the propagation among edges:
1. aggregate info from in-edges
2. aggregate info from src nodes
"""
def __init__(self, args):
super(LGESQL, self).__init__()
self.num_layers, self.num_heads = args.gnn_num_layers, args.num_heads
self.relation_share_heads = args.relation_share_heads
self.graph_view = args.local_and_nonlocal
self.ndim = args.gnn_hidden_size # node feature dim
self.edim = self.ndim // self.num_heads if self.relation_share_heads else \
self.ndim // 2 if self.graph_view == 'mmc' else self.ndim
self.relation_num = args.relation_num
self.relation_embed = nn.Embedding(self.relation_num, self.edim)
self.gnn_layers = nn.ModuleList([
DualRGATLayer(self.ndim, self.edim, num_heads=args.num_heads, feat_drop=args.dropout, graph_view=self.graph_view)
for _ in range(self.num_layers)])
def forward(self, x, batch):
global_lgx = self.relation_embed(batch.graph.global_edges)
mask = batch.graph.local_mask
local_lgx = global_lgx[mask]
local_g, global_g, lg = batch.graph.local_g, batch.graph.global_g, batch.graph.lg
src_ids, dst_ids = batch.graph.src_ids, batch.graph.dst_ids
for i in range(self.num_layers):
x, local_lgx = self.gnn_layers[i](x, local_lgx, global_lgx, local_g, global_g, lg, src_ids, dst_ids)
if self.graph_view == 'msde':
# update local edge embeddings in the global edge embeddings matrix
global_lgx = global_lgx.masked_scatter_(mask.unsqueeze(-1), local_lgx)
return x
class DualRGATLayer(nn.Module):
def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2, graph_view='mmc'):
super(DualRGATLayer, self).__init__()
self.ndim, self.edim = ndim, edim
self.num_heads, self.graph_view = num_heads, graph_view
NodeRGATLayer = MultiViewRGATLayer if self.graph_view == 'mmc' else RGATLayer
self.node_update = NodeRGATLayer(self.ndim, self.edim, self.num_heads, feat_drop=feat_drop)
self.edge_update = EdgeRGATLayer(self.edim, self.ndim, self.num_heads, feat_drop=feat_drop)
def forward(self, x, local_lgx, global_lgx, local_g, global_g, lg, src_ids, dst_ids):
if self.graph_view == 'mmc':
out_x, _ = self.node_update(x, local_lgx, global_lgx, local_g, global_g)
elif self.graph_view == 'msde':
out_x, _ = self.node_update(x, global_lgx, global_g)
else:
out_x, _ = self.node_update(x, local_lgx, local_g)
src_x = torch.index_select(x, dim=0, index=src_ids)
dst_x = torch.index_select(x, dim=0, index=dst_ids)
out_local_lgx, _ = self.edge_update(local_lgx, src_x, dst_x, lg)
return out_x, out_local_lgx
class EdgeRGATLayer(nn.Module):
def __init__(self, edim, ndim, num_heads=8, feat_drop=0.2):
super(EdgeRGATLayer, self).__init__()
self.edim, self.ndim = edim, ndim
self.num_heads = num_heads
self.d_k = self.ndim // self.num_heads
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.edim, self.ndim), \
nn.Linear(self.edim, self.ndim, bias=False), nn.Linear(self.edim, self.ndim, bias=False)
self.affine_o = nn.Linear(self.ndim, self.edim)
self.layernorm = nn.LayerNorm(self.edim)
self.feat_dropout = nn.Dropout(p=feat_drop)
self.ffn = FFN(self.edim)
def forward(self, x, src_x, dst_x, g):
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
with g.local_scope():
g.ndata['q'] = (q + src_x).view(-1, self.num_heads, self.d_k)
g.ndata['k'] = k.view(-1, self.num_heads, self.d_k)
g.ndata['v'] = (v + dst_x).view(-1, self.num_heads, self.d_k)
out_x = self.propagate_attention(g)
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
out_x = self.ffn(out_x)
return out_x, (src_x, dst_x)
def propagate_attention(self, g):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'))
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
# Update node state
g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
out_x = g.ndata['o']
return out_x

View File

@ -0,0 +1,125 @@
#coding=utf8
import copy, math
import torch, dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from model.model_utils import Registrable, FFN
from model.encoder.functions import *
@Registrable.register('rgatsql')
class RGATSQL(nn.Module):
def __init__(self, args):
super(RGATSQL, self).__init__()
self.num_layers = args.gnn_num_layers
self.relation_num = args.relation_num
self.relation_share_layers, self.relation_share_heads = args.relation_share_layers, args.relation_share_heads
self.graph_view = 'multiview' if args.local_and_nonlocal in ['mmc', 'msde'] else args.local_and_nonlocal
edim = args.gnn_hidden_size // args.num_heads if self.relation_share_heads else \
args.gnn_hidden_size // 2 if self.graph_view == 'multiview' else args.gnn_hidden_size
if self.relation_share_layers:
self.relation_embed = nn.Embedding(args.relation_num, edim)
else:
self.relation_embed = nn.ModuleList([nn.Embedding(args.relation_num, edim) for _ in range(self.num_layers)])
gnn_layer = MultiViewRGATLayer if self.graph_view == 'multiview' else RGATLayer
self.gnn_layers = nn.ModuleList([gnn_layer(args.gnn_hidden_size, edim, num_heads=args.num_heads, feat_drop=args.dropout)
for _ in range(self.num_layers)])
def forward(self, x, batch):
if self.graph_view == 'multiview':
# multi-view multi-head concatenation
global_edges, mask = batch.graph.global_edges, batch.graph.local_mask
local_g, global_g = batch.graph.local_g, batch.graph.global_g
if self.relation_share_layers:
global_lgx = self.relation_embed(global_edges)
local_lgx = global_lgx[mask]
for i in range(self.num_layers):
global_lgx = self.relation_embed[i](global_edges) if not self.relation_share_layers else global_lgx
local_lgx = global_lgx[mask] if not self.relation_share_layers else local_lgx
x, (local_lgx, global_lgx) = self.gnn_layers[i](x, local_lgx, global_lgx, local_g, global_g)
else:
graph = batch.graph.local_g if self.graph_view == 'local' else batch.graph.global_g
edges, mask = batch.graph.global_edges, batch.graph.local_mask
if self.relation_share_layers:
lgx = self.relation_embed(edges)
lgx = lgx if self.graph_view == 'global' else lgx[mask]
for i in range(self.num_layers):
lgx = lgx if self.relation_share_layers else self.relation_embed[i](edges) if self.graph_view == 'global' else self.relation_embed[i](edges)[mask]
x, lgx = self.gnn_layers[i](x, lgx, graph)
return x
class RGATLayer(nn.Module):
def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2):
super(RGATLayer, self).__init__()
self.ndim, self.edim = ndim, edim
self.num_heads = num_heads
dim = max([ndim, edim])
self.d_k = dim // self.num_heads
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.ndim, dim),\
nn.Linear(self.ndim, dim, bias=False), nn.Linear(self.ndim, dim, bias=False)
self.affine_o = nn.Linear(dim, self.ndim)
self.layernorm = nn.LayerNorm(self.ndim)
self.feat_dropout = nn.Dropout(p=feat_drop)
self.ffn = FFN(self.ndim)
def forward(self, x, lgx, g):
""" @Params:
x: node feats, num_nodes x ndim
lgx: edge feats, num_edges x edim
g: dgl.graph
"""
# pre-mapping q/k/v affine
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
e = lgx.view(-1, self.num_heads, self.d_k) if lgx.size(-1) == q.size(-1) else \
lgx.unsqueeze(1).expand(-1, self.num_heads, -1)
with g.local_scope():
g.ndata['q'], g.ndata['k'] = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k)
g.ndata['v'] = v.view(-1, self.num_heads, self.d_k)
g.edata['e'] = e
out_x = self.propagate_attention(g)
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
out_x = self.ffn(out_x)
return out_x, lgx
def propagate_attention(self, g):
# Compute attention score
g.apply_edges(src_sum_edge_mul_dst('k', 'q', 'e', 'score'))
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
# Update node state
g.update_all(src_sum_edge_mul_edge('v', 'e', 'score', 'v'), fn.sum('v', 'wv'))
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
out_x = g.ndata['o']
return out_x
class MultiViewRGATLayer(RGATLayer):
def forward(self, x, local_lgx, global_lgx, local_g, global_g):
""" @Params:
x: node feats, num_nodes x ndim
local_lgx: local edge feats, local_edge_num x edim
global_lgx: all edge feats, global_edge_num x edim
local_g: dgl.graph, a local graph for node update
global_g: dgl.graph, a complete graph for node update
"""
# pre-mapping q/k/v affine
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
q, k, v = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k), v.view(-1, self.num_heads, self.d_k)
with local_g.local_scope():
local_g.ndata['q'], local_g.ndata['k'] = q[:, :self.num_heads // 2], k[:, :self.num_heads // 2]
local_g.ndata['v'] = v[:, :self.num_heads // 2]
local_g.edata['e'] = local_lgx.view(-1, self.num_heads // 2, self.d_k) if local_lgx.size(-1) == self.d_k * self.num_heads // 2 else \
local_lgx.unsqueeze(1).expand(-1, self.num_heads // 2, -1)
out_x1 = self.propagate_attention(local_g)
with global_g.local_scope():
global_g.ndata['q'], global_g.ndata['k'] = q[:, self.num_heads // 2:], k[:, self.num_heads // 2:]
global_g.ndata['v'] = v[:, self.num_heads // 2:]
global_g.edata['e'] = global_lgx.view(-1, self.num_heads // 2, self.d_k) if global_lgx.size(-1) == self.d_k * self.num_heads // 2 else \
global_lgx.unsqueeze(1).expand(-1, self.num_heads // 2, -1)
out_x2 = self.propagate_attention(global_g)
out_x = torch.cat([out_x1, out_x2], dim=1)
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
out_x = self.ffn(out_x)
return out_x, (local_lgx, global_lgx)

View File

@ -0,0 +1,41 @@
#coding=utf8
import torch
import torch.nn as nn
from model.model_utils import Registrable, PoolingFunction
from model.encoder.graph_encoder import *
from model.decoder.sql_parser import *
@Registrable.register('text2sql')
class Text2SQL(nn.Module):
def __init__(self, args, transition_system):
super(Text2SQL, self).__init__()
self.encoder = Registrable.by_name('encoder_text2sql')(args)
self.encoder2decoder = PoolingFunction(args.gnn_hidden_size, args.lstm_hidden_size, method='attentive-pooling')
self.decoder = Registrable.by_name('decoder_tranx')(args, transition_system)
def forward(self, batch):
""" This function is used during training, which returns the entire training loss
"""
encodings, mask, gp_loss = self.encoder(batch)
h0 = self.encoder2decoder(encodings, mask=mask)
loss = self.decoder.score(encodings, mask, h0, batch)
return loss, gp_loss
def parse(self, batch, beam_size):
""" This function is used for decoding, which returns a batch of [DecodeHypothesis()] * beam_size
"""
encodings, mask = self.encoder(batch)
h0 = self.encoder2decoder(encodings, mask=mask)
hyps = []
for i in range(len(batch)):
""" table_mappings and column_mappings are used to map original database ids to local ids,
while reverse_mappings perform the opposite function, mapping local ids to database ids
"""
hyps.append(self.decoder.parse(encodings[i:i+1], mask[i:i+1], h0[i:i+1], batch, beam_size))
return hyps
def pad_embedding_grad_zero(self, index=None):
""" For glove.42B.300d word vectors, gradients for <pad> symbol is always 0;
Most words (starting from index) in the word vocab are also fixed except most frequent words
"""
self.encoder.input_layer.pad_embedding_grad_zero(index)

214
proton/model/model_utils.py Normal file
View File

@ -0,0 +1,214 @@
#coding=utf8
import copy, math
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def lens2mask(lens):
bsize = lens.numel()
max_len = lens.max()
masks = torch.arange(0, max_len).type_as(lens).to(lens.device).repeat(bsize, 1).lt(lens.unsqueeze(1))
masks.requires_grad = False
return masks
def mask2matrix(mask):
col_mask, row_mask = mask.unsqueeze(-1), mask.unsqueeze(-2)
return col_mask & row_mask
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
E.g. [1, 2, 3], count=2 ==> [1, 1, 2, 2, 3, 3]
[[1, 2], [3, 4]], count=3, dim=1 ==> [[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]
Different from torch.repeat
"""
if x is None:
return x
elif type(x) in [list, tuple]:
return type(x)([tile(each, count, dim) for each in x])
else:
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.contiguous().view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x
def rnn_wrapper(encoder, inputs, lens, cell='lstm'):
"""
@args:
encoder(nn.Module): rnn series bidirectional encoder, batch_first=True
inputs(torch.FloatTensor): rnn inputs, [bsize x max_seq_len x in_dim]
lens(torch.LongTensor): seq len for each sample, allow length=0, padding with 0-vector, [bsize]
@return:
out(torch.FloatTensor): output of encoder, bsize x max_seq_len x hidden_dim*2
hidden_states([tuple of ]torch.FloatTensor): final hidden states, num_layers*2 x bsize x hidden_dim
"""
# rerank according to lens and remove empty inputs
sorted_lens, sort_key = torch.sort(lens, descending=True)
nonzero_num, total_num = torch.sum(sorted_lens > 0).item(), sorted_lens.size(0)
sort_key = sort_key[:nonzero_num]
sorted_inputs = torch.index_select(inputs, dim=0, index=sort_key)
# forward non empty inputs
packed_inputs = rnn_utils.pack_padded_sequence(sorted_inputs, sorted_lens[:nonzero_num].tolist(), batch_first=True)
packed_out, sorted_h = encoder(packed_inputs) # bsize x srclen x dim
sorted_out, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)
if cell.upper() == 'LSTM':
sorted_h, sorted_c = sorted_h
# rerank according to sort_key
out_shape = list(sorted_out.size())
out_shape[0] = total_num
out = sorted_out.new_zeros(*out_shape).scatter_(0, sort_key.unsqueeze(-1).unsqueeze(-1).repeat(1, *out_shape[1:]), sorted_out)
h_shape = list(sorted_h.size())
h_shape[1] = total_num
h = sorted_h.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_h)
if cell.upper() == 'LSTM':
c = sorted_c.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_c)
return out, (h.contiguous(), c.contiguous())
return out, h.contiguous()
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, q_size, kv_size, output_size, num_heads=8, bias=True, feat_drop=0.2, attn_drop=0.0):
super(MultiHeadAttention, self).__init__()
self.num_heads = int(num_heads)
self.hidden_size = hidden_size
assert self.hidden_size % self.num_heads == 0, 'Head num %d must be divided by hidden size %d' % (num_heads, hidden_size)
self.d_k = self.hidden_size // self.num_heads
self.feat_drop = nn.Dropout(p=feat_drop)
self.attn_drop = nn.Dropout(p=attn_drop)
self.W_q = nn.Linear(q_size, self.hidden_size, bias=bias)
self.W_k = nn.Linear(kv_size, self.hidden_size, bias=False)
self.W_v = nn.Linear(kv_size, self.hidden_size, bias=False)
self.W_o = nn.Linear(self.hidden_size, output_size, bias=bias)
def forward(self, hiddens, query_hiddens, mask=None):
''' @params:
hiddens : encoded sequence representations, bsize x seqlen x hidden_size
query_hiddens : bsize [x tgtlen ]x hidden_size
mask : length mask for hiddens, ByteTensor, bsize x seqlen
@return:
context : bsize x[ tgtlen x] hidden_size
'''
remove_flag = False
if query_hiddens.dim() == 2:
query_hiddens, remove_flag = query_hiddens.unsqueeze(1), True
Q, K, V = self.W_q(self.feat_drop(query_hiddens)), self.W_k(self.feat_drop(hiddens)), self.W_v(self.feat_drop(hiddens))
Q, K, V = Q.reshape(-1, Q.size(1), 1, self.num_heads, self.d_k), K.reshape(-1, 1, K.size(1), self.num_heads, self.d_k), V.reshape(-1, 1, V.size(1), self.num_heads, self.d_k)
e = (Q * K).sum(-1) / math.sqrt(self.d_k) # bsize x tgtlen x seqlen x num_heads
if mask is not None:
e = e + ((1 - mask.float()) * (-1e20)).unsqueeze(1).unsqueeze(-1)
a = torch.softmax(e, dim=2)
concat = (a.unsqueeze(-1) * V).sum(dim=2).reshape(-1, query_hiddens.size(1), self.hidden_size)
context = self.W_o(concat)
if remove_flag:
return context.squeeze(dim=1), a.mean(dim=-1).squeeze(dim=1)
else:
return context, a.mean(dim=-1)
class PoolingFunction(nn.Module):
""" Map a sequence of hidden_size dim vectors into one fixed size vector with dimension output_size """
def __init__(self, hidden_size=256, output_size=256, bias=True, method='attentive-pooling'):
super(PoolingFunction, self).__init__()
assert method in ['mean-pooling', 'max-pooling', 'attentive-pooling']
self.method = method
if self.method == 'attentive-pooling':
self.attn = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=bias),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=bias)
)
self.mapping_function = nn.Sequential(nn.Linear(hidden_size, output_size, bias=bias), nn.Tanh()) \
if hidden_size != output_size else lambda x: x
def forward(self, inputs, mask=None):
""" @args:
inputs(torch.FloatTensor): features, batch_size x seq_len x hidden_size
mask(torch.BoolTensor): mask for inputs, batch_size x seq_len
@return:
outputs(torch.FloatTensor): aggregate seq_len dim for inputs, batch_size x output_size
"""
if self.method == 'max-pooling':
outputs = inputs.masked_fill(~ mask.unsqueeze(-1), -1e8)
outputs = outputs.max(dim=1)[0]
elif self.method == 'mean-pooling':
mask_float = mask.float().unsqueeze(-1)
outputs = (inputs * mask_float).sum(dim=1) / mask_float.sum(dim=1)
elif self.method == 'attentive-pooling':
e = self.attn(inputs).squeeze(-1)
e = e + (1 - mask.float()) * (-1e20)
a = torch.softmax(e, dim=1).unsqueeze(1)
outputs = torch.bmm(a, inputs).squeeze(1)
else:
raise ValueError('[Error]: Unrecognized pooling method %s !' % (self.method))
outputs = self.mapping_function(outputs)
return outputs
class FFN(nn.Module):
def __init__(self, input_size):
super(FFN, self).__init__()
self.input_size = input_size
self.feedforward = nn.Sequential(
nn.Linear(self.input_size, self.input_size * 4),
nn.ReLU(inplace=True),
nn.Linear(self.input_size * 4, self.input_size)
)
self.layernorm = nn.LayerNorm(self.input_size)
def forward(self, inputs):
return self.layernorm(inputs + self.feedforward(inputs))
class Registrable(object):
"""
A class that collects all registered components,
adapted from `common.registrable.Registrable` from AllenNLP
"""
registered_components = dict()
@staticmethod
def register(name):
def register_class(cls):
if name in Registrable.registered_components:
raise RuntimeError('class %s already registered' % name)
Registrable.registered_components[name] = cls
return cls
return register_class
@staticmethod
def by_name(name):
return Registrable.registered_components[name]
class cached_property(object):
""" A property that is only computed once per instance and then replaces
itself with an ordinary attribute. Deleting the attribute resets the
property.
Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
"""
def __init__(self, func):
self.__doc__ = getattr(func, '__doc__')
self.func = func
def __get__(self, obj, cls):
if obj is None:
return self
value = obj.__dict__[self.func.__name__] = self.func(obj)
return value

View File

@ -0,0 +1,15 @@
## Data Preprocess
#### Get Parsed SQL Output
The SQL parsing script is `process_sql.py` in the main directory. Please refer to `parsed_sql_examples.sql` for the explanation of some parsed SQL output examples.
If you would like to use `process_sql.py` to parse SQL queries by yourself, `parse_sql_one.py` provides an example of how the script is called. Or you can use `parse_raw_json.py` to update all parsed SQL results (value for `sql`) in `train.json` and `dev.json`.
#### Get Table Info from Database
To generate the final `tables.json` file. It reads sqlite files from `database/` dir and previous `tables.json` with hand-corrected names:
```
python process/get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]
```

View File

@ -0,0 +1,57 @@
#coding=utf8
import argparse, os, sys, pickle, json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import Counter
def construct_vocab_from_dataset(*data_paths, table_path='data/tables.bin', mwf=4, reference_file=None, output_path=None, sep='\t'):
words = []
tables = pickle.load(open(table_path, 'rb'))
for fp in data_paths:
dataset = pickle.load(open(fp, 'rb'))
for ex in dataset:
words.extend(ex['processed_question_toks'])
db = tables[ex['db_id']]
words.extend(['table'] * len(db['table_names']))
words.extend(db['column_types'])
for c in db['processed_column_toks']:
words.extend(c)
for t in db['processed_table_toks']:
words.extend(t)
cnt = Counter(words)
vocab = sorted(list(cnt.items()), key=lambda x: - x[1])
glove_vocab = set()
with open(reference_file, 'r', encoding='utf-8') as inf:
for line in inf:
line = line.strip()
if line == '': continue
glove_vocab.add(line)
oov_words, oov_but_freq_words = set(), []
for w, c in vocab:
if w not in glove_vocab:
oov_words.add(w)
if c >= mwf:
oov_but_freq_words.append((w, c))
print('Out of glove vocabulary size: %d\nAmong them, %d words occur equal or more than %d times in training dataset.' % (len(oov_words), len(oov_but_freq_words), mwf))
with open(output_path, 'w') as of:
# first serialize oov but frequent words, allowing fine-tune them during training
for w, c in oov_but_freq_words:
of.write(w + sep + str(c) + '\n')
# next serialize words in both train vocab and glove vocab according to decreasing frequency
for w, c in vocab:
if w not in oov_words:
of.write(w + sep + str(c) + '\n')
return len(vocab)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_paths', nargs='+', type=str, help='input preprocessed dataset file')
parser.add_argument('--table_path', type=str, default='data/tables.bin', help='preprocessed table file')
parser.add_argument('--output_path', type=str, required=True, help='output word vocabulary path')
parser.add_argument('--reference_file', type=str, default='pretrained_models/glove-42b-300d/vocab_glove.txt',
help='eliminate word not in glove vocabulary, unless it occurs frequently >= mwf')
parser.add_argument('--mwf', default=4, type=int,
help='minimum word frequency used to pick up frequent words in training dataset, but not in glove vocabulary')
args = parser.parse_args()
construct_vocab_from_dataset(*args.data_paths, table_path=args.table_path, mwf=args.mwf, reference_file=args.reference_file, output_path=args.output_path)

View File

@ -0,0 +1,527 @@
#coding=utf8
import os, sqlite3
import numpy as np
import stanza, torch
from nltk.corpus import stopwords
from itertools import product, combinations
import torch.nn.functional as F
from utils.constants import MAX_RELATIVE_DIST
from transformers import AutoModel, AutoConfig, AutoTokenizer
import geoopt as gt
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def agg(input):
# if input.size(0)==1:
# return input.squeeze()
# else :
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
def quote_normalization(question):
""" Normalize all usage of quotation marks into a separate \" """
new_question, quotation_marks = [], ["'", '"', '`', '', '', '', '', '``', "''", "", ""]
for idx, tok in enumerate(question):
if len(tok) > 2 and tok[0] in quotation_marks and tok[-1] in quotation_marks:
new_question += ["\"", tok[1:-1], "\""]
elif len(tok) > 2 and tok[0] in quotation_marks:
new_question += ["\"", tok[1:]]
elif len(tok) > 2 and tok[-1] in quotation_marks:
new_question += [tok[:-1], "\"" ]
elif tok in quotation_marks:
new_question.append("\"")
elif len(tok) == 2 and tok[0] in quotation_marks:
# special case: the length of entity value is 1
if idx + 1 < len(question) and question[idx + 1] in quotation_marks:
new_question += ["\"", tok[1]]
else:
new_question.append(tok)
else:
new_question.append(tok)
return new_question
class Preprocessor():
def __init__(self, db_dir='data/database', db_content=True):
super(Preprocessor, self).__init__()
self.db_dir = db_dir
self.db_content = db_content
self.nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
self.stopwords = stopwords.words("english")
self.device = torch.device("cuda:0")
self.hidden_size=1024
self.max_batch_size = 8
self.plm_model =AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(self.device)
self.plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
self.config = self.plm_model.config
self.ball = gt.Stereographic(-1)
self.threshold = 0.4
def pipeline(self, entry: dict, db: dict, verbose: bool = False):
""" db should be preprocessed """
entry = self.preprocess_question(entry, db, verbose=verbose)
entry = self.schema_linking(entry, db, verbose=verbose)
entry = self.extract_subgraph(entry, db, verbose=verbose)
return entry
def preprocess_database(self, db: dict, verbose: bool = False):
""" Tokenize, lemmatize, lowercase table and column names for each database """
table_toks, table_names = [], []
for tab in db['table_names']:
doc = self.nlp(tab)
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
table_toks.append(tab)
table_names.append(" ".join(tab))
db['processed_table_toks'], db['processed_table_names'] = table_toks, table_names
column_toks, column_names = [], []
for _, c in db['column_names']:
doc = self.nlp(c)
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
column_toks.append(c)
column_names.append(" ".join(c))
db['processed_column_toks'], db['processed_column_names'] = column_toks, column_names
column2table = list(map(lambda x: x[0], db['column_names'])) # from column id to table id
table2columns = [[] for _ in range(len(table_names))] # from table id to column ids list
for col_id, col in enumerate(db['column_names']):
if col_id == 0: continue
table2columns[col[0]].append(col_id)
db['column2table'], db['table2columns'] = column2table, table2columns
t_num, c_num, dtype = len(db['table_names']), len(db['column_names']), '<U100'
# relations in tables, tab_num * tab_num
tab_mat = np.array([['table-table-generic'] * t_num for _ in range(t_num)], dtype=dtype)
table_fks = set(map(lambda pair: (column2table[pair[0]], column2table[pair[1]]), db['foreign_keys']))
for (tab1, tab2) in table_fks:
if (tab2, tab1) in table_fks:
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
else:
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fk', 'table-table-fkr'
tab_mat[list(range(t_num)), list(range(t_num))] = 'table-table-identity'
# relations in columns, c_num * c_num
col_mat = np.array([['column-column-generic'] * c_num for _ in range(c_num)], dtype=dtype)
for i in range(t_num):
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
col_mat[col1, col2] = 'column-column-sametable'
col_mat[list(range(c_num)), list(range(c_num))] = 'column-column-identity'
if len(db['foreign_keys']) > 0:
col1, col2 = list(zip(*db['foreign_keys']))
col_mat[col1, col2], col_mat[col2, col1] = 'column-column-fk', 'column-column-fkr'
col_mat[0, list(range(c_num))] = '*-column-generic'
col_mat[list(range(c_num)), 0] = 'column-*-generic'
col_mat[0, 0] = '*-*-identity'
# relations between tables and columns, t_num*c_num and c_num*t_num
tab_col_mat = np.array([['table-column-generic'] * c_num for _ in range(t_num)], dtype=dtype)
col_tab_mat = np.array([['column-table-generic'] * t_num for _ in range(c_num)], dtype=dtype)
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) # ignore *
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-has', 'table-column-has'
if len(db['primary_keys']) > 0:
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), db['primary_keys']))))
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-pk', 'table-column-pk'
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'
relations = np.concatenate([
np.concatenate([tab_mat, tab_col_mat], axis=1),
np.concatenate([col_tab_mat, col_mat], axis=1)
], axis=0)
db['relations'] = relations.tolist()
if verbose:
print('Tables:', ', '.join(db['table_names']))
print('Lemmatized:', ', '.join(table_names))
print('Columns:', ', '.join(list(map(lambda x: x[1], db['column_names']))))
print('Lemmatized:', ', '.join(column_names), '\n')
return db
def preprocess_question(self, entry: dict, db: dict, verbose: bool = False):
""" Tokenize, lemmatize, lowercase question"""
# stanza tokenize, lemmatize and POS tag
question = ' '.join(quote_normalization(entry['question_toks']))
doc = self.nlp(question)
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
pos_tags = [w.xpos for s in doc.sentences for w in s.words]
entry['raw_question_toks'] = raw_toks
entry['processed_question_toks'] = toks
entry['pos_tags'] = pos_tags
# relations in questions, q_num * q_num
q_num, dtype = len(toks), '<U100'
if q_num <= MAX_RELATIVE_DIST + 1:
dist_vec = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity'
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)]
starting = MAX_RELATIVE_DIST
else:
dist_vec = ['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1) + \
['question-question-dist' + str(i) if i != 0 else 'question-question-identity' \
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)] + \
['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1)
starting = q_num - 1
q_mat = np.array([dist_vec[starting - i: starting - i + q_num] for i in range(q_num)], dtype=dtype)
entry['relations'] = q_mat.tolist()
if verbose:
print('Question:', entry['question'])
print('Tokenized:', ' '.join(entry['raw_question_toks']))
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
return entry
def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
sql = entry['sql']
used_schema = {'table': set(), 'column': set()}
used_schema = self.extract_subgraph_from_sql(sql, used_schema)
entry['used_tables'] = sorted(list(used_schema['table']))
entry['used_columns'] = sorted(list(used_schema['column']))
if verbose:
print('Used tables:', entry['used_tables'])
print('Used columns:', entry['used_columns'], '\n')
return entry
def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
select_items = sql['select'][1]
# select clause
for _, val_unit in select_items:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# from clause conds
table_units = sql['from']['table_units']
for _, t in table_units:
if type(t) == dict:
used_schema = self.extract_subgraph_from_sql(t, used_schema)
else:
used_schema['table'].add(t)
# from, where and having conds
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'], used_schema)
used_schema = self.extract_subgraph_from_conds(sql['where'], used_schema)
used_schema = self.extract_subgraph_from_conds(sql['having'], used_schema)
# groupBy and orderBy clause
groupBy = sql['groupBy']
for col_unit in groupBy:
used_schema['column'].add(col_unit[1])
orderBy = sql['orderBy']
if len(orderBy) > 0:
orderBy = orderBy[1]
for val_unit in orderBy:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# union, intersect and except clause
if sql['intersect']:
used_schema = self.extract_subgraph_from_sql(sql['intersect'], used_schema)
if sql['union']:
used_schema = self.extract_subgraph_from_sql(sql['union'], used_schema)
if sql['except']:
used_schema = self.extract_subgraph_from_sql(sql['except'], used_schema)
return used_schema
def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
if len(conds) == 0:
return used_schema
for cond in conds:
if cond in ['and', 'or']:
continue
val_unit, val1, val2 = cond[2:]
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
if type(val1) == list:
used_schema['column'].add(val1[1])
elif type(val1) == dict:
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
if type(val2) == list:
used_schema['column'].add(val1[1])
elif type(val2) == dict:
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
return used_schema
def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
""" Perform schema linking: both question and database need to be preprocessed """
raw_question_toks, question_toks = entry['raw_question_toks'], entry['processed_question_toks']
table_toks, column_toks = db['processed_table_toks'], db['processed_column_toks']
table_names, column_names = db['processed_table_names'], db['processed_column_names']
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(column_toks), '<U100'
assert len(column_names)==len(column_toks) and len(table_names) == len(table_toks) and len(raw_question_toks)==len(question_toks)
question_id = [self.plm_tokenizer.cls_token_id]
question = [q.lower() for q in question_toks]
question_subword_len = []
for w in question:
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
exact_question_token = len(question_id) - 1
question_id.append(self.plm_tokenizer.sep_token_id)
masked_question_id = [question_id]
start = 1
for i, sub_len in enumerate(question_subword_len):
tmp_question_id = question_id.copy()
for m in range(start, start + sub_len):
tmp_question_id[m] = self.plm_tokenizer.mask_token_id
masked_question_id.append(tmp_question_id)
start += sub_len
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
l = 0
for w in s:
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
table_id.extend(toks)
table_subword_len.append(len(toks))
l += len(toks)
table_word_len.append(l)
table_mask_plm = [1] * len(table_id)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
l = 0
for w in s:
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
column_id.extend(toks)
column_subword_len.append(len(toks))
l += len(toks)
column_word_len.append(l)
column_mask_plm = [1] * len(column_id) + [0]
exact_column_token = len(column_id)
column_id.append(self.plm_tokenizer.sep_token_id)
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id = []
segment_id = []
atten_mask = []
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1] * len(input_id[-1]))
start = 0
total_size = len(input_id)
store_arr = []
if total_size <= self.max_batch_size:
ii = torch.tensor(input_id, dtype=torch.long, device=self.device)
im = torch.tensor(atten_mask, dtype=torch.float, device=self.device)
si = torch.tensor(segment_id, dtype=torch.long, device=self.device)
outputs = self.plm_model(ii, im)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + self.max_batch_size <= len(input_id):
ii = torch.tensor(input_id[start: start + self.max_batch_size], dtype=torch.long, device=self.device)
im = torch.tensor(atten_mask[start: start + self.max_batch_size], dtype=torch.float, device=self.device)
si = torch.tensor(segment_id[start: start + self.max_batch_size], dtype=torch.long, device=self.device)
outputs = self.plm_model(ii, im)[0] # .squeeze()
store_arr.append(outputs)
else:
ii = torch.tensor(input_id[start: len(input_id)], dtype=torch.long, device=self.device)
im = torch.tensor(atten_mask[start: len(input_id)], dtype=torch.float, device=self.device)
si = torch.tensor(segment_id[start: len(input_id)], dtype=torch.long, device=self.device)
outputs = self.plm_model(ii, im)[0] # .squeeze()
store_arr.append(outputs)
start += self.max_batch_size
assert len(store_arr) > 0
if len(store_arr) == 1:
outputs = store_arr[0]
else:
outputs = store_arr[0]
for t in store_arr[1:]:
outputs = torch.cat((outputs, t), dim=0)
q_tab_mat = outputs.new_zeros(len(raw_question_toks), len(table_names))
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=self.device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), self.hidden_size)
start = 0
new_table_arr = []
for i, sub_len in enumerate(table_word_len):
curr = old_tables[:, start:start + sub_len]
new_table_arr.append(agg(curr))
start += sub_len
new_tables = torch.cat(new_table_arr, 1)
tbl_cmp = new_tables[0:1]
tbl_msk = new_tables[1:]
assert tbl_msk.size(0) == len(raw_question_toks)
for i in range(len(table_word_len)):
a = self.ball.expmap0(tbl_cmp[:, i])
b = self.ball.expmap0(tbl_msk[:, i])
dis=self.ball.dist(a,b)
q_tab_mat[:, i] = dis
q_col_mat = outputs.new_zeros(len(raw_question_toks), len(column_names))
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=self.device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, self.hidden_size)
new_column_arr = []
start = 0
for i, sub_len in enumerate(column_word_len):
curr = old_columns[:, start:start + sub_len]
new_column_arr.append(agg(curr))
start += sub_len
new_column = torch.cat(new_column_arr, 1)
col_cmp = new_column[0:1]
col_msk = new_column[1:]
assert col_msk.size(0) == len(raw_question_toks)
for i in range(len(column_word_len)):
a = self.ball.expmap0(col_cmp[:, i])
b = self.ball.expmap0(col_msk[:, i])
dis=self.ball.dist(a,b)
q_col_mat[:, i] = dis
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
matrix_min=torch.min(use_matrix)
matrix_max=torch.max(use_matrix)
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
use_q_tab_mat = use_matrix[:, :q_tab_mat.size(1)]
use_q_col_mat = use_matrix[:, q_tab_mat.size(1):]
assert use_q_tab_mat.size(1) == t_num and use_q_col_mat.size(1)== c_num
use_tab_q_mat = use_q_tab_mat.transpose(0,1).cpu().detach().numpy()
use_col_q_mat = use_q_col_mat.transpose(0,1).cpu().detach().numpy()
table_matched_pairs = {'partial': [], 'exact': []}
q_tab_mat = np.array([['question-table-nomatch'] * t_num for _ in range(q_num)], dtype=dtype)
tab_q_mat = np.array([['table-question-nomatch'] * q_num for _ in range(t_num)], dtype=dtype)
max_len = max([len(t) for t in table_toks])
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i: j])
if phrase in self.stopwords: continue
for idx, name in enumerate(table_names):
if phrase == name: # fully match will overwrite partial match due to sort
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
if verbose:
table_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
if verbose:
table_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
assert use_tab_q_mat.shape[0]==t_num and use_tab_q_mat.shape[1]==q_num
for x in range(t_num):
for y in range(q_num):
if question_toks[y] in self.stopwords or question_toks[y] in '."?,':
continue
if use_tab_q_mat[x,y]>self.threshold:
if 'partialmatch' in tab_q_mat[x,y]:
tab_q_mat[x,y] = 'table-question-partialsemanticmatch'
q_tab_mat[y,x] = 'question-table-partialsemanticmatch'
elif 'exact' in tab_q_mat[x,y]:
continue
elif 'nomatch' in tab_q_mat[x,y]:
tab_q_mat[x,y] = 'table-question-semanticmatch'
q_tab_mat[y,x] = 'question-table-semanticmatch'
# relations between questions and columns
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
q_col_mat = np.array([['question-column-nomatch'] * c_num for _ in range(q_num)], dtype=dtype)
col_q_mat = np.array([['column-question-nomatch'] * q_num for _ in range(c_num)], dtype=dtype)
max_len = max([len(c) for c in column_toks])
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i: j])
if phrase in self.stopwords: continue
for idx, name in enumerate(column_names):
if phrase == name: # fully match will overwrite partial match due to sort
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
if verbose:
column_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
q_col_mat[range(i, j), idx] = 'question-column-partialmatch'
col_q_mat[idx, range(i, j)] = 'column-question-partialmatch'
if verbose:
column_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
if self.db_content:
db_file = os.path.join(self.db_dir, db['db_id'], db['db_id'] + '.sqlite')
if not os.path.exists(db_file):
raise ValueError('[ERROR]: database file %s not found ...' % (db_file))
conn = sqlite3.connect(db_file)
conn.text_factory = lambda b: b.decode(errors='ignore')
conn.execute('pragma foreign_keys=ON')
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
if i == 0 or 'id' in column_toks[i]: # ignore * and special token 'id'
continue
tab_name = db['table_names_original'][tab_id]
try:
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" % (col_name, tab_name))
cell_values = cursor.fetchall()
cell_values = [str(each[0]) for each in cell_values]
cell_values = [[str(float(each))] if is_number(each) else each.lower().split() for each in cell_values]
except Exception as e:
print(e)
for j, word in enumerate(raw_question_toks):
word = str(float(word)) if is_number(word) else word
for c in cell_values:
if word in c and 'nomatch' in q_col_mat[j, i] and word not in self.stopwords:
q_col_mat[j, i] = 'question-column-valuematch'
col_q_mat[i, j] = 'column-question-valuematch'
if verbose:
column_matched_pairs['value'].append(str((column_names[i], i, word, j, j + 1)))
break
conn.close()
assert use_col_q_mat.shape[0]==c_num and use_col_q_mat.shape[1]==q_num
for x in range(c_num):
for y in range(q_num):
if question_toks[y] in self.stopwords or question_toks[y] in '."?,':
continue
if use_col_q_mat[x,y]>self.threshold:
if 'partialmatch' in col_q_mat[x,y]:
col_q_mat[x,y] = 'column-question-partialsemanticmatch'
q_col_mat[y,x] = 'question-column-partialsemanticmatch'
elif 'exact' in col_q_mat[x,y] or 'value' in col_q_mat[x,y]:
continue
elif 'nomatch' in col_q_mat[x,y]:
col_q_mat[x,y] = 'column-question-semanticmatch'
q_col_mat[y,x] = 'question-column-semanticmatch'
# two symmetric schema linking matrix: q_num x (t_num + c_num), (t_num + c_num) x q_num
q_col_mat[:, 0] = 'question-*-generic'
col_q_mat[0] = '*-question-generic'
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())
if verbose:
print('Question:', ' '.join(question_toks))
print('Table matched: (table name, column id, question span, start id, end id)')
print('Exact match:', ', '.join(table_matched_pairs['exact']) if table_matched_pairs['exact'] else 'empty')
print('Partial match:', ', '.join(table_matched_pairs['partial']) if table_matched_pairs['partial'] else 'empty')
print('Column matched: (column name, column id, question span, start id, end id)')
print('Exact match:', ', '.join(column_matched_pairs['exact']) if column_matched_pairs['exact'] else 'empty')
print('Partial match:', ', '.join(column_matched_pairs['partial']) if column_matched_pairs['partial'] else 'empty')
print('Value match:', ', '.join(column_matched_pairs['value']) if column_matched_pairs['value'] else 'empty', '\n')
return entry

779
proton/preprocess/cu.py Normal file
View File

@ -0,0 +1,779 @@
#coding=utf8
import os, sqlite3
import numpy as np
import stanza, torch
from nltk.corpus import stopwords
from itertools import product, combinations
import torch.nn.functional as F
#from utils.constants import MAX_RELATIVE_DIST
from transformers import AutoModel, AutoConfig,AutoTokenizer
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def agg(input):
# if input.size(0)==1:
# return input.squeeze()
# else :
return torch.sum(input,dim=1)/input.size(1)
def quote_normalization(question):
""" Normalize all usage of quotation marks into a separate \" """
new_question, quotation_marks = [], ["'", '"', '`', '', '', '', '', '``', "''", "", ""]
for idx, tok in enumerate(question):
if len(tok) > 2 and tok[0] in quotation_marks and tok[-1] in quotation_marks:
new_question += ["\"", tok[1:-1], "\""]
elif len(tok) > 2 and tok[0] in quotation_marks:
new_question += ["\"", tok[1:]]
elif len(tok) > 2 and tok[-1] in quotation_marks:
new_question += [tok[:-1], "\"" ]
elif tok in quotation_marks:
new_question.append("\"")
elif len(tok) == 2 and tok[0] in quotation_marks:
# special case: the length of entity value is 1
if idx + 1 < len(question) and question[idx + 1] in quotation_marks:
new_question += ["\"", tok[1]]
else:
new_question.append(tok)
else:
new_question.append(tok)
return new_question
class Preprocessor():
def __init__(self, db_dir='data/database', db_content=True):
super(Preprocessor, self).__init__()
self.db_dir = db_dir
self.db_content = db_content
self.nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
self.stopwords = stopwords.words("english")
self.device = torch.device("cuda:0")
self.hidden_size=1024
self.max_batch_size = 8
self.plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(self.device)
self.plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
self.config = self.plm_model.config
# raw_question_toks=['What', 'is', 'the', 'model', 'for', 'the', 'car', 'with', 'a', 'weight', 'smaller', 'than', 'the', 'average', '?']
# question_id = [self.plm_tokenizer.cls_token_id]
# question = [q.lower() for q in raw_question_toks]
# question_subword_len = []
# for w in question:
# toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
# question_id.extend(toks)
# question_subword_len.append(len(toks))
# print(question_id)
#self.subword_aggregation = SubwordAggregation(self.config.hidden_size, subword_aggregation=subword_aggregation)
def pipeline(self, entry: dict, db: dict, verbose: bool = False):
""" db should be preprocessed """
entry = self.preprocess_question(entry, db, verbose=verbose)
entry = self.schema_linking(entry, db, verbose=verbose)
entry = self.extract_subgraph(entry, db, verbose=verbose)
return entry
def preprocess_database(self, db: dict, verbose: bool = False):
""" Tokenize, lemmatize, lowercase table and column names for each database """
table_toks, table_names = [], []
for tab in db['table_names']:
doc = self.nlp(tab)
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
table_toks.append(tab)
table_names.append(" ".join(tab))
db['processed_table_toks'], db['processed_table_names'] = table_toks, table_names
column_toks, column_names = [], []
for _, c in db['column_names']:
doc = self.nlp(c)
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
column_toks.append(c)
column_names.append(" ".join(c))
db['processed_column_toks'], db['processed_column_names'] = column_toks, column_names
column2table = list(map(lambda x: x[0], db['column_names'])) # from column id to table id
table2columns = [[] for _ in range(len(table_names))] # from table id to column ids list
for col_id, col in enumerate(db['column_names']):
if col_id == 0: continue
table2columns[col[0]].append(col_id)
db['column2table'], db['table2columns'] = column2table, table2columns
t_num, c_num, dtype = len(db['table_names']), len(db['column_names']), '<U100'
# relations in tables, tab_num * tab_num
tab_mat = np.array([['table-table-generic'] * t_num for _ in range(t_num)], dtype=dtype)
table_fks = set(map(lambda pair: (column2table[pair[0]], column2table[pair[1]]), db['foreign_keys']))
for (tab1, tab2) in table_fks:
if (tab2, tab1) in table_fks:
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
else:
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fk', 'table-table-fkr'
tab_mat[list(range(t_num)), list(range(t_num))] = 'table-table-identity'
# relations in columns, c_num * c_num
col_mat = np.array([['column-column-generic'] * c_num for _ in range(c_num)], dtype=dtype)
for i in range(t_num):
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
col_mat[col1, col2] = 'column-column-sametable'
col_mat[list(range(c_num)), list(range(c_num))] = 'column-column-identity'
if len(db['foreign_keys']) > 0:
col1, col2 = list(zip(*db['foreign_keys']))
col_mat[col1, col2], col_mat[col2, col1] = 'column-column-fk', 'column-column-fkr'
col_mat[0, list(range(c_num))] = '*-column-generic'
col_mat[list(range(c_num)), 0] = 'column-*-generic'
col_mat[0, 0] = '*-*-identity'
# relations between tables and columns, t_num*c_num and c_num*t_num
tab_col_mat = np.array([['table-column-generic'] * c_num for _ in range(t_num)], dtype=dtype)
col_tab_mat = np.array([['column-table-generic'] * t_num for _ in range(c_num)], dtype=dtype)
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) # ignore *
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-has', 'table-column-has'
if len(db['primary_keys']) > 0:
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), db['primary_keys']))))
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-pk', 'table-column-pk'
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'
relations = np.concatenate([
np.concatenate([tab_mat, tab_col_mat], axis=1),
np.concatenate([col_tab_mat, col_mat], axis=1)
], axis=0)
db['relations'] = relations.tolist()
if verbose:
print('Tables:', ', '.join(db['table_names']))
print('Lemmatized:', ', '.join(table_names))
print('Columns:', ', '.join(list(map(lambda x: x[1], db['column_names']))))
print('Lemmatized:', ', '.join(column_names), '\n')
return db
def preprocess_question(self, entry: dict, db: dict, verbose: bool = False):
""" Tokenize, lemmatize, lowercase question"""
# stanza tokenize, lemmatize and POS tag
question = ' '.join(quote_normalization(entry['question_toks']))
doc = self.nlp(question)
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
pos_tags = [w.xpos for s in doc.sentences for w in s.words]
entry['raw_question_toks'] = raw_toks
entry['processed_question_toks'] = toks
entry['pos_tags'] = pos_tags
# relations in questions, q_num * q_num
q_num, dtype = len(toks), '<U100'
if q_num <= MAX_RELATIVE_DIST + 1:
dist_vec = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity'
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)]
starting = MAX_RELATIVE_DIST
else:
dist_vec = ['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1) + \
['question-question-dist' + str(i) if i != 0 else 'question-question-identity' \
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)] + \
['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1)
starting = q_num - 1
q_mat = np.array([dist_vec[starting - i: starting - i + q_num] for i in range(q_num)], dtype=dtype)
entry['relations'] = q_mat.tolist()
if verbose:
print('Question:', entry['question'])
print('Tokenized:', ' '.join(entry['raw_question_toks']))
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
return entry
def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
sql = entry['sql']
used_schema = {'table': set(), 'column': set()}
used_schema = self.extract_subgraph_from_sql(sql, used_schema)
entry['used_tables'] = sorted(list(used_schema['table']))
entry['used_columns'] = sorted(list(used_schema['column']))
if verbose:
print('Used tables:', entry['used_tables'])
print('Used columns:', entry['used_columns'], '\n')
return entry
def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
select_items = sql['select'][1]
# select clause
for _, val_unit in select_items:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# from clause conds
table_units = sql['from']['table_units']
for _, t in table_units:
if type(t) == dict:
used_schema = self.extract_subgraph_from_sql(t, used_schema)
else:
used_schema['table'].add(t)
# from, where and having conds
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'], used_schema)
used_schema = self.extract_subgraph_from_conds(sql['where'], used_schema)
used_schema = self.extract_subgraph_from_conds(sql['having'], used_schema)
# groupBy and orderBy clause
groupBy = sql['groupBy']
for col_unit in groupBy:
used_schema['column'].add(col_unit[1])
orderBy = sql['orderBy']
if len(orderBy) > 0:
orderBy = orderBy[1]
for val_unit in orderBy:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# union, intersect and except clause
if sql['intersect']:
used_schema = self.extract_subgraph_from_sql(sql['intersect'], used_schema)
if sql['union']:
used_schema = self.extract_subgraph_from_sql(sql['union'], used_schema)
if sql['except']:
used_schema = self.extract_subgraph_from_sql(sql['except'], used_schema)
return used_schema
def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
if len(conds) == 0:
return used_schema
for cond in conds:
if cond in ['and', 'or']:
continue
val_unit, val1, val2 = cond[2:]
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
if type(val1) == list:
used_schema['column'].add(val1[1])
elif type(val1) == dict:
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
if type(val2) == list:
used_schema['column'].add(val1[1])
elif type(val2) == dict:
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
return used_schema
def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
""" Perform schema linking: both question and database need to be preprocessed """
raw_question_toks, question_toks = entry['raw_question_toks'], entry['processed_question_toks']
table_toks, column_toks = db['processed_table_toks'], db['processed_column_toks']
table_names, column_names = db['processed_table_names'], db['processed_column_names']
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(column_toks), '<U100'
print(raw_question_toks)
print(column_names)
print(table_names)
assert len(column_names)==len(column_toks) and len(table_names) == len(table_toks) and len(raw_question_toks)==len(question_toks)
# relations between questions and tables, q_num*t_num and t_num*q_num
table_matched_pairs = {'partial': [], 'exact': []}
q_tab_mat = np.array([['question-table-nomatch'] * t_num for _ in range(q_num)], dtype=dtype)
tab_q_mat = np.array([['table-question-nomatch'] * q_num for _ in range(t_num)], dtype=dtype)
max_len = max([len(t) for t in table_toks])
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i: j])
if phrase in self.stopwords: continue
for idx, name in enumerate(table_names):
if phrase == name: # fully match will overwrite partial match due to sort
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
if verbose:
table_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
if verbose:
table_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
# relations between questions and columns
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
q_col_mat = np.array([['question-column-nomatch'] * c_num for _ in range(q_num)], dtype=dtype)
col_q_mat = np.array([['column-question-nomatch'] * q_num for _ in range(c_num)], dtype=dtype)
max_len = max([len(c) for c in column_toks])
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i: j])
if phrase in self.stopwords: continue
for idx, name in enumerate(column_names):
if phrase == name: # fully match will overwrite partial match due to sort
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
if verbose:
column_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
q_col_mat[range(i, j), idx] = 'question-column-partialmatch'
col_q_mat[idx, range(i, j)] = 'column-question-partialmatch'
if verbose:
column_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
if self.db_content:
db_file = os.path.join(self.db_dir, db['db_id'], db['db_id'] + '.sqlite')
if not os.path.exists(db_file):
raise ValueError('[ERROR]: database file %s not found ...' % (db_file))
conn = sqlite3.connect(db_file)
conn.text_factory = lambda b: b.decode(errors='ignore')
conn.execute('pragma foreign_keys=ON')
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
if i == 0 or 'id' in column_toks[i]: # ignore * and special token 'id'
continue
tab_name = db['table_names_original'][tab_id]
try:
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" % (col_name, tab_name))
cell_values = cursor.fetchall()
cell_values = [str(each[0]) for each in cell_values]
cell_values = [[str(float(each))] if is_number(each) else each.lower().split() for each in cell_values]
except Exception as e:
print(e)
for j, word in enumerate(raw_question_toks):
word = str(float(word)) if is_number(word) else word
for c in cell_values:
if word in c and 'nomatch' in q_col_mat[j, i] and word not in self.stopwords:
q_col_mat[j, i] = 'question-column-valuematch'
col_q_mat[i, j] = 'column-question-valuematch'
if verbose:
column_matched_pairs['value'].append(str((column_names[i], i, word, j, j + 1)))
break
conn.close()
# two symmetric schema linking matrix: q_num x (t_num + c_num), (t_num + c_num) x q_num
q_col_mat[:, 0] = 'question-*-generic'
col_q_mat[0] = '*-question-generic'
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())
if verbose:
print('Question:', ' '.join(question_toks))
print('Table matched: (table name, column id, question span, start id, end id)')
print('Exact match:', ', '.join(table_matched_pairs['exact']) if table_matched_pairs['exact'] else 'empty')
print('Partial match:', ', '.join(table_matched_pairs['partial']) if table_matched_pairs['partial'] else 'empty')
print('Column matched: (column name, column id, question span, start id, end id)')
print('Exact match:', ', '.join(column_matched_pairs['exact']) if column_matched_pairs['exact'] else 'empty')
print('Partial match:', ', '.join(column_matched_pairs['partial']) if column_matched_pairs['partial'] else 'empty')
print('Value match:', ', '.join(column_matched_pairs['value']) if column_matched_pairs['value'] else 'empty', '\n')
return entry
if __name__=='__main__':
device = torch.device("cuda:0")
hidden_size=1024
max_batch_size = 8
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
config = plm_model.config
raw_question_toks=['list', 'the', 'service', 'id', 'and', 'details', 'for', 'the', 'events', '.']
column_names=['*', 'service id', 'service type code', 'participant id', 'participant type code', 'participant detail', 'event id', 'service id', 'event detail', 'event id', 'participant id']
table_names=['service', 'participant', 'event', 'participant in event']
print('Q', len(raw_question_toks), raw_question_toks)
print('C', len(column_names), column_names)
print('T', len(table_names), table_names)
question_id = [plm_tokenizer.cls_token_id]
question = [q.lower() for q in raw_question_toks]
question_subword_len = []
for w in question:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
exact_question_token=len(question_id) - 1
question_id.append(plm_tokenizer.sep_token_id)
#print(question_id)
masked_question_id = []
#print(len(masked_question_id))
start=1
#print(question_subword_len)
for i,sub_len in enumerate(question_subword_len):
#print(question_id)
tmp_question_id=question_id.copy()
#print(list(range(start,start+sub_len)))
for m in range(start,start+sub_len):
# print('index',index+i)
#print('before', tmp_question_id)
tmp_question_id[m]=plm_tokenizer.mask_token_id
#print('after', tmp_question_id)
# print('next', masked_question_id[index+i+1])
masked_question_id.append(tmp_question_id)
start+=sub_len
#print('end')
# index+=1
# for a in masked_question_id:
# print(a)
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
#s 每一个table name
l = 0
for w in s:
#table name里每一个单词
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
table_id.extend(toks)
table_subword_len.append(len(toks))
l += len(toks)
table_word_len.append(l)
table_mask_plm = [1] * len(table_id)
# print(table_id)
# print(table_subword_len)
# print(table_word_len)
masked_table_id = []
#print(len(masked_question_id))
start=0
#print(question_subword_len)
for i,sub_len in enumerate(table_word_len):
#print(question_id)
tmp_table_id=table_id.copy()
#print(list(range(start,start+sub_len)))
for m in range(start,start+sub_len):
# print('index',index+i)
#print('before', tmp_question_id)
tmp_table_id[m]=plm_tokenizer.mask_token_id
#print('after', tmp_question_id)
# print('next', masked_question_id[index+i+1])
masked_table_id.append(tmp_table_id)
start+=sub_len
#print('end')
# index+=1
# for a in masked_table_id:
# print(a)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
column_id.extend(toks)
column_subword_len.append(len(toks))
l += len(toks)
column_word_len.append(l)
column_mask_plm = [1] * len(column_id) + [0]
exact_column_token=len(column_id)
column_id.append(plm_tokenizer.sep_token_id)
#print(column_id)
#print(column_subword_len)
#print(column_word_len)
masked_column_id = []
#print(len(masked_question_id))
start=0
#print(question_subword_len)
for i,sub_len in enumerate(column_word_len):
#print(question_id)
tmp_column_id=column_id.copy()
#print(list(range(start,start+sub_len)))
for m in range(start,start+sub_len):
# print('index',index+i)
#print('before', tmp_question_id)
tmp_column_id[m]=plm_tokenizer.mask_token_id
#print('after', tmp_question_id)
# print('next', masked_question_id[index+i+1])
masked_column_id.append(tmp_column_id)
start+=sub_len
#print('end')
# index+=1
# for a in masked_column_id:
# print(a)
# input_id = question_id + table_id + column_id
# segment_id = [0] * len(question_id) + [1] * (len(table_id) + len(column_id))
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id=[]
segment_id=[]
atten_mask=[]
print(len(table_id))
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
#device = torch.device("cuda:1")
# print(len(input_id))
# print(len(segment_id))
# print(len(question_mask_plm))
#inputs = {"input_ids": torch.tensor(input_id, dtype=torch.long, device=device), "attention_mask": torch.tensor([1]*len(input_id), dtype=torch.float, device=device), "token_type_ids": torch.tensor(segment_id, dtype=torch.long, device=device), "position_ids": None}
start=0
total_size=len(input_id)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
#print(ii.size())
#print(im.size())
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
#print(outputs.size())
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
#print(ii.size())
#print(im.size())
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
#print(outputs.size())
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
#print(ii.size())
#print(im.size())
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
#print(outputs.size())
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
for t in store_arr[1:]:
outputs= torch.cat((outputs, t), dim=0)
qu_output=outputs
#print(qu_output.size())
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
for j, msk_t_id in enumerate(masked_table_id):
input_id=[]
segment_id=[]
atten_mask=[]
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + msk_t_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
start=0
total_size=len(input_id)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
for t in store_arr[1:]:
outputs= torch.cat((outputs, t), dim=0)
#print(outputs.size())
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
#print(old_tables.size())
cmp_tables = qu_output.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
#print(old_tables.size()) #[16, 6, 1024]
# table_compare=old_tables[0] # [6, 1024]
# table_mask = old_tables[1:] # [15, 6, 1024]
start=0
new_table_arr = []
cmp_table_arr = []
for m,sub_len in enumerate(table_word_len):
#print(start, start+sub_len-1)
if m==j:
curr=agg(old_tables[:, start:start+sub_len])
cmp_curr = agg(cmp_tables[:, start:start+sub_len])
# print(curr.size())
# print(cmp_curr.size())
dis=F.pairwise_distance(cmp_curr,curr,p=2)
q_tab_mat[:,j]=dis
start+=sub_len
break
else:
start+=sub_len
#print(q_tab_mat)
# curr=old_tables[:, start:start+sub_len]
# #print(curr.size())
# new_table_arr.append(agg(curr))
# curr=cmp_tables[:, start:start+sub_len]
# #print(curr.size())
# cmp_table_arr.append(agg(curr))
# # new_tables[i]=agg(curr)
# #print(new_questions[i][:10])
# #print(curr.size())
# start+=sub_len
# new_tables = torch.cat(new_table_arr, 1)
# #print(new_tables.size()) # [16, 2, 1024]
# tbl_msk=new_tables[:,j] #15, 2, 1024
# print('tbl_msk',tbl_msk.size())
# # assert tbl_msk.size(0) == len(raw_question_toks)
# for i, msk_q_id in enumerate(masked_question_id):
# a=qu_output[i]
# b=tbl_msk[i]
# dis=F.pairwise_distance(a,b,p=2)
# q_tab_mat[i][j]=dis
# # for i in range(len(table_word_len)):
# # a=tbl_cmp[:,i]
# # #print(a.size())
# # b=tbl_msk[:,i]
# # #print(b.size())
# # dis=F.pairwise_distance(a,b,p=2)
# # q_tab_mat[:,i]=dis
# # #print(dis)
#print(q_tab_mat.transpose(0,1).size())
print(q_tab_mat.transpose(0,1))
# old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
# #print(old_columns.size())
# new_column_arr = []
# start=0
# for i,sub_len in enumerate(column_word_len):
# #print(start, start+sub_len-1)
# curr=old_columns[:,start:start+sub_len]
# new_column_arr.append(agg(curr))
# #print(new_questions[i][:10])
# #print(curr.size())
# start+=sub_len
# new_column = torch.cat(new_column_arr, 1)
# #print(new_column.size()) #[16, 8, 1024]
# col_cmp=new_column[0:1] #1,2,1024
# col_msk=new_column[1:] #15, 2, 1024
# assert col_msk.size(0) == len(raw_question_toks)
# for i in range(len(column_word_len)):
# a=col_cmp[:,i]
# #print(a.size())
# b=col_msk[:,i]
# #print(b.size())
# dis=F.pairwise_distance(a,b,p=2)
# q_col_mat[:,i]=dis
# #print(dis)
# print(q_col_mat.transpose(0,1).size())
# print(q_col_mat.transpose(0,1))
# print(q_col_mat.transpose(0,1))
#for i, curr_table in enumerate(table_names):
# use_whole = outputs[0] #35, 1024]
# use_mask = outputs[1:] #[15, 35, 1024]
# print(use_whole.size())
# print(use_mask.size())
# assert use_mask.size(0)==len(raw_question_toks)
# for i, sub_len in enumerate(raw_question_toks):
# outputs=use_mask[i]
# print(outputs.size())
# old_questions = outputs.masked_select(torch.tensor(question_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(exact_question_token, hidden_size)
# print(old_questions.size())
# new_questions = old_questions.new_zeros(len(raw_question_toks), hidden_size)
# start=0
# print(question_subword_len)
# assert len(question_subword_len) == len(raw_question_toks)
# for i,sub_len in enumerate(question_subword_len):
# #print(start, start+sub_len-1)
# curr=old_questions[start:start+sub_len]
# new_questions[i]=agg(curr)
# #print(new_questions[i][:10])
# #print(curr.size())
# start+=sub_len
# old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(len(table_id), hidden_size)
# print(old_tables.size())
# new_tables = old_tables.new_zeros(len(table_names), hidden_size)
# start=0
# print(table_word_len)
# assert len(table_word_len) == len(table_names)
# for i,sub_len in enumerate(table_word_len):
# #print(start, start+sub_len-1)
# curr=old_tables[start:start+sub_len]
# new_tables[i]=agg(curr)
# #print(new_questions[i][:10])
# #print(curr.size())
# start+=sub_len
# old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(exact_column_token, hidden_size)
# print(old_columns.size())
# new_column = old_columns.new_zeros(len(column_names),hidden_size)
# start=0
# print(column_word_len)
# assert len(column_word_len) == len(column_names)
# for i,sub_len in enumerate(column_word_len):
# #print(start, start+sub_len-1)
# curr=old_columns[start:start+sub_len]
# new_column[i]=agg(curr)
# #print(new_questions[i][:10])
# print(curr.size())
# start+=sub_len

View File

@ -0,0 +1,123 @@
import os
import sys
import json
import sqlite3
from os import listdir, makedirs
from os.path import isfile, isdir, join, split, exists, splitext
import traceback
EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"}
def convert_fk_index(data):
fk_holder = []
for fk in data["foreign_keys"]:
tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1]
ref_cid, cid = None, None
try:
tid = data['table_names_original'].index(tn)
ref_tid = data['table_names_original'].index(ref_tn)
for i, (tab_id, col_org) in enumerate(data['column_names_original']):
if tab_id == ref_tid and ref_col == col_org:
ref_cid = i
elif tid == tab_id and col == col_org:
cid = i
if ref_cid and cid:
fk_holder.append([cid, ref_cid])
except:
traceback.print_exc()
print("table_names_original: ", data['table_names_original'])
print("finding tab name: ", tn, ref_tn)
sys.exit()
return fk_holder
def dump_db_json_schema(db, f):
'''read table and column info'''
conn = sqlite3.connect(db)
conn.execute('pragma foreign_keys=ON')
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
data = {'db_id': f,
'table_names_original': [],
'table_names': [],
'column_names_original': [(-1, '*')],
'column_names': [(-1, '*')],
'column_types': ['text'],
'primary_keys': [],
'foreign_keys': []}
fk_holder = []
for i, item in enumerate(cursor.fetchall()):
table_name = item[0]
data['table_names_original'].append(table_name)
data['table_names'].append(table_name.lower().replace("_", ' '))
fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall()
fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks])
cur = conn.execute("PRAGMA table_info('{}') ".format(table_name))
for j, col in enumerate(cur.fetchall()):
data['column_names_original'].append((i, col[1]))
data['column_names'].append((i, col[1].lower().replace("_", " ")))
col_type = col[2].lower()
if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type:
data['column_types'].append('text')
elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type\
or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type:
data['column_types'].append('number')
elif 'date' in col_type or 'time' in col_type or 'year' in col_type:
data['column_types'].append('time')
elif 'boolean' in col_type:
data['column_types'].append('boolean')
else:
data['column_types'].append('others')
if col[5] == 1:
data['primary_keys'].append(len(data['column_names'])-1)
data["foreign_keys"] = fk_holder
data['foreign_keys'] = convert_fk_index(data)
return data
if __name__ == '__main__':
if len(sys.argv) < 2:
print("Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]")
sys.exit()
input_dir = sys.argv[1]
output_file = sys.argv[2]
ex_tab_file = sys.argv[3]
all_fs = [df for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))]
with open(ex_tab_file) as f:
ex_tabs = json.load(f)
#for tab in ex_tabs:
# tab["foreign_keys"] = convert_fk_index(tab)
ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs}
# print "precessed file num: ", len(ex_tabs)
not_fs = [df for df in listdir(input_dir) if not exists(join(input_dir, df, df+'.sqlite'))]
for d in not_fs:
print("no sqlite file found in: ", d)
db_files = [(df+'.sqlite', df) for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))]
tables = []
for f, df in db_files:
#if df in ex_tabs.keys():
#print 'reading old db: ', df
# tables.append(ex_tabs[df])
db = join(input_dir, df, f)
# print('\nreading new db: ', df)
table = dump_db_json_schema(db, df)
prev_tab_num = len(ex_tabs[df]["table_names"])
prev_col_num = len(ex_tabs[df]["column_names"])
cur_tab_num = len(table["table_names"])
cur_col_num = len(table["column_names"])
if df in ex_tabs.keys() and prev_tab_num == cur_tab_num and prev_col_num == cur_col_num and prev_tab_num != 0 and len(ex_tabs[df]["column_names"]) > 1:
table["table_names"] = ex_tabs[df]["table_names"]
table["column_names"] = ex_tabs[df]["column_names"]
else:
print("\n----------------------------------problem db: ", df)
tables.append(table)
print("final db num: ", len(tables))
with open(output_file, 'wt') as out:
json.dump(tables, out, sort_keys=True, indent=2, separators=(',', ': '))

View File

@ -0,0 +1,90 @@
#coding=utf8
import math, dgl, torch
import numpy as np
from utils.constants import MAX_RELATIVE_DIST
from utils.graph_example import GraphExample
# mapping special column * as an ordinary column
special_column_mapping_dict = {
'question-*-generic': 'question-column-nomatch',
'*-question-generic': 'column-question-nomatch',
'table-*-generic': 'table-column-has',
'*-table-generic': 'column-table-has',
'*-column-generic': 'column-column-generic',
'column-*-generic': 'column-column-generic',
'*-*-identity': 'column-column-identity'
}
nonlocal_relations = [
'question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic',
'table-table-fk', 'table-table-fkr', 'table-table-fkb', 'column-column-sametable',
'*-column-generic', 'column-*-generic', '*-*-identity', '*-table-generic',
'question-question-identity', 'table-table-identity', 'column-column-identity'] + [
'question-question-dist' + str(i) for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1) if i not in [-1, 0, 1]
]
class GraphProcessor():
def process_rgatsql(self, ex: dict, db: dict, relation: list):
graph = GraphExample()
num_nodes = int(math.sqrt(len(relation)))
local_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r))
for idx, r in enumerate(relation) if r not in nonlocal_relations]
nonlocal_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r))
for idx, r in enumerate(relation) if r in nonlocal_relations]
global_edges = local_edges + nonlocal_edges
src_ids, dst_ids = list(map(lambda r: r[0], global_edges)), list(map(lambda r: r[1], global_edges))
graph.global_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32)
graph.global_edges = global_edges
src_ids, dst_ids = list(map(lambda r: r[0], local_edges)), list(map(lambda r: r[1], local_edges))
graph.local_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32)
graph.local_edges = local_edges
# graph pruning for nodes
q_num = len(ex['processed_question_toks'])
s_num = num_nodes - q_num
graph.question_mask = [1] * q_num + [0] * s_num
graph.schema_mask = [0] * q_num + [1] * s_num
graph.gp = dgl.heterograph({
('question', 'to', 'schema'): (list(range(q_num)) * s_num,
[i for i in range(s_num) for _ in range(q_num)])
}, num_nodes_dict={'question': q_num, 'schema': s_num}, idtype=torch.int32
)
t_num = len(db['processed_table_toks'])
def check_node(i):
if i < t_num and i in ex['used_tables']:
return 1.0
elif i >= t_num and i - t_num in ex['used_columns']:
return 1.0
else: return 0.0
graph.node_label = list(map(check_node, range(s_num)))
ex['graph'] = graph
return ex
def process_lgesql(self, ex: dict, db: dict, relation: list):
ex = self.process_rgatsql(ex, db, relation)
graph = ex['graph']
lg = graph.local_g.line_graph(backtracking=False)
# prevent information propagate through matching edges
match_ids = [idx for idx, r in enumerate(graph.global_edges) if 'match' in r[2]]
src, dst, eids = lg.edges(form='all', order='eid')
eids = [e for u, v, e in zip(src.tolist(), dst.tolist(), eids.tolist()) if not (u in match_ids and v in match_ids)]
graph.lg = lg.edge_subgraph(eids, preserve_nodes=True).remove_self_loop().add_self_loop()
ex['graph'] = graph
return ex
def process_graph_utils(self, ex: dict, db: dict, method: str = 'rgatsql'):
""" Example should be preprocessed by self.pipeline
"""
q = np.array(ex['relations'], dtype='<U100')
s = np.array(db['relations'], dtype='<U100')
q_s = np.array(ex['schema_linking'][0], dtype='<U100')
s_q = np.array(ex['schema_linking'][1], dtype='<U100')
relation = np.concatenate([
np.concatenate([q, q_s], axis=1),
np.concatenate([s_q, s], axis=1)
], axis=0)
relation = relation.flatten().tolist()
if method == 'rgatsql':
ex = self.process_rgatsql(ex, db, relation)
elif method == 'lgesql':
ex = self.process_lgesql(ex, db, relation)
return ex

View File

@ -0,0 +1,112 @@
import os, sys
import json
import sqlite3
import traceback
import argparse
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from process_sql import get_sql
#TODO: update the following dirs
sql_path = 'data/train.json'
db_dir = 'data/database/'
output_file = 'train.json'
table_file = 'data/tables.json'
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
column_names_original = table['column_names_original']
table_names_original = table['table_names_original']
#print 'column_names_original: ', column_names_original
#print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {'*': i}
else:
key = table_names_original[tab_id].lower()
val = col.lower()
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
def get_schemas_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
db_names = [db['db_id'] for db in data]
tables = {}
schemas = {}
for db in data:
db_id = db['db_id']
schema = {} #{'table': [col.lower, ..., ]} * -> __all__
column_names_original = db['column_names_original']
table_names_original = db['table_names_original']
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower()) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
return schemas, db_names, tables
schemas, db_names, tables = get_schemas_from_json(table_file)
with open(sql_path) as inf:
sql_data = json.load(inf)
sql_data_new = []
for data in sql_data:
try:
if data['query'] == 'SELECT T1.company_name FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id JOIN Ref_Company_Types AS T3 ON T1.company_type_code = T3.company_type_code ORDER BY T2.contract_end_date DESC LIMIT 1':
data['query'] = 'SELECT T1.company_type FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id ORDER BY T2.contract_end_date DESC LIMIT 1'
data['query_toks'] = ['SELECT', 'T1.company_type', 'FROM', 'Third_Party_Companies', 'AS', 'T1', 'JOIN', 'Maintenance_Contracts', 'AS', 'T2', 'ON', 'T1.company_id', '=', 'T2.maintenance_contract_company_id', 'ORDER', 'BY', 'T2.contract_end_date', 'DESC', 'LIMIT', '1']
data['query_toks_no_value'] = ['select', 't1', '.', 'company_type', 'from', 'third_party_companies', 'as', 't1', 'join', 'maintenance_contracts', 'as', 't2', 'on', 't1', '.', 'company_id', '=', 't2', '.', 'maintenance_contract_company_id', 'order', 'by', 't2', '.', 'contract_end_date', 'desc', 'limit', 'value']
data['question'] = 'What is the type of the company who concluded its contracts most recently?'
data['question_toks'] = ['What', 'is', 'the', 'type', 'of', 'the', 'company', 'who', 'concluded', 'its', 'contracts', 'most', 'recently', '?']
if data['query'].startswith('SELECT T1.fname FROM student AS T1 JOIN lives_in AS T2 ON T1.stuid = T2.stuid WHERE T2.dormid IN'):
data['query'] = data['query'].replace('IN (SELECT T2.dormid)', 'IN (SELECT T3.dormid)')
index = data['query_toks'].index('(') + 2
assert data['query_toks'][index] == 'T2.dormid'
data['query_toks'][index] = 'T3.dormid'
index = data['query_toks_no_value'].index('(') + 2
assert data['query_toks_no_value'][index] == 't2'
data['query_toks_no_value'][index] = 't3'
db_id = data["db_id"]
schema = schemas[db_id]
table = tables[db_id]
schema = Schema(schema, table)
sql = data["query"]
sql_label = get_sql(schema, sql)
data["sql"] = sql_label
sql_data_new.append(data)
except:
print("db_id: ", db_id)
print("sql: ", sql)
with open(output_file, 'wt') as out:
json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(',', ': '))

View File

@ -0,0 +1,88 @@
import os
import traceback
import re
import sys
import json
import sqlite3
import sqlparse
import random
from os import listdir, makedirs
from collections import OrderedDict
from nltk import word_tokenize, tokenize
from os.path import isfile, isdir, join, split, exists, splitext
from process_sql import get_sql
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
column_names_original = table['column_names_original']
table_names_original = table['table_names_original']
#print 'column_names_original: ', column_names_original
#print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {'*': i}
else:
key = table_names_original[tab_id].lower()
val = col.lower()
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
def get_schemas_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
db_names = [db['db_id'] for db in data]
tables = {}
schemas = {}
for db in data:
db_id = db['db_id']
schema = {} #{'table': [col.lower, ..., ]} * -> __all__
column_names_original = db['column_names_original']
table_names_original = db['table_names_original']
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower()) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
return schemas, db_names, tables
if __name__ == '__main__':
sql = "SELECT name , country , age FROM singer ORDER BY age DESC"
db_id = "concert_singer"
table_file = "tables.json"
schemas, db_names, tables = get_schemas_from_json(table_file)
schema = schemas[db_id]
table = tables[db_id]
schema = Schema(schema, table)
sql_label = get_sql(schema, sql)

View File

@ -0,0 +1,73 @@
#coding=utf8
import os, json, pickle, argparse, sys, time
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from asdl.asdl import ASDLGrammar
from asdl.transition_system import TransitionSystem
from asdl.action_info import get_action_infos
from preprocess.common_utils import Preprocessor
def process_example(processor, entry, db, trans, verbose=False):
# preprocess raw tokens, schema linking and subgraph extraction
entry = processor.pipeline(entry, db, verbose=verbose)
# generate target output actions
ast = trans.surface_code_to_ast(entry['sql'])
actions = trans.get_actions(ast)
entry['ast'] = ast
entry['actions'] = get_action_infos(tgt_actions=actions)
return entry
def process_tables(processor, tables_list, output_path=None, verbose=False):
tables = {}
for each in tables_list:
if verbose:
print('*************** Processing database %s **************' % (each['db_id']))
tables[each['db_id']] = processor.preprocess_database(each, verbose=verbose)
print('In total, process %d databases .' % (len(tables)))
if output_path is not None:
pickle.dump(tables, open(output_path, 'wb'))
return tables
def process_dataset(processor, dataset, tables, output_path=None, skip_large=False, verbose=False):
from utils.constants import GRAMMAR_FILEPATH
grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
trans = TransitionSystem.get_class_by_lang('sql')(grammar)
processed_dataset = []
for idx, entry in enumerate(dataset):
if skip_large and len(tables[entry['db_id']]['column_names']) > 100: continue
if verbose:
print('*************** Processing %d-th sample **************' % (idx))
entry = process_example(processor, entry, tables[entry['db_id']], trans, verbose=verbose)
processed_dataset.append(entry)
print('In total, process %d samples , skip %d extremely large databases.' % (len(processed_dataset), len(dataset) - len(processed_dataset)))
if output_path is not None:
# serialize preprocessed dataset
pickle.dump(processed_dataset, open(output_path, 'wb'))
return processed_dataset
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--db_dir', type=str, default='data/database')
arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path')
arg_parser.add_argument('--raw_table_path', type=str, help='raw tables path')
arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path')
arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset')
arg_parser.add_argument('--skip_large', action='store_true', help='whether skip large databases')
arg_parser.add_argument('--verbose', action='store_true', help='whether print processing information')
args = arg_parser.parse_args()
processor = Preprocessor(db_dir=args.db_dir, db_content=True)
# loading database and dataset
if args.raw_table_path:
# need to preprocess database items
tables_list = json.load(open(args.raw_table_path, 'r'))
print('Firstly, preprocess the original databases ...')
start_time = time.time()
tables = process_tables(processor, tables_list, args.table_path, args.verbose)
print('Databases preprocessing costs %.4fs .' % (time.time() - start_time))
else:
tables = pickle.load(open(args.table_path, 'rb'))
dataset = json.load(open(args.dataset_path, 'r'))
start_time = time.time()
dataset = process_dataset(processor, dataset, tables, args.output_path, args.skip_large, verbose=args.verbose)
print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time))

View File

@ -0,0 +1,37 @@
#coding=utf8
import os, json, pickle, argparse, sys, time
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from preprocess.graph_utils import GraphProcessor
def process_dataset_graph(processor, dataset, tables, method, output_path=None, skip_large=False):
processed_dataset = []
for idx, entry in enumerate(dataset):
db = tables[entry['db_id']]
if skip_large and len(db['column_names']) > 100:
continue
if (idx + 1) % 500 == 0:
print('Processing the %d-th example ...' % (idx + 1))
entry = processor.process_graph_utils(entry, db, method=method)
processed_dataset.append(entry)
print('In total, process %d samples, skip %d samples .' % (len(processed_dataset), len(dataset) - len(processed_dataset)))
if output_path is not None:
# serialize preprocessed dataset
pickle.dump(processed_dataset, open(output_path, 'wb'))
return processed_dataset
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path')
arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path')
arg_parser.add_argument('--method', type=str, default='lgesql', choices=['rgatsql', 'lgesql'])
arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset')
args = arg_parser.parse_args()
processor = GraphProcessor()
# loading database and dataset
tables = pickle.load(open(args.table_path, 'rb'))
dataset = pickle.load(open(args.dataset_path, 'rb'))
start_time = time.time()
dataset = process_dataset_graph(processor, dataset, tables, args.method, args.output_path)
print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time))

20
proton/preprocess/test.py Normal file
View File

@ -0,0 +1,20 @@
# coding=utf8
# import os
# import numpy as np
# import torch
# import torch.nn.functional as F
# from transformers import AutoModel, AutoConfig, AutoTokenizer
# import matplotlib.pyplot as plt
from nltk.corpus import stopwords
def agg(input):
return torch.sum(input, dim=1, keepdim=True) / input.size(1)
if __name__ == '__main__':
stopwords = stopwords.words("english")
b='."?,'
a='the'
print(a in stopwords)

235
proton/preprocess/test_g.py Normal file
View File

@ -0,0 +1,235 @@
#coding=utf8
import os
import numpy as np
import torch
import stanza, torch
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig,AutoTokenizer
import matplotlib.pyplot as plt
from nltk.corpus import stopwords
def agg(input):
assert input.size(1)>=2
input=input[:,1:]
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
if __name__=='__main__':
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
stopwords = stopwords.words("english")
a='.'
print(a in stopwords)
device = torch.device("cuda:0")
hidden_size=1024
max_batch_size = 8
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt')).to(device)
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt'))
config = plm_model.config
raw_question_toks=['what', 'is', 'the', 'number', 'of', 'employees', '?']
column_names=['*', 'flight number', 'origin', 'destination', 'distance', 'departure date', 'arrival date', 'price', 'airline id', 'airline id', 'name', 'distance', 'employee id', 'name', 'salary', 'employee id', 'airline id']
table_names=['flight', 'aircraft', 'employee', 'certificate']
question = " ".join(raw_question_toks)
doc = nlp(question)
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
print('Q', len(raw_question_toks), raw_question_toks)
print('C', len(column_names), column_names)
print('T', len(table_names), table_names)
question_id = [plm_tokenizer.cls_token_id]
question = [q.lower() for q in raw_question_toks]
question_subword_len = []
for w in question:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) #+ [0]
#question_id.append(plm_tokenizer.sep_token_id)
masked_question_id = [question_id]
start=1
for i,sub_len in enumerate(question_subword_len):
tmp_question_id=question_id.copy()
for m in range(start,start+sub_len):
tmp_question_id[m]=plm_tokenizer.mask_token_id
masked_question_id.append(tmp_question_id)
start+=sub_len
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
#l = 1
toks = [plm_tokenizer.sep_token_id]
for w in s:
sub_toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
toks.extend(sub_toks)
#table_subword_len.append(len(toks))
#l += len(sub_toks)
table_id.extend(toks)
table_word_len.append(len(toks))
table_mask_plm = [1] * len(table_id)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
#l = 1
toks = [plm_tokenizer.sep_token_id]
for w in s:
sub_toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
toks.extend(sub_toks)
#column_subword_len.append(len(toks))
#l += len(sub_toks)
column_id.extend(toks)
column_word_len.append(len(toks))
column_mask_plm = [1] * len(column_id) #+ [0]
#exact_column_token=len(column_id)
#column_id.append(plm_tokenizer.sep_token_id)
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id=[]
segment_id=[]
atten_mask=[]
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
start=0
total_size=len(input_id)
#print(total_size)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
#print(outputs.size())
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
print(outputs.size())
for t in store_arr[1:]:
print(t.size())
outputs= torch.cat((outputs, t), dim=0)
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
start=0
new_table_arr = []
for i,sub_len in enumerate(table_word_len):
curr=old_tables[:, start:start+sub_len]
new_table_arr.append(agg(curr))
start+=sub_len
new_tables = torch.cat(new_table_arr, 1)
tbl_cmp=new_tables[0:1]
tbl_msk=new_tables[1:]
assert tbl_msk.size(0) == len(raw_question_toks)
for i in range(len(table_word_len)):
a=tbl_cmp[:,i]
b=tbl_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_tab_mat[:,i]=dis
#print(q_tab_mat.transpose(0,1).size())
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
print(q_tab_mat.transpose(0,1))
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(column_id), hidden_size)
new_column_arr = []
start=0
for i,sub_len in enumerate(column_word_len):
curr=old_columns[:,start:start+sub_len]
new_column_arr.append(agg(curr))
start+=sub_len
new_column = torch.cat(new_column_arr, 1)
col_cmp=new_column[0:1]
col_msk=new_column[1:]
assert col_msk.size(0) == len(raw_question_toks)
for i in range(len(column_word_len)):
a=col_cmp[:,i]
b=col_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_col_mat[:,i]=dis
#print(dis)
#print(q_col_mat.transpose(0,1).size())
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
print(q_col_mat.transpose(0,1))
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
matrix_min=torch.min(use_matrix)
print(matrix_min)
matrix_max=torch.max(use_matrix)
print(matrix_max)
matrix_mean=torch.mean(use_matrix)
print(matrix_mean)
matrix_var=torch.sqrt(torch.mean(use_matrix))
print(matrix_var)
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
#use_matrix=(use_matrix-matrix_mean)/matrix_var
print(use_matrix.size())
print(use_matrix)
# vegetables = table_names+column_names
# farmers = raw_question_toks
vegetables = raw_question_toks
farmers = table_names+column_names
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
fig, ax = plt.subplots(figsize=(15,15))
im = ax.imshow(harvest,cmap="YlGn")
#threshold = im.norm(use_matrix.max())/2
# We want to show all ticks...
ax.set_xticks(np.arange(len(farmers)))
ax.set_yticks(np.arange(len(vegetables)))
# ... and label them with the respective list entries
ax.set_xticklabels(farmers)
ax.set_yticklabels(vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(vegetables)):
for j in range(len(farmers)):
text = ax.text(j, i, harvest[i, j],
ha="center", va="center", color="w")
ax.set_title("Probing Matrix")
fig.tight_layout()
plt.show()
plt.savefig('1.jpg')

View File

@ -0,0 +1,247 @@
#coding=utf8
import os
import numpy as np
import torch
import stanza, torch
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig,AutoTokenizer
import matplotlib.pyplot as plt
def agg(input):
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
if __name__=='__main__':
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
device = torch.device("cuda:1")
hidden_size=1024
max_batch_size = 8
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
config = plm_model.config
raw_question_toks=['How', 'many', 'flights', 'depart', 'from', 'Jackson']
column_names=['*', 'airline id', 'airline name', 'abbreviation', 'country', 'city', 'airport code', 'airport name', 'country', 'country abbrev', 'airline', 'flight number', 'source airport', 'destination airport']
table_names=['airlines', 'airports', 'flights']
question = " ".join(raw_question_toks)
doc = nlp(question)
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
print('Q', len(raw_question_toks), raw_question_toks)
print('C', len(column_names), column_names)
print('T', len(table_names), table_names)
question_id = [plm_tokenizer.cls_token_id]
question = [q.lower() for q in raw_question_toks]
question_subword_len = []
for w in question:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
exact_question_token=len(question_id) - 1
question_id.append(plm_tokenizer.sep_token_id)
masked_question_id = [question_id]
start=1
for i,sub_len in enumerate(question_subword_len):
tmp_question_id=question_id.copy()
for m in range(start,start+sub_len):
tmp_question_id[m]=plm_tokenizer.mask_token_id
masked_question_id.append(tmp_question_id)
start+=sub_len
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
table_id.extend(toks)
table_subword_len.append(len(toks))
l += len(toks)
table_word_len.append(l)
table_mask_plm = [1] * len(table_id)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
column_id.extend(toks)
column_subword_len.append(len(toks))
l += len(toks)
column_word_len.append(l)
column_mask_plm = [1] * len(column_id) + [0]
exact_column_token=len(column_id)
column_id.append(plm_tokenizer.sep_token_id)
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id=[]
segment_id=[]
atten_mask=[]
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
start=0
total_size=len(input_id)
#print(total_size)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
#print(outputs.size())
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
#print(outputs.size())
for t in store_arr[1:]:
#print(t.size())
outputs= torch.cat((outputs, t), dim=0)
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
start=0
new_table_arr = []
maybe_use=[]
for i,sub_len in enumerate(table_word_len):
curr=old_tables[:, start:start+sub_len]
new_table_arr.append(agg(curr))
maybe_use.append(curr)
start+=sub_len
new_tables = torch.cat(new_table_arr, 1)
tbl_cmp=new_tables[0:1]
tbl_msk=new_tables[1:]
assert tbl_msk.size(0) == len(raw_question_toks)
for i in range(len(table_word_len)):
a=tbl_cmp[:,i] #1,hidden
b=tbl_msk[:,i] #qu_len,hidden
#maybe_use[i] qu_len,sublen,hidden
dis=F.pairwise_distance(a,b,p=2)
q_tab_mat[:,i]=dis
# for m in range(maybe_use[i].size(1)):
# tmp=maybe_use[i][1:,m]
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
# for k in range(pre_dis.size(0)):
# if pre_dis[k]>q_tab_mat[k,i]:
# q_tab_mat[k,i]=pre_dis[k]
#q_tab_mat[:,i]=pre_dis
#print(q_tab_mat.transpose(0,1).size())
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
print(q_tab_mat.transpose(0,1))
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
new_column_arr = []
start=0
maybe_use=[]
for i,sub_len in enumerate(column_word_len):
curr=old_columns[:,start:start+sub_len]
new_column_arr.append(agg(curr))
maybe_use.append(curr)
start+=sub_len
new_column = torch.cat(new_column_arr, 1)
col_cmp=new_column[0:1]
col_msk=new_column[1:]
assert col_msk.size(0) == len(raw_question_toks)
for i in range(len(column_word_len)):
a=col_cmp[:,i]
b=col_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_col_mat[:,i]=dis
# for m in range(maybe_use[i].size(1)):
# tmp=maybe_use[i][1:,m]
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
# for k in range(pre_dis.size(0)):
# if pre_dis[k]>q_col_mat[k,i]:
# q_col_mat[k,i]=pre_dis[k]
#print(dis)
#print(q_col_mat.transpose(0,1).size())
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
print(q_col_mat.transpose(0,1))
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
matrix_min=torch.min(use_matrix)
#print(matrix_min)
matrix_max=torch.max(use_matrix)
#print(matrix_max)
matrix_mean=torch.mean(use_matrix)
#print(matrix_mean)
matrix_var=torch.sqrt(torch.mean(use_matrix))
#print(matrix_var)
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
#use_matrix=(use_matrix-matrix_mean)/matrix_var
#print(use_matrix.size())
#print(use_matrix)
# vegetables = table_names+column_names
# farmers = raw_question_toks
vegetables = raw_question_toks
farmers = table_names+column_names
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
print()
fig, ax = plt.subplots(figsize=(25,25))
im = ax.imshow(harvest,cmap="YlGn")
#threshold = im.norm(use_matrix.max())/2
# We want to show all ticks...
ax.set_xticks(np.arange(len(farmers)))
ax.set_yticks(np.arange(len(vegetables)))
# ... and label them with the respective list entries
ax.set_xticklabels(farmers)
ax.set_yticklabels(vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(vegetables)):
for j in range(len(farmers)):
text = ax.text(j, i, harvest[i, j],
ha="center", va="center", color="w")
ax.set_title("Probing Matrix")
fig.tight_layout()
plt.show()
plt.savefig('1.jpg')

View File

@ -0,0 +1,269 @@
#coding=utf8
import os
import numpy as np
import torch
import stanza, torch
import torch.nn.functional as F
from itertools import product, combinations
from transformers import AutoModel, AutoConfig,AutoTokenizer
import matplotlib.pyplot as plt
def agg(input):
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
if __name__=='__main__':
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
device = torch.device("cuda:0")
hidden_size=1024
max_batch_size = 8
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt')).to(device)
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt'))
config = plm_model.config
raw_question_toks=['list', 'the', 'most', 'common', 'result', 'of', 'the', 'musicals', '.']
column_names=['*', 'musical id', 'name', 'year', 'award', 'category', 'nominee', 'result', 'actor id', 'name', 'musical id', 'character', 'duration', 'age']
table_names=['musical', 'actor']
question = " ".join(raw_question_toks)
doc = nlp(question)
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
q_num, t_num, c_num = len(raw_question_toks), len(table_names), len(column_names)
t_q_dict = {}
c_q_dict = {}
max_len = max([len(t.split()) for t in table_names])
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(raw_question_toks[i: j])
for idx, name in enumerate(table_names):
if phrase == name:
if idx in t_q_dict.keys():
t_q_dict[idx].append([phrase,i,j])
else:
t_q_dict[idx]=[[phrase,i,j]]
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
if idx in t_q_dict.keys():
t_q_dict[idx].append([phrase,i,j])
else:
t_q_dict[idx]=[[phrase,i,j]]
print(t_q_dict)
print('Q', len(raw_question_toks), raw_question_toks)
print('C', len(column_names), column_names)
print('T', len(table_names), table_names)
question_id = [plm_tokenizer.cls_token_id]
question = [q.lower() for q in raw_question_toks]
question_subword_len = []
for w in question:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
exact_question_token=len(question_id) - 1
question_id.append(plm_tokenizer.sep_token_id)
masked_question_id = [question_id]
start=1
for i,sub_len in enumerate(question_subword_len):
tmp_question_id=question_id.copy()
for m in range(start,start+sub_len):
tmp_question_id[m]=plm_tokenizer.mask_token_id
masked_question_id.append(tmp_question_id)
start+=sub_len
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
table_id.extend(toks)
table_subword_len.append(len(toks))
l += len(toks)
table_word_len.append(l)
table_mask_plm = [1] * len(table_id)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
column_id.extend(toks)
column_subword_len.append(len(toks))
l += len(toks)
column_word_len.append(l)
column_mask_plm = [1] * len(column_id) + [0]
exact_column_token=len(column_id)
column_id.append(plm_tokenizer.sep_token_id)
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id=[]
segment_id=[]
atten_mask=[]
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
start=0
total_size=len(input_id)
#print(total_size)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
outputs = plm_model(ii,im)[0]#.squeeze()
store_arr.append(outputs)
#print(outputs.size())
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
#print(outputs.size())
for t in store_arr[1:]:
#print(t.size())
outputs= torch.cat((outputs, t), dim=0)
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
start=0
new_table_arr = []
maybe_use=[]
for i,sub_len in enumerate(table_word_len):
curr=old_tables[:, start:start+sub_len]
new_table_arr.append(agg(curr))
maybe_use.append(curr)
start+=sub_len
new_tables = torch.cat(new_table_arr, 1)
tbl_cmp=new_tables[0:1]
tbl_msk=new_tables[1:]
assert tbl_msk.size(0) == len(raw_question_toks)
for i in range(len(table_word_len)):
a=tbl_cmp[:,i] #1,hidden
b=tbl_msk[:,i] #qu_len,hidden
#maybe_use[i] qu_len,sublen,hidden
dis=F.pairwise_distance(a,b,p=2)
q_tab_mat[:,i]=dis
# for m in range(maybe_use[i].size(1)):
# tmp=maybe_use[i][1:,m]
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
# for k in range(pre_dis.size(0)):
# if pre_dis[k]>q_tab_mat[k,i]:
# q_tab_mat[k,i]=pre_dis[k]
#q_tab_mat[:,i]=pre_dis
#print(q_tab_mat.transpose(0,1).size())
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
print(q_tab_mat.transpose(0,1))
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
new_column_arr = []
start=0
maybe_use=[]
for i,sub_len in enumerate(column_word_len):
curr=old_columns[:,start:start+sub_len]
new_column_arr.append(agg(curr))
maybe_use.append(curr)
start+=sub_len
new_column = torch.cat(new_column_arr, 1)
col_cmp=new_column[0:1]
col_msk=new_column[1:]
assert col_msk.size(0) == len(raw_question_toks)
for i in range(len(column_word_len)):
a=col_cmp[:,i]
b=col_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_col_mat[:,i]=dis
# for m in range(maybe_use[i].size(1)):
# tmp=maybe_use[i][1:,m]
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
# for k in range(pre_dis.size(0)):
# if pre_dis[k]>q_col_mat[k,i]:
# q_col_mat[k,i]=pre_dis[k]
#print(dis)
#print(q_col_mat.transpose(0,1).size())
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
print(q_col_mat.transpose(0,1))
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
matrix_min=torch.min(use_matrix)
#print(matrix_min)
matrix_max=torch.max(use_matrix)
#print(matrix_max)
matrix_mean=torch.mean(use_matrix)
#print(matrix_mean)
matrix_var=torch.sqrt(torch.mean(use_matrix))
#print(matrix_var)
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
#use_matrix=(use_matrix-matrix_mean)/matrix_var
#print(use_matrix.size())
#print(use_matrix)
# vegetables = table_names+column_names
# farmers = raw_question_toks
vegetables = raw_question_toks
farmers = table_names+column_names
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
print()
fig, ax = plt.subplots(figsize=(15,15))
im = ax.imshow(harvest,cmap="YlGn")
#threshold = im.norm(use_matrix.max())/2
# We want to show all ticks...
ax.set_xticks(np.arange(len(farmers)))
ax.set_yticks(np.arange(len(vegetables)))
# ... and label them with the respective list entries
ax.set_xticklabels(farmers)
ax.set_yticklabels(vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(vegetables)):
for j in range(len(farmers)):
text = ax.text(j, i, harvest[i, j],
ha="center", va="center", color="w")
ax.set_title("Probing Matrix")
fig.tight_layout()
plt.show()
plt.savefig('1.jpg')

View File

@ -0,0 +1,223 @@
#coding=utf8
import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig,AutoTokenizer
import matplotlib.pyplot as plt
def agg(input):
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
if __name__=='__main__':
device = torch.device("cuda:0")
hidden_size=512
max_batch_size = 8
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
config = plm_model.config
raw_question_toks=['list', 'the', 'most', 'common', 'result', 'of', 'the', 'musicals', '.']
column_names=['*', 'musical id', 'name', 'year', 'award', 'category', 'nominee', 'result', 'actor id', 'name', 'musical id', 'character', 'duration', 'age']
table_names=['musical', 'actor']
print('Q', len(raw_question_toks), raw_question_toks)
print('C', len(column_names), column_names)
print('T', len(table_names), table_names)
question_id = [plm_tokenizer.cls_token_id]
question = [q.lower() for q in raw_question_toks]
question_subword_len = []
for w in question:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
question_id.extend(toks)
question_subword_len.append(len(toks))
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
exact_question_token=len(question_id) - 1
question_id.append(plm_tokenizer.sep_token_id)
masked_question_id = [question_id]
start=1
for i,sub_len in enumerate(question_subword_len):
tmp_question_id=question_id.copy()
for m in range(start,start+sub_len):
tmp_question_id[m]=plm_tokenizer.mask_token_id
masked_question_id.append(tmp_question_id)
start+=sub_len
table = [t.lower().split() for t in table_names]
table_id, table_mask_plm, table_subword_len = [], [], []
table_word_len = []
for s in table:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
table_id.extend(toks)
table_subword_len.append(len(toks))
l += len(toks)
table_word_len.append(l)
table_mask_plm = [1] * len(table_id)
column = [t.lower().split() for t in column_names]
column_id, column_mask_plm, column_subword_len = [], [], []
column_word_len = []
for s in column:
l = 0
for w in s:
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
column_id.extend(toks)
column_subword_len.append(len(toks))
l += len(toks)
column_word_len.append(l)
column_mask_plm = [1] * len(column_id) + [0]
exact_column_token=len(column_id)
column_id.append(plm_tokenizer.sep_token_id)
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
input_id=[]
segment_id=[]
atten_mask=[]
for i, msk_q_id in enumerate(masked_question_id):
input_id.append(msk_q_id + table_id + column_id)
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
atten_mask.append([1]*len(input_id[-1]))
start=0
total_size=len(input_id)
#print(total_size)
store_arr=[]
if total_size <= max_batch_size:
ii=torch.tensor(input_id, dtype=torch.long, device=device)
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
si=torch.tensor(segment_id, dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0].squeeze()
store_arr.append(outputs)
else:
while start < len(input_id):
if start + max_batch_size <= len(input_id):
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0]#.squeeze()
store_arr.append(outputs)
else:
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
outputs = plm_model(ii,im,si)[0]#.squeeze()
store_arr.append(outputs)
#print(outputs.size())
start+=max_batch_size
assert len(store_arr)>0
if len(store_arr)==1:
outputs=store_arr[0]
else:
outputs=store_arr[0]
print(outputs.size())
for t in store_arr[1:]:
print(t.size())
outputs= torch.cat((outputs, t), dim=0)
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
start=0
new_table_arr = []
for i,sub_len in enumerate(table_word_len):
curr=old_tables[:, start:start+sub_len]
new_table_arr.append(agg(curr))
start+=sub_len
new_tables = torch.cat(new_table_arr, 1)
tbl_cmp=new_tables[0:1]
tbl_msk=new_tables[1:]
assert tbl_msk.size(0) == len(raw_question_toks)
for i in range(len(table_word_len)):
a=tbl_cmp[:,i]
b=tbl_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_tab_mat[:,i]=dis
#print(q_tab_mat.transpose(0,1).size())
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
print(q_tab_mat.transpose(0,1))
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
new_column_arr = []
start=0
for i,sub_len in enumerate(column_word_len):
curr=old_columns[:,start:start+sub_len]
new_column_arr.append(agg(curr))
start+=sub_len
new_column = torch.cat(new_column_arr, 1)
col_cmp=new_column[0:1]
col_msk=new_column[1:]
assert col_msk.size(0) == len(raw_question_toks)
for i in range(len(column_word_len)):
a=col_cmp[:,i]
b=col_msk[:,i]
dis=F.pairwise_distance(a,b,p=2)
q_col_mat[:,i]=dis
#print(dis)
#print(q_col_mat.transpose(0,1).size())
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
print(q_col_mat.transpose(0,1))
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
matrix_min=torch.min(use_matrix)
print(matrix_min)
matrix_max=torch.max(use_matrix)
print(matrix_max)
matrix_mean=torch.mean(use_matrix)
print(matrix_mean)
matrix_var=torch.sqrt(torch.mean(use_matrix))
print(matrix_var)
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
#use_matrix=(use_matrix-matrix_mean)/matrix_var
print(use_matrix.size())
print(use_matrix)
# vegetables = table_names+column_names
# farmers = raw_question_toks
vegetables = raw_question_toks
farmers = table_names+column_names
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
fig, ax = plt.subplots(figsize=(15,15))
im = ax.imshow(harvest,cmap="YlGn")
#threshold = im.norm(use_matrix.max())/2
# We want to show all ticks...
ax.set_xticks(np.arange(len(farmers)))
ax.set_yticks(np.arange(len(vegetables)))
# ... and label them with the respective list entries
ax.set_xticklabels(farmers)
ax.set_yticklabels(vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(vegetables)):
for j in range(len(farmers)):
text = ax.text(j, i, harvest[i, j],
ha="center", va="center", color="w")
ax.set_title("Probing Matrix")
fig.tight_layout()
plt.show()
plt.savefig('1.jpg')

665
proton/process_sql.py Normal file
View File

@ -0,0 +1,665 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self): # map lowercased raw tab name to lowercased raw col name list
return self._schema
@property
def idMap(self): # map tab name to __tab__, tab.col to __tab.col__, * to __all__, all lowercased
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1] # overwritten problem
return alias
def check_contradiction(toks):
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
a = toks[idx+1]
if a in alias and alias[a] != toks[idx-1]:
return True
alias[a] = toks[idx-1]
return False
def toks2nested(toks):
"""
Determine the scope for each sub-sql
mapping [select, count, (, c1, ), from, (, select c1, c2, from, t, ), ... ] into
[select, count, (, c1, ), from, [select, c1, c2, from, t], ... ]
"""
def detect_sql(idx):
count, sql_list = 0, []
while idx < len(toks):
if toks[idx] == '(':
count += 1
if toks[idx + 1] == 'select':
sub_sql_list, idx = detect_sql(idx + 1)
count -= 1
sql_list.append(sub_sql_list)
else:
sql_list.append('(')
elif toks[idx] == ')':
count -= 1
if count < 0:
return sql_list, idx
else:
sql_list.append(')')
else:
sql_list.append(toks[idx])
idx += 1
return sql_list, idx
def intersect_union_except(tok_list):
for idx, tok in enumerate(tok_list):
if type(tok) == list:
new_tok = intersect_union_except(tok)
tok_list[idx] = new_tok
for op in ['intersect', 'union', 'except']:
if op in tok_list:
idx = tok_list.index(op)
tok_list = [tok_list[:idx]] + [op] + [tok_list[idx+1:]]
break
return tok_list
try:
nested_toks, _ = detect_sql(0)
# not all sqls are wrapped with (), e.g. sql1 intersect sql2
# add wrapper for each sql on the left and right handside of intersect/union/except
nested_toks = intersect_union_except(nested_toks)
return nested_toks
except:
print('Something unknown happened when transforming %s' % (' '.join(toks)))
return None
def reassign_table_alias(nested_toks, index):
current_map = {} # map old alias in the current sql to new alias in global map
as_idxs = [idx for idx, tok in enumerate(nested_toks) if tok == 'as']
for idx in as_idxs:
index += 1 # add 1 to global index for table alias before assignment
assert nested_toks[idx+1] not in current_map
current_map[nested_toks[idx+1]] = 't' + str(index)
nested_toks[idx+1] = 't' + str(index)
for j, tok in enumerate(nested_toks):
if type(tok) == list:
new_tok, index = reassign_table_alias(tok, index)
nested_toks[j] = new_tok
elif '.' in tok:
for alias in current_map.keys():
if tok.startswith(alias + '.'):
nested_toks[j] = current_map[alias] + '.' + tok[tok.index('.')+1:]
break
return nested_toks, index
def normalize_table_alias(toks):
""" Make sure that different table alias are assigned to different tables """
if toks.count('select') == 1:
return toks # no nested sql, don't worry
elif toks.count('select') > 1:
flag = check_contradiction(toks)
if flag: # avoid unnecessary normalization process
nested_toks = toks2nested(toks)
index = 0 # global index for table alias
nested_toks, _ = reassign_table_alias(nested_toks, index)
def flatten(x):
if type(x) == list and ('intersect' in x or 'union' in x or 'except' in x):
assert len(x) == 3 # sql1 union sql2
return ['('] + flatten(x[0])[1:-1] + [x[1]] + flatten(x[2])[1:-1] + [')']
elif type(x) == list:
return ['('] + [y for l in x for y in flatten(l)] + [')']
else:
return [x]
toks = flatten(nested_toks)[1:-1]
return toks
else:
raise ValueError('Something wrong in sql because no select clause is found!')
def get_tables_with_alias(schema, toks):
toks = normalize_table_alias(toks)
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables, toks
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
return idx, int(toks[idx-1])
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema):
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias, toks = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

5
proton/requirements.txt Normal file
View File

@ -0,0 +1,5 @@
dgl-cu101==0.5.3
embeddings==0.0.8
transformers==3.5.1
stanza==1.1.1
nltk==3.5

View File

@ -0,0 +1,825 @@
medium pred: SELECT singer.Name , singer.Country FROM singer ORDER BY singer.Country DESC
medium gold: SELECT name , country FROM singer ORDER BY birthday ASC
medium pred: SELECT singer.Name , singer.Country FROM singer ORDER BY singer.Birthday DESC
medium gold: SELECT name , country FROM singer ORDER BY birthday ASC
medium pred: SELECT singer.Song_Name , singer.Song_release_year FROM singer ORDER BY singer.Birthday ASC LIMIT 1
medium gold: SELECT song_name , song_release_year FROM singer ORDER BY birthday desc LIMIT 1
medium pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value"
medium gold: SELECT DISTINCT country FROM singer WHERE birthday like '2001%'
medium pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value"
medium gold: SELECT DISTINCT country FROM singer WHERE birthday like '2001%'
hard pred: SELECT singer.Song_Name FROM singer WHERE singer.Song_Name > ( SELECT AVG(singer.Song_Name) FROM singer )
hard gold: SELECT song_name FROM singer WHERE birthday < (SELECT avg(birthday) FROM singer)
hard pred: SELECT singer.Song_Name FROM singer WHERE singer.Is_male > ( SELECT AVG(singer.Is_male) FROM singer )
hard gold: SELECT song_name FROM singer WHERE birthday < (SELECT avg(birthday) FROM singer)
medium pred: SELECT AVG(stadium.Average) , MAX(stadium.Capacity) FROM stadium
medium gold: SELECT avg(capacity) , max(capacity) FROM stadium
medium pred: SELECT AVG(stadium.Capacity) , MAX(stadium.Highest) FROM stadium
medium gold: SELECT avg(capacity) , max(capacity) FROM stadium
medium pred: SELECT stadium.Stadium_ID , COUNT(*) FROM stadium JOIN concert GROUP BY stadium.Stadium_ID
medium gold: SELECT T2.name , count(*) FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id GROUP BY T1.stadium_id
extra pred: SELECT stadium.Name , stadium.Highest FROM stadium JOIN concert WHERE concert.Year > "value" GROUP BY concert.Stadium_ID ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T2.name , T2.capacity FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.year >= 2014 GROUP BY T2.stadium_id ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT MAX(stadium.Highest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert )
hard gold: SELECT highest FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
hard pred: SELECT MIN(stadium.Lowest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert )
hard gold: SELECT lowest FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
extra pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value" OR singer.Birthday = "value"
extra gold: SELECT country FROM singer WHERE birthday like '1981%' or birthday like '1991%'
hard pred: SELECT MIN(stadium.Lowest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert WHERE concert.Year = "value" )
hard gold: SELECT lowest FROM stadium EXCEPT SELECT T2.lowest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.year = 2014
medium pred: SELECT singer.Name , COUNT(*) FROM singer_in_concert JOIN singer GROUP BY singer.Name
medium gold: SELECT T2.name , count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id GROUP BY T2.singer_id
hard pred: SELECT singer.Name FROM singer_in_concert JOIN singer JOIN concert WHERE concert.Year < "value"
hard gold: SELECT T2.name FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id JOIN concert AS T3 ON T1.concert_id = T3.concert_id WHERE T3.year <= 2014
extra pred: SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value" OR concert.Year = "value"
extra gold: SELECT T2.lowest , T2.highest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2014 INTERSECT SELECT T2.name , T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2015
extra pred: SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value" INTERSECT SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value"
extra gold: SELECT T2.lowest , T2.highest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2014 INTERSECT SELECT T2.name , T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2015
medium pred: SELECT Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
medium gold: SELECT weight FROM pets ORDER BY birthdate desc LIMIT 1
medium pred: SELECT Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
medium gold: SELECT weight FROM pets ORDER BY birthdate desc LIMIT 1
medium pred: SELECT MAX(Pets.weight) , MAX(Pets.weight) , Pets.PetType FROM Pets GROUP BY Pets.PetType
medium gold: SELECT max(weight) , petType FROM pets GROUP BY petType
hard pred: SELECT COUNT(*) FROM Student JOIN Has_Pet JOIN Student WHERE Student.Sex = "value"
hard gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T2.petid = T3.petid WHERE T1.sex = 'F' AND T3.pettype = 'dog'
hard pred: SELECT COUNT(*) FROM Has_Pet JOIN Student JOIN Has_Pet WHERE Student.Sex = "value"
hard gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T2.petid = T3.petid WHERE T1.sex = 'F' AND T3.pettype = 'dog'
extra pred: SELECT Student.LName FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" INTERSECT SELECT Student.LName FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
extra gold: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog'
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet WHERE Has_Pet.StuID = "value"
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet WHERE Has_Pet.StuID = "value"
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.StuID = "value"
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
extra pred: SELECT Student.Fname , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
extra pred: SELECT Student.Fname FROM Student WHERE Student.StuID = "value" EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
medium pred: SELECT Pets.PetType , Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
medium gold: SELECT pettype , weight FROM pets ORDER BY birthdate desc LIMIT 1
medium pred: SELECT Pets.PetType , Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
medium gold: SELECT pettype , weight FROM pets ORDER BY birthdate desc LIMIT 1
medium pred: SELECT Pets.PetID , Pets.weight FROM Pets WHERE Pets.birthdate > "value"
medium gold: SELECT petid , weight FROM pets WHERE birthdate < '2020-01-01'
medium pred: SELECT Pets.PetID , Pets.weight FROM Pets WHERE Pets.birthdate > "value"
medium gold: SELECT petid , weight FROM pets WHERE birthdate < '2020-05-01'
medium pred: SELECT Student.Fname , Student.LName FROM Student JOIN Has_Pet
medium gold: SELECT DISTINCT T1.fname , T1.LName , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid
medium pred: SELECT Student.Fname , Student.LName FROM Student JOIN Has_Pet
medium gold: SELECT DISTINCT T1.fname , T1.LName T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid
medium pred: SELECT COUNT(*) , Has_Pet.StuID FROM Has_Pet GROUP BY Has_Pet.StuID
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
medium pred: SELECT Has_Pet.StuID , COUNT(*) FROM Has_Pet GROUP BY Has_Pet.StuID
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
extra pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.birthdate = "value"
extra gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.birthdate like '2001%' AND T3.pettype = 'cat'
extra pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.birthdate = "value"
extra gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.birthdate like '2001%' AND T3.pettype = 'cat'
extra pred: SELECT AVG(Student.Age) FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet )
extra gold: SELECT avg(age) FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid)
extra pred: SELECT AVG(Student.Age) FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet )
extra gold: SELECT avg(age) FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid)
extra pred: SELECT car_makers.Maker FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year >= '2019';
extra pred: SELECT car_makers.FullName FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year >= '2019';
extra pred: SELECT car_names.Make , cars_data.Year FROM cars_data JOIN car_names ORDER BY cars_data.Year ASC LIMIT 1
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
extra pred: SELECT car_makers.Maker , cars_data.Year FROM cars_data JOIN car_names JOIN model_list JOIN car_makers WHERE cars_data.Year = "value" ORDER BY cars_data.Year ASC LIMIT 1
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
hard pred: SELECT car_names.Model FROM cars_data JOIN car_names WHERE cars_data.Year <= "value"
hard gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.model = T2.model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.id WHERE T3.year <= 1980;
medium pred: SELECT COUNT(*) , car_makers.FullName FROM car_makers JOIN model_list GROUP BY car_makers.Maker
medium gold: SELECT Count(*) , T2.FullName , T2.id FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id GROUP BY T2.id;
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
medium pred: SELECT COUNT(*) FROM car_makers WHERE car_makers.Country = "value"
medium gold: SELECT count(*) FROM CAR_MAKERS AS T1 JOIN COUNTRIES AS T2 ON T1.Country = T2.CountryId WHERE T2.CountryName = 'Japan';
medium pred: SELECT COUNT(*) FROM car_makers WHERE car_makers.Country = "value"
medium gold: SELECT count(*) FROM CAR_MAKERS AS T1 JOIN COUNTRIES AS T2 ON T1.Country = T2.CountryId WHERE T2.CountryName = 'Japan';
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
hard pred: SELECT MIN(cars_data.Weight) FROM cars_data WHERE cars_data.Year = "value" AND cars_data.Cylinders = "value"
hard gold: SELECT Weight FROM CARS_DATA WHERE Cylinders = 4 AND YEAR = 1974 ORDER BY Weight ASC LIMIT 1;
hard pred: SELECT MIN(cars_data.Weight) FROM cars_data WHERE cars_data.Year = "value" AND cars_data.Cylinders = "value"
hard gold: SELECT Weight FROM CARS_DATA WHERE Cylinders = 4 AND YEAR = 1974 ORDER BY Weight ASC LIMIT 1;
medium pred: SELECT car_makers.Maker , model_list.Model FROM car_makers JOIN model_list
medium gold: SELECT Maker , Model FROM MODEL_LIST;
medium pred: SELECT countries.CountryName , car_makers.Id FROM car_makers JOIN countries
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
medium pred: SELECT countries.CountryName , countries.CountryId FROM car_makers JOIN countries
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
medium pred: SELECT AVG(cars_data.Weight) , MAX(cars_data.Weight) , cars_data.Year FROM cars_data GROUP BY cars_data.Year
medium gold: SELECT avg(Weight) , YEAR FROM CARS_DATA GROUP BY YEAR;
extra pred: SELECT countries.CountryName FROM countries JOIN car_makers WHERE countries.CountryName = "value" GROUP BY countries.CountryName HAVING COUNT(*) >= "value"
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
extra pred: SELECT MAX(cars_data.Horsepower) , MAX(car_names.Make) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
extra pred: SELECT MAX(car_names.Make) , MAX(cars_data.Horsepower) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
hard pred: SELECT cars_data.MPG FROM cars_data ORDER BY cars_data.MPG DESC LIMIT 1
hard gold: SELECT T1.Model FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id ORDER BY T2.mpg DESC LIMIT 1;
easy pred: SELECT COUNT(*) FROM cars_data ORDER BY cars_data.Year DESC LIMIT 1
easy gold: SELECT count(*) FROM CARS_DATA WHERE YEAR >= 2019;
easy pred: SELECT COUNT(*) FROM cars_data ORDER BY cars_data.Year DESC LIMIT 1
easy gold: SELECT count(*) FROM CARS_DATA WHERE YEAR >= 2019;
medium pred: SELECT car_makers.FullName , car_makers.Id , COUNT(*) FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
medium gold: SELECT T1.FullName , T1.Id FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) > 3;
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
extra pred: SELECT model_list.Model FROM model_list JOIN car_makers JOIN cars_data WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
medium pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight BETWEEN "value" AND "value"
medium gold: SELECT DISTINCT T1.Year FROM CARS_DATA AS T1 WHERE T1.Weight > 3000 AND T1.weight < 4000;
medium pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight < "value" INTERSECT SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight > "value"
medium gold: SELECT DISTINCT T1.Year FROM CARS_DATA AS T1 WHERE T1.Weight > 3000 AND T1.weight < 4000;
extra pred: SELECT cars_data.Cylinders FROM cars_data WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'tesla' ORDER BY T1.accelerate ASC LIMIT 1;
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
easy pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
easy gold: SELECT COUNT(*) FROM ( SELECT T1.CountryId , COUNT(*) FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) > 2 );
easy pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
easy gold: SELECT COUNT(*) FROM ( SELECT T1.CountryId , COUNT(*) FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) > 2 );
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value" EXCEPT SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value"
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders < "value"
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
extra pred: SELECT MAX(cars_data.MPG) FROM cars_data WHERE cars_data.Cylinders = "value" OR cars_data.Year < "value"
extra gold: SELECT mpg FROM CARS_DATA WHERE Cylinders = 8 OR YEAR <= 1980 ORDER BY mpg DESC LIMIT 1;
extra pred: SELECT MAX(cars_data.MPG) FROM cars_data WHERE cars_data.Cylinders = "value" OR cars_data.Year < "value"
extra gold: SELECT mpg FROM CARS_DATA WHERE Cylinders = 8 OR YEAR <= 1980 ORDER BY mpg DESC LIMIT 1;
extra pred: SELECT model_list.Model FROM cars_data JOIN car_names JOIN car_makers WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM cars_data JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value"
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
extra pred: SELECT model_list.Model FROM model_list JOIN car_names WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_names.Make = "value"
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
extra pred: SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
extra pred: SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY model_list.Maker HAVING COUNT(*) > "value"
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN model_list JOIN countries JOIN countries JOIN model_list WHERE model_list.Model = "value"
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'tesla';
extra pred: SELECT countries.CountryId , countries.CountryName FROM countries JOIN car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM countries JOIN car_makers WHERE car_makers.Country = "value"
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'tesla';
medium pred: SELECT COUNT(*) FROM flights WHERE flights.SourceAirport = "value"
medium gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.SourceAirport = T2.AirportCode WHERE T2.City = "Jackson"
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Syracuse"
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Syracuse"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.DestAirport = "ASY"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airports.AirportName = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.DestAirport = "ASY"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.SourceAirport = "AHD"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airports.AirportName = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.SourceAirport = "AHD"
extra pred: SELECT airports.City FROM airports JOIN flights GROUP BY flights.SourceAirport ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.City FROM AIRPORTS AS T1 JOIN FLIGHTS AS T2 ON T1.AirportCode = T2.SourceAirport GROUP BY T1.City ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT airlines.Abbreviation , airlines.Country FROM airlines JOIN flights GROUP BY flights.Airline ORDER BY COUNT(*) ASC LIMIT 1
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
extra pred: SELECT airlines.Abbreviation , airports.Country FROM airlines JOIN airports JOIN flights JOIN flights GROUP BY airlines.Abbreviation ORDER BY COUNT(*) ASC LIMIT 1
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "AHD"
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.DestAirport = "AHD"
extra pred: SELECT airlines.Airline FROM airports JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE flights.SourceAirport = "value"
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
extra pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airports JOIN airlines JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE airlines.Airline = "value"
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
medium pred: SELECT airlines.Airline FROM airlines JOIN flights GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
medium pred: SELECT airlines.Airline FROM flights JOIN airlines GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
easy pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
easy gold: SELECT FlightNo FROM FLIGHTS WHERE DestAirport = "APG"
medium pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
medium gold: SELECT T1.FlightNo FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.SourceAirport = T2.AirportCode WHERE T2.City = "Jackson"
medium pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
medium gold: SELECT T1.FlightNo FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode WHERE T2.City = "Jackson"
hard pred: SELECT airports.AirportName FROM airports WHERE airports.AirportCode NOT IN ( SELECT flights.SourceAirport FROM flights )
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
hard pred: SELECT airports.AirportName FROM airports EXCEPT SELECT airports.AirportName FROM airports JOIN flights
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
medium pred: SELECT Documents.Document_ID , Documents.Document_Name FROM Documents
medium gold: SELECT document_id , document_name , document_description FROM Documents
medium pred: SELECT Documents.Document_Name , Documents.Template_ID FROM Documents WHERE Documents.Document_Description LIKE "value"
medium gold: SELECT document_name , Document_ID, template_id FROM Documents WHERE Document_Description LIKE "%w%"
medium pred: SELECT Documents.Document_Description , Documents.Template_ID FROM Documents WHERE Documents.Document_Name = "value"
medium gold: SELECT document_id , template_id , Document_Description FROM Documents WHERE document_name = "Robbin CV"
medium pred: SELECT Documents.Document_Description , Documents.Template_ID , Documents.Document_Name FROM Documents WHERE Documents.Document_Name = "value"
medium gold: SELECT document_id , template_id , Document_Description FROM Documents WHERE document_name = "Robbin CV"
extra pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents GROUP BY Documents.Template_ID ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Documents AS T1 JOIN Templates AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents GROUP BY Documents.Template_ID ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Documents AS T1 JOIN Templates AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_id ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT Templates.Date_Effective_From , Templates.Template_Type_Code FROM Templates
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates
medium pred: SELECT Templates.Date_Effective_To , Templates.Template_Type_Code FROM Templates
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates
medium pred: SELECT Templates.Date_Effective_To FROM Templates
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates
medium pred: SELECT Templates.Date_Effective_From FROM Templates
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates
extra pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value" OR Templates.Template_Type_Code = "value"
extra gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "PP" OR template_type_code = "PPT"
extra pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value" OR Templates.Template_Type_Code = "value"
extra gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "PP" OR template_type_code = "PPT"
medium pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "CV"
medium pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "CV"
medium pred: SELECT Templates.Date_Effective_To , Templates.Template_Type_Code FROM Templates WHERE Templates.Version_Number > "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates WHERE version_number > 5
medium pred: SELECT Templates.Date_Effective_From , Templates.Template_Type_Code FROM Templates WHERE Templates.Version_Number > "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates WHERE version_number > 5
medium pred: SELECT Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Date_Effective_To
medium gold: SELECT Date_Effective_From , Date_Effective_To , count(*) FROM Templates GROUP BY template_type_code
medium pred: SELECT Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Date_Effective_To
medium gold: SELECT Date_Effective_From , Date_Effective_To , count(*) FROM Templates GROUP BY template_type_code
hard pred: SELECT Templates.Date_Effective_From FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To HAVING COUNT(*) < "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code HAVING count(*) < 3
medium pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To HAVING COUNT(*) < "value"
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code HAVING count(*) < 3
medium pred: SELECT MIN(Templates.Date_Effective_From) , MIN(Templates.Date_Effective_To) , Templates.Version_Number FROM Templates
medium gold: SELECT min(Version_Number) , Date_Effective_From , Date_Effective_To FROM Templates
medium pred: SELECT MIN(Templates.Date_Effective_From) , AVG(Templates.Date_Effective_To) , Templates.Version_Number FROM Templates
medium gold: SELECT min(Version_Number) , Date_Effective_From , Date_Effective_To FROM Templates
medium pred: SELECT Templates.Date_Effective_To FROM Documents JOIN Templates WHERE Documents.Document_Name = "value"
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T2.document_name = "Data base"
medium pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents WHERE Documents.Document_Name = "value"
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T2.document_name = "Data base"
medium pred: SELECT Documents.Document_Name , Documents.Template_ID , Templates.Template_Type_Code FROM Documents JOIN Templates WHERE Templates.Template_Type_Code = "value"
medium gold: SELECT T2.document_name, T2.Document_ID, T2.template_id FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T1.template_type_code = "BK"
medium pred: SELECT Documents.Document_Name , Documents.Template_ID FROM Documents JOIN Templates WHERE Templates.Template_Type_Code = "value"
medium gold: SELECT T2.document_name, T2.Document_ID, T2.template_id FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T1.template_type_code = "BK"
medium pred: SELECT Templates.Date_Effective_From , Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
medium pred: SELECT Templates.Date_Effective_From , Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
extra pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT Templates.Date_Effective_From FROM Templates JOIN Ref_Template_Types WHERE Ref_Template_Types.Template_Type_Description = "value"
medium gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Ref_template_types AS T1 JOIN Templates AS T2 ON T1.template_type_code = T2.template_type_code WHERE T1.template_type_description = "Presentation"
medium pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Ref_Template_Types WHERE Ref_Template_Types.Template_Type_Description = "value"
medium gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Ref_template_types AS T1 JOIN Templates AS T2 ON T1.template_type_code = T2.template_type_code WHERE T1.template_type_description = "Presentation"
medium pred: SELECT players.first_name FROM players JOIN matches
medium gold: SELECT loser_name, winner_name FROM matches
medium pred: SELECT players.first_name FROM players JOIN matches
medium gold: SELECT loser_name, winner_name FROM matches
medium pred: SELECT matches.loser_age FROM matches
medium gold: SELECT winner_age, loser_age FROM matches
medium pred: SELECT matches.loser_age FROM matches
medium gold: SELECT winner_age, loser_age FROM matches
medium pred: SELECT AVG(matches.winner_rank) FROM matches
medium gold: SELECT avg(winner_rank), avg(loser_rank) FROM matches
medium pred: SELECT MAX(matches.winner_rank) FROM matches
medium gold: SELECT min(winner_rank), min(loser_rank) FROM matches
medium pred: SELECT MAX(matches.winner_rank_points) FROM matches
medium gold: SELECT min(winner_rank), min(loser_rank) FROM matches
easy pred: SELECT matches.tourney_name FROM matches GROUP BY matches.tourney_id HAVING COUNT(*) > "value"
easy gold: SELECT tourney_name FROM matches GROUP BY tourney_name HAVING count(*) > 10
extra pred: SELECT matches.winner_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT matches.winner_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value"
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT loser_name, winner_name FROM matches WHERE YEAR = 2016
extra pred: SELECT players.first_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT players.first_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value"
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT loser_name, winner_name FROM matches WHERE YEAR = 2016
extra pred: SELECT players.first_name FROM players JOIN matches WHERE matches.year = "value" OR matches.year = "value"
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 OR YEAR = 2016
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
medium pred: SELECT players.first_name , players.last_name FROM players ORDER BY players.last_name DESC
medium gold: SELECT first_name , last_name FROM players ORDER BY birth_date
medium pred: SELECT players.first_name , players.last_name FROM players WHERE players.hand = "value" ORDER BY players.birth_date DESC
medium gold: SELECT first_name , last_name FROM players WHERE hand = 'L' ORDER BY birth_date
medium pred: SELECT players.first_name FROM players WHERE players.hand = "value" ORDER BY players.birth_date ASC
medium gold: SELECT first_name , last_name FROM players WHERE hand = 'L' ORDER BY birth_date desc
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
hard pred: SELECT matches.winner_name , matches.winner_rank_points FROM players JOIN matches GROUP BY players.first_name ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT players.first_name , players.last_name FROM players JOIN matches GROUP BY players.first_name ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT matches.winner_name , matches.loser_name , matches.winner_rank_points FROM matches WHERE matches.tourney_name = "value" ORDER BY matches.winner_rank_points DESC LIMIT 1
hard gold: SELECT loser_name, winner_name FROM matches WHERE tourney_name = 'Australian Open' ORDER BY winner_rank_points DESC LIMIT 1
hard pred: SELECT matches.winner_name , matches.loser_name , matches.winner_rank_points FROM matches WHERE matches.tourney_name = "value" ORDER BY matches.winner_rank_points DESC LIMIT 1
hard gold: SELECT loser_name, winner_name FROM matches WHERE tourney_name = 'Australian Open' ORDER BY winner_rank_points DESC LIMIT 1
medium pred: SELECT players.first_name , players.last_name FROM players JOIN matches ORDER BY matches.minutes DESC LIMIT 1
medium gold: SELECT winner_name , loser_name FROM matches ORDER BY minutes DESC LIMIT 1
medium pred: SELECT players.first_name , AVG(rankings.ranking_points) FROM players JOIN rankings GROUP BY players.first_name
medium gold: SELECT avg(ranking) , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
medium pred: SELECT SUM(rankings.ranking_points) , players.last_name FROM players JOIN rankings GROUP BY players.first_name
medium gold: SELECT sum(ranking_points) , T1.first_name, T1.last_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
medium pred: SELECT players.first_name , SUM(players.last_name) FROM players JOIN rankings GROUP BY players.last_name
medium gold: SELECT sum(ranking_points) , T1.first_name, T1.last_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
medium pred: SELECT COUNT(*) , rankings.ranking_date FROM rankings GROUP BY rankings.ranking_date
medium gold: SELECT sum(tours) , ranking_date FROM rankings GROUP BY ranking_date
medium pred: SELECT COUNT(matches.winner_entry) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
medium pred: SELECT COUNT(*) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
hard pred: SELECT matches.winner_name , players.birth_date FROM players JOIN matches ORDER BY matches.winner_rank_points DESC LIMIT 1
hard gold: SELECT T1.first_name , T1.last_name , T1.birth_date FROM players AS T1 JOIN matches AS T2 ON T1.player_id = T2.winner_id ORDER BY T2.winner_rank_points DESC LIMIT 1
hard pred: SELECT players.first_name , players.birth_date FROM players JOIN matches ORDER BY matches.winner_rank_points DESC LIMIT 1
hard gold: SELECT T1.first_name , T1.last_name , T1.birth_date FROM players AS T1 JOIN matches AS T2 ON T1.player_id = T2.winner_id ORDER BY T2.winner_rank_points DESC LIMIT 1
medium pred: SELECT Addresses.line_1 FROM Addresses
medium gold: SELECT line_1 , line_2, line_3 FROM addresses
medium pred: SELECT Addresses.line_3 FROM Addresses
medium gold: SELECT line_1 , line_2, line_3 FROM addresses
easy pred: SELECT COUNT(*) FROM show
easy gold: SELECT count(*) FROM show where If_first_show = 'T'
easy pred: SELECT COUNT(*) FROM show
easy gold: SELECT count(*) FROM show where If_first_show = 'T'
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
easy gold: SELECT Name FROM conductor ORDER BY birthday desc
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
easy gold: SELECT Name FROM conductor ORDER BY birthday desc
easy pred: SELECT AVG(show.Attendance) FROM show WHERE show.If_first_show NOT IN ( SELECT show.If_first_show FROM show )
easy gold: SELECT avg(Attendance) FROM SHOW where If_first_show = 'F'
easy pred: SELECT AVG(show.Attendance) FROM show WHERE show.If_first_show != "value"
easy gold: SELECT avg(Attendance) FROM SHOW where If_first_show = 'F'
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
easy gold: SELECT Name FROM conductor ORDER BY birthday DESC
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
easy gold: SELECT Name FROM conductor ORDER BY birthday DESC
hard pred: SELECT conductor.Name FROM orchestra JOIN conductor ORDER BY orchestra.Year_of_Founded DESC LIMIT 1
hard gold: SELECT T1.Name FROM conductor AS T1 JOIN orchestra AS T2 ON T1.Conductor_ID = T2.Conductor_ID order by Year_of_Founded asc limit 1
hard pred: SELECT conductor.Name FROM conductor JOIN orchestra ORDER BY orchestra.Orchestra DESC LIMIT 1
hard gold: SELECT T1.Name FROM conductor AS T1 JOIN orchestra AS T2 ON T1.Conductor_ID = T2.Conductor_ID order by Year_of_Founded desc limit 1
easy pred: SELECT orchestra.Major_Record_Format FROM orchestra ORDER BY orchestra.Year_of_Founded ASC
easy gold: SELECT Major_Record_Format FROM orchestra ORDER BY Year_of_Founded desc
easy pred: SELECT orchestra.Major_Record_Format FROM orchestra ORDER BY orchestra.Year_of_Founded ASC
easy gold: SELECT Major_Record_Format FROM orchestra ORDER BY Year_of_Founded desc
medium pred: SELECT orchestra.Record_Company FROM orchestra ORDER BY orchestra.Record_Company DESC LIMIT 1
medium gold: SELECT Record_Company FROM orchestra ORDER BY Year_of_Founded desc LIMIT 1
easy pred: SELECT orchestra.Record_Company FROM orchestra WHERE orchestra.Year_of_Founded >= "value" INTERSECT SELECT orchestra.Record_Company FROM orchestra WHERE orchestra.Year_of_Founded >= "value"
easy gold: SELECT Record_Company FROM orchestra WHERE Year_of_Founded >= 2003
medium pred: SELECT COUNT(*) FROM show WHERE show.Result = "value"
medium gold: SELECT COUNT(*) FROM show WHERE Result = "Glebe Park" and If_first_show = "T"
medium pred: SELECT COUNT(*) FROM show WHERE show.Result = "value"
medium gold: SELECT COUNT(*) FROM show WHERE Result = "Glebe Park" and If_first_show = "T"
hard pred: SELECT performance.Type FROM performance JOIN show WHERE show.If_first_show != "value" GROUP BY show.Performance_ID HAVING COUNT(*) > "value"
hard gold: SELECT Type FROM performance AS T1 JOIN show AS T2 ON T1.Performance_ID = T2.Performance_ID WHERE If_first_show = 'F' GROUP BY T2.Performance_ID HAVING COUNT(*) > 1
hard pred: SELECT performance.Type FROM performance JOIN show WHERE show.If_first_show != "value" GROUP BY show.Performance_ID HAVING COUNT(*) > "value"
hard gold: SELECT Type FROM performance AS T1 JOIN show AS T2 ON T1.Performance_ID = T2.Performance_ID Where If_first_show = 'F' GROUP BY T2.Performance_ID HAVING COUNT(*) > 1
hard pred: SELECT AVG(Dogs.age) FROM Dogs JOIN Treatments
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
hard pred: SELECT AVG(Dogs.age) FROM Treatments JOIN Dogs
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) > "value"
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Professionals.state HAVING COUNT(*) > "value"
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
hard gold: SELECT name FROM Dogs WHERE dog_id NOT IN( SELECT dog_id FROM Treatments GROUP BY dog_id HAVING sum(cost_of_treatment) > 1000 )
hard pred: SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
hard gold: SELECT name FROM Dogs WHERE dog_id NOT IN( SELECT dog_id FROM Treatments GROUP BY dog_id HAVING sum(cost_of_treatment) > 1000 )
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
extra pred: SELECT Professionals.professional_id , Professionals.last_name FROM Professionals WHERE Professionals.professional_id NOT IN ( SELECT Treatments.professional_id FROM Treatments )
extra gold: SELECT professional_id , first_name , last_name FROM Professionals EXCEPT SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id
extra pred: SELECT Professionals.professional_id , Professionals.last_name FROM Professionals WHERE Professionals.professional_id NOT IN ( SELECT Treatments.professional_id FROM Treatments )
extra gold: SELECT professional_id , first_name , last_name FROM Professionals EXCEPT SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id
extra pred: SELECT Owners.owner_id , Owners.last_name FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T2.first_name , T2.last_name FROM Dogs AS T1 JOIN Owners AS T2 ON T1.owner_id = T2.owner_id GROUP BY T1.owner_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Owners.owner_id , Owners.last_name FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T2.first_name , T2.last_name FROM Dogs AS T1 JOIN Owners AS T2 ON T1.owner_id = T2.owner_id GROUP BY T1.owner_id ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Treatments JOIN Professionals GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Treatments JOIN Professionals GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code where abandoned_yn = 1 GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code where abandoned_yn = 1 GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs JOIN Charges GROUP BY Owners.owner_id ORDER BY SUM(Charges.charge_amount) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY SUM(Dogs.date_adopted) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
medium pred: SELECT Treatments.professional_id , Professionals.last_name FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
medium gold: SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Professionals JOIN Treatments GROUP BY Professionals.professional_id HAVING COUNT(*) >= "value"
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
extra pred: SELECT Professionals.first_name FROM Professionals JOIN Treatments WHERE Professionals.last_name < ( SELECT AVG(Treatments.cost_of_treatment) FROM Treatments )
extra gold: SELECT DISTINCT T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 WHERE cost_of_treatment < ( SELECT avg(cost_of_treatment) FROM Treatments )
extra pred: SELECT Professionals.first_name FROM Professionals JOIN Treatments WHERE Professionals.last_name < ( SELECT AVG(Treatments.cost_of_treatment) FROM Treatments )
extra gold: SELECT DISTINCT T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 WHERE cost_of_treatment < ( SELECT avg(cost_of_treatment) FROM Treatments )
medium pred: SELECT Treatments.date_of_treatment , Professionals.last_name FROM Treatments JOIN Professionals
medium gold: SELECT T1.date_of_treatment , T2.first_name , T2.last_name FROM Treatments AS T1 JOIN Professionals AS T2 ON T1.professional_id = T2.professional_id
medium pred: SELECT Treatments.date_of_treatment , Professionals.last_name FROM Treatments JOIN Professionals
medium gold: SELECT T1.date_of_treatment , T2.first_name , T2.last_name FROM Treatments AS T1 JOIN Professionals AS T2 ON T1.professional_id = T2.professional_id
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
medium gold: SELECT T1.first_name , T1.last_name , T2.size_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
medium gold: SELECT T1.first_name , T1.last_name , T2.size_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
medium gold: SELECT T1.first_name , T1.last_name , T2.name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
medium gold: SELECT T1.first_name , T1.last_name , T2.name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Treatments JOIN Dogs GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Dogs JOIN Treatments GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
medium pred: SELECT Dogs.date_arrived FROM Dogs
medium gold: SELECT DISTINCT T1.date_arrived , T1.date_departed FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id
medium pred: SELECT Dogs.date_arrived FROM Dogs ORDER BY Dogs.date_departed DESC
medium gold: SELECT DISTINCT T1.date_arrived , T1.date_departed FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
medium pred: SELECT Dogs.date_arrived , Dogs.date_departed FROM Dogs
medium gold: SELECT date_arrived , date_departed FROM Dogs where abandoned_yn = 1
medium pred: SELECT Dogs.date_arrived , Dogs.date_departed FROM Dogs
medium gold: SELECT date_arrived , date_departed FROM Dogs where abandoned_yn = 1
medium pred: SELECT Professionals.first_name FROM Professionals WHERE Professionals.city LIKE "value"
medium gold: SELECT first_name , last_name FROM professionals WHERE city LIKE '%West%'
medium pred: SELECT Professionals.first_name FROM Professionals WHERE Professionals.city LIKE "value"
medium gold: SELECT first_name , last_name FROM professionals WHERE city LIKE '%West%'
medium pred: SELECT Owners.last_name FROM Owners WHERE Owners.state LIKE "value"
medium gold: SELECT first_name , last_name FROM Owners WHERE state LIKE '%North%'
medium pred: SELECT Owners.first_name FROM Owners WHERE Owners.state LIKE "value"
medium gold: SELECT first_name , last_name FROM Owners WHERE state LIKE '%North%'
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.age < ( SELECT AVG(Dogs.age) FROM Dogs )
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and age < ( SELECT avg(age) FROM Dogs )
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.age < ( SELECT AVG(Dogs.age) FROM Dogs )
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and age < ( SELECT avg(age) FROM Dogs )
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.dog_id NOT IN ( SELECT Treatments.dog_id FROM Treatments )
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and dog_id NOT IN ( SELECT dog_id FROM Treatments )
extra pred: SELECT COUNT(Treatments.dog_id) FROM Treatments
extra gold: SELECT count(*) FROM Dogs WHERE dog_id NOT IN ( SELECT dog_id FROM Treatments )
extra pred: SELECT COUNT(*) FROM Owners WHERE Owners.owner_id NOT IN ( SELECT Dogs.owner_id FROM Dogs )
extra gold: SELECT count(*) FROM Owners WHERE owner_id NOT IN ( SELECT owner_id FROM Dogs where abandoned_yn = 1 )
extra pred: SELECT COUNT(*) FROM Owners WHERE Owners.owner_id NOT IN ( SELECT Dogs.owner_id FROM Dogs )
extra gold: SELECT count(*) FROM Owners WHERE owner_id NOT IN ( SELECT owner_id FROM Dogs where abandoned_yn = 1 )
easy pred: SELECT AVG(Dogs.age) FROM Dogs
easy gold: SELECT avg(age) FROM Dogs where abandoned_yn = 1
easy pred: SELECT AVG(Dogs.age) FROM Dogs
easy gold: SELECT avg(age) FROM Dogs where abandoned_yn = 1
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
medium gold: SELECT charge_type , charge_amount FROM Charges
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
medium gold: SELECT charge_type , charge_amount FROM Charges
easy pred: SELECT Charges.charge_type FROM Charges ORDER BY Charges.charge_amount DESC LIMIT 1
easy gold: SELECT max(charge_amount) FROM Charges
easy pred: SELECT Charges.charge_amount FROM Charges ORDER BY Charges.charge_type DESC LIMIT 1
easy gold: SELECT max(charge_amount) FROM Charges
medium pred: SELECT Professionals.email_address , Professionals.last_name FROM Professionals
medium gold: SELECT email_address , first_name , last_name FROM professionals
medium pred: SELECT Professionals.email_address , Professionals.last_name FROM Professionals
medium gold: SELECT email_address , first_name , last_name FROM professionals
medium pred: SELECT Sizes.size_code , Sizes.size_code FROM Sizes
medium gold: SELECT DISTINCT breed_code , size_code FROM dogs
medium pred: SELECT Professionals.first_name , Professionals.last_name FROM Professionals JOIN Treatments JOIN Treatment_Types
medium gold: SELECT DISTINCT T1.first_name , T1.last_name , T3.treatment_type_description FROM professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id JOIN Treatment_types AS T3 ON T2.treatment_type_code = T3.treatment_type_code
medium pred: SELECT Professionals.first_name , Professionals.last_name FROM Professionals JOIN Treatments JOIN Treatment_Types
medium gold: SELECT DISTINCT T1.first_name , T1.last_name , T3.treatment_type_description FROM professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id JOIN Treatment_types AS T3 ON T2.treatment_type_code = T3.treatment_type_code
easy pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC
easy gold: SELECT Name FROM singer ORDER BY Birth_Year desc
easy pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC
easy gold: SELECT Name FROM singer ORDER BY Birth_Year desc
easy pred: SELECT singer.Name FROM singer WHERE singer.Birth_Year < ( SELECT MAX(singer.Birth_Year) FROM singer WHERE singer.Birth_Year = "value" )
easy gold: SELECT Name FROM singer WHERE Birth_Year <= 1948
easy pred: SELECT singer.Name FROM singer WHERE singer.Birth_Year < ( SELECT MAX(singer.Birth_Year) FROM singer WHERE singer.Birth_Year = "value" )
easy gold: SELECT Name FROM singer WHERE Birth_Year <= 1948
medium pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC LIMIT 1
medium gold: SELECT Name FROM singer ORDER BY Birth_Year DESC LIMIT 1
medium pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC LIMIT 1
medium gold: SELECT Name FROM singer ORDER BY Birth_Year DESC LIMIT 1
medium pred: SELECT singer.Name FROM singer JOIN song GROUP BY song.Singer_ID HAVING COUNT(*) > "value"
medium gold: SELECT T1.Name FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID = T2.Singer_ID GROUP BY T1.Name HAVING COUNT(*) > 1
easy pred: SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year < "value" INTERSECT SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year <= "value"
easy gold: SELECT Citizenship FROM singer WHERE Birth_Year <= 1945
easy pred: SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year < "value" INTERSECT SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year = "value"
easy gold: SELECT Citizenship FROM singer WHERE Birth_Year <= 1945
easy medium hard extra all
count 110 246 74 105 535
====================== EXACT MATCHING ACCURACY =====================
exact match 0.755 0.508 0.405 0.333 0.510
---------------------PARTIAL MATCHING ACCURACY----------------------
select 0.964 0.695 0.824 0.752 0.779
select(no AGG) 0.973 0.711 0.892 0.790 0.806
where 0.860 0.807 0.486 0.508 0.692
where(no OP) 0.980 0.852 0.657 0.689 0.808
group(no Having) 0.625 0.861 0.667 0.775 0.797
group 0.625 0.833 0.667 0.750 0.775
order 0.143 0.543 0.786 0.786 0.639
and/or 1.000 0.984 0.973 0.894 0.968
IUEN 0.000 0.000 0.455 0.320 0.325
keywords 0.778 0.896 0.792 0.760 0.828
---------------------- PARTIAL MATCHING RECALL ----------------------
select 0.964 0.695 0.824 0.752 0.779
select(no AGG) 0.973 0.711 0.892 0.790 0.806
where 0.768 0.789 0.472 0.449 0.645
where(no OP) 0.875 0.833 0.639 0.609 0.753
group(no Having) 0.625 0.861 0.750 0.816 0.821
group 0.625 0.833 0.750 0.789 0.799
order 0.200 0.559 0.786 0.825 0.679
and/or 1.000 0.992 1.000 0.989 0.994
IUEN 0.000 0.000 0.417 0.364 0.382
keywords 0.757 0.892 0.770 0.752 0.817
---------------------- PARTIAL MATCHING F1 --------------------------
select 0.964 0.695 0.824 0.752 0.779
select(no AGG) 0.973 0.711 0.892 0.790 0.806
where 0.811 0.798 0.479 0.477 0.668
where(no OP) 0.925 0.843 0.648 0.646 0.779
group(no Having) 0.625 0.861 0.706 0.795 0.809
group 0.625 0.833 0.706 0.769 0.787
order 0.167 0.551 0.786 0.805 0.658
and/or 1.000 0.988 0.986 0.939 0.981
IUEN 1.000 1.000 0.435 0.340 0.351
keywords 0.767 0.894 0.781 0.756 0.822

View File

@ -0,0 +1,774 @@
medium pred: SELECT MAX(stadium.Capacity) , AVG(stadium.Average) FROM stadium
medium gold: select max(capacity), average from stadium
medium pred: SELECT AVG(stadium.Average) , MAX(stadium.Capacity) FROM stadium
medium gold: select avg(capacity) , max(capacity) from stadium
medium pred: SELECT stadium.Stadium_ID , COUNT(*) FROM stadium JOIN concert GROUP BY stadium.Stadium_ID
medium gold: SELECT T2.name , count(*) FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id GROUP BY T1.stadium_id
medium pred: SELECT singer.Name , COUNT(*) FROM singer_in_concert JOIN singer GROUP BY singer.Name
medium gold: SELECT T2.name , count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id GROUP BY T2.singer_id
hard pred: SELECT COUNT(*) FROM concert JOIN stadium ORDER BY stadium.Capacity DESC LIMIT 1
hard gold: select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)
hard pred: SELECT COUNT(*) FROM concert JOIN stadium ORDER BY stadium.Capacity DESC LIMIT 1
hard gold: select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)
medium pred: SELECT MAX(Pets.weight) , MAX(Pets.weight) , Pets.PetType FROM Pets GROUP BY Pets.PetType
medium gold: SELECT max(weight) , petType FROM pets GROUP BY petType
medium pred: SELECT COUNT(*) FROM Pets JOIN Has_Pet JOIN Student WHERE Student.Age > "value"
medium gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid WHERE T1.age > 20
extra pred: SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" INTERSECT SELECT Student.Fname FROM Student JOIN Has_Pet WHERE Has_Pet.PetID = "value"
extra gold: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog'
extra pred: SELECT Student.Major , Student.Age FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet JOIN Pets WHERE Pets.PetType = "value" )
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet JOIN Pets WHERE Pets.PetType = "value"
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
extra pred: SELECT Student.Fname , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
extra pred: SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" EXCEPT SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
medium pred: SELECT COUNT(*) , Has_Pet.StuID FROM Has_Pet GROUP BY Has_Pet.StuID
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
medium pred: SELECT Has_Pet.StuID , COUNT(*) FROM Has_Pet WHERE Has_Pet.StuID != "value" GROUP BY Has_Pet.StuID
medium gold: select count(*) , t1.stuid from student as t1 join has_pet as t2 on t1.stuid = t2.stuid group by t1.stuid
hard pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.pet_age = "value"
hard gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pet_age = 3 AND T3.pettype = 'cat'
hard pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.pet_age = "value"
hard gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pet_age = 3 AND T3.pettype = 'cat'
extra pred: SELECT car_makers.Maker FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year = '1970';
extra pred: SELECT car_makers.FullName FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year = '1970';
extra pred: SELECT car_names.Make , cars_data.Year FROM cars_data JOIN car_names ORDER BY cars_data.Year ASC LIMIT 1
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
extra pred: SELECT car_makers.Maker , cars_data.Year FROM cars_data JOIN car_names JOIN model_list JOIN car_makers WHERE cars_data.Year = "value" ORDER BY cars_data.Year ASC LIMIT 1
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
hard pred: SELECT car_names.Model FROM cars_data JOIN car_names WHERE cars_data.Year > "value"
hard gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.model = T2.model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.id WHERE T3.year > 1980;
medium pred: SELECT COUNT(*) , car_makers.FullName FROM car_makers JOIN model_list GROUP BY car_makers.Maker
medium gold: select count(*) , t2.fullname from model_list as t1 join car_makers as t2 on t1.maker = t2.id group by t2.id;
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
medium pred: SELECT car_makers.Maker , model_list.Model FROM car_makers JOIN model_list
medium gold: SELECT Maker , Model FROM MODEL_LIST;
medium pred: SELECT countries.CountryName , car_makers.Id FROM car_makers JOIN countries
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
medium pred: SELECT countries.CountryName , countries.CountryId FROM car_makers JOIN countries
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
medium pred: SELECT AVG(cars_data.Weight) , MAX(cars_data.Weight) , cars_data.Year FROM cars_data GROUP BY cars_data.Year
medium gold: SELECT avg(Weight) , YEAR FROM CARS_DATA GROUP BY YEAR;
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
extra pred: SELECT MAX(cars_data.Horsepower) , MAX(car_names.Make) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
extra pred: SELECT MAX(car_names.Make) , MAX(cars_data.Horsepower) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
medium pred: SELECT AVG(cars_data.Edispl) FROM cars_data JOIN car_names WHERE car_names.Make = "value"
medium gold: SELECT avg(T2.edispl) FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T1.Model = 'volvo';
medium pred: SELECT car_makers.FullName , car_makers.Id , COUNT(*) FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
medium gold: SELECT T1.FullName , T1.Id FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) > 3;
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
extra pred: SELECT model_list.Model FROM model_list JOIN car_makers JOIN cars_data WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
easy pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight < "value" INTERSECT SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight > "value"
easy gold: select distinct year from cars_data where weight between 3000 and 4000;
extra pred: SELECT cars_data.Cylinders FROM cars_data JOIN car_names WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'volvo' ORDER BY T1.accelerate ASC LIMIT 1;
extra pred: SELECT cars_data.Cylinders FROM cars_data JOIN car_names WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'volvo' ORDER BY T1.accelerate ASC LIMIT 1;
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
medium pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
medium gold: select count(*) from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 2
medium pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
medium gold: select count(*) from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 2
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value" EXCEPT SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value"
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders < "value"
extra gold: select t2.makeid , t2.make from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t1.horsepower > (select min(horsepower) from cars_data) and t1.cylinders < 4;
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value"
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
extra pred: SELECT model_list.Model FROM model_list JOIN car_names WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_names.Make = "value"
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
extra pred: SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
extra gold: select t1.id , t1.maker from car_makers as t1 join model_list as t2 on t1.id = t2.maker group by t1.id having count(*) >= 2 intersect select t1.id , t1.maker from car_makers as t1 join model_list as t2 on t1.id = t2.maker join car_names as t3 on t2.model = t3.model group by t1.id having count(*) > 3;
extra pred: SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY model_list.Maker HAVING COUNT(*) > "value"
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN model_list JOIN countries JOIN model_list JOIN countries WHERE model_list.Model = "value"
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'fiat';
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries WHERE car_makers.Maker = "value"
extra gold: select t1.countryid , t1.countryname from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 3 union select t1.countryid , t1.countryname from countries as t1 join car_makers as t2 on t1.countryid = t2.country join model_list as t3 on t2.id = t3.maker where t3.model = 'fiat';
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Aberdeen"
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Aberdeen"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.DestAirport = "ASY"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.DestAirport = "ASY"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.SourceAirport = "AHD"
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airports.AirportName = "value"
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.SourceAirport = "AHD"
extra pred: SELECT airports.City FROM airports JOIN flights GROUP BY flights.SourceAirport ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.City FROM AIRPORTS AS T1 JOIN FLIGHTS AS T2 ON T1.AirportCode = T2.SourceAirport GROUP BY T1.City ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT airlines.Abbreviation , airlines.Country FROM airlines JOIN flights GROUP BY flights.Airline ORDER BY COUNT(*) ASC LIMIT 1
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
extra pred: SELECT airlines.Abbreviation , airports.Country FROM airlines JOIN airports JOIN flights JOIN flights GROUP BY airlines.Abbreviation ORDER BY COUNT(*) ASC LIMIT 1
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "AHD"
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.DestAirport = "AHD"
extra pred: SELECT airlines.Airline FROM airports JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE flights.SourceAirport = "value"
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
extra pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airports JOIN airlines JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE airlines.Airline = "value"
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
medium pred: SELECT airlines.Airline FROM airlines JOIN flights GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
medium pred: SELECT airlines.Airline FROM flights JOIN airlines GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
easy pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
easy gold: SELECT FlightNo FROM FLIGHTS WHERE DestAirport = "APG"
hard pred: SELECT airports.AirportName FROM airports WHERE airports.AirportCode NOT IN ( SELECT flights.SourceAirport FROM flights )
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
hard pred: SELECT airports.AirportName FROM airports EXCEPT SELECT airports.AirportName FROM airports JOIN flights
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
medium pred: SELECT COUNT(*) , shop.Name FROM hiring JOIN shop GROUP BY hiring.Shop_ID
medium gold: SELECT count(*) , t2.name FROM hiring AS t1 JOIN shop AS t2 ON t1.shop_id = t2.shop_id GROUP BY t2.name
medium pred: SELECT Templates.Version_Number , Templates.Template_Type_Code FROM Templates ORDER BY Templates.Version_Number ASC LIMIT 1
medium gold: SELECT min(Version_Number) , template_type_code FROM Templates
medium pred: SELECT Templates.Version_Number , Templates.Template_Type_Code FROM Templates ORDER BY Templates.Version_Number ASC LIMIT 1
medium gold: SELECT min(Version_Number) , template_type_code FROM Templates
medium pred: SELECT Templates.Template_Type_Code , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
medium gold: SELECT T1.template_type_code , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
extra pred: SELECT Templates.Template_Type_Code FROM Templates GROUP BY Templates.Template_Type_Code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.template_type_code FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Templates.Template_Type_Code FROM Templates GROUP BY Templates.Template_Type_Code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.template_type_code FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT Paragraphs.Document_ID , COUNT(*) FROM Paragraphs GROUP BY Paragraphs.Document_ID ORDER BY COUNT(*) ASC
medium gold: SELECT document_id , count(*) FROM Paragraphs GROUP BY document_id ORDER BY document_id
medium pred: SELECT Paragraphs.Document_ID , COUNT(*) FROM Paragraphs GROUP BY Paragraphs.Document_ID ORDER BY COUNT(*) ASC
medium gold: SELECT document_id , count(*) FROM Paragraphs GROUP BY document_id ORDER BY document_id
medium pred: SELECT teacher.Name , COUNT(*) FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID
medium gold: SELECT T2.Name , COUNT(*) FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name
medium pred: SELECT teacher.Name , COUNT(*) FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID
medium gold: SELECT T2.Name , COUNT(*) FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name
medium pred: SELECT teacher.Name FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID HAVING COUNT(*) >= "value"
medium gold: SELECT T2.Name FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name HAVING COUNT(*) >= 2
medium pred: SELECT teacher.Name FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID HAVING COUNT(*) >= "value"
medium gold: SELECT T2.Name FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name HAVING COUNT(*) >= 2
medium pred: SELECT visitor.Name FROM visitor WHERE visitor.Level_of_membership > "value" ORDER BY visitor.Name DESC
medium gold: SELECT name FROM visitor WHERE Level_of_membership > 4 ORDER BY Level_of_membership DESC
easy pred: SELECT AVG(visitor.Age) FROM visitor WHERE visitor.Level_of_membership > "value"
easy gold: SELECT avg(age) FROM visitor WHERE Level_of_membership <= 4
medium pred: SELECT visitor.Name , visitor.Level_of_membership FROM visitor WHERE visitor.Age > "value" ORDER BY visitor.Age ASC
medium gold: SELECT name , Level_of_membership FROM visitor WHERE Level_of_membership > 4 ORDER BY age DESC
hard pred: SELECT visitor.Name , visitor.Age FROM visitor JOIN visit GROUP BY visit.visitor_ID ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT t1.name , t1.age FROM visitor AS t1 JOIN visit AS t2 ON t1.id = t2.visitor_id ORDER BY t2.num_of_ticket DESC LIMIT 1
extra pred: SELECT COUNT(*) FROM visitor WHERE visitor.ID NOT IN ( SELECT visit.visitor_ID FROM visit JOIN museum WHERE museum.Open_Year > "value" )
extra gold: SELECT count(*) FROM visitor WHERE id NOT IN (SELECT t2.visitor_id FROM museum AS t1 JOIN visit AS t2 ON t1.Museum_ID = t2.Museum_ID WHERE t1.open_year > 2010)
easy pred: SELECT AVG(matches.winner_rank_points) FROM matches
easy gold: SELECT avg(winner_rank) FROM matches
easy pred: SELECT MAX(matches.loser_rank_points) FROM matches
easy gold: SELECT min(loser_rank) FROM matches
easy pred: SELECT MAX(matches.loser_rank_points) FROM matches
easy gold: SELECT min(loser_rank) FROM matches
easy pred: SELECT matches.tourney_name FROM matches GROUP BY matches.tourney_id HAVING COUNT(*) > "value"
easy gold: SELECT tourney_name FROM matches GROUP BY tourney_name HAVING count(*) > 10
hard pred: SELECT players.last_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT matches.winner_name FROM players JOIN matches WHERE matches.year = "value"
hard gold: SELECT winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT winner_name FROM matches WHERE YEAR = 2016
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
hard pred: SELECT matches.winner_name , matches.winner_rank_points FROM matches GROUP BY matches.winner_rank ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT players.first_name , AVG(rankings.ranking_points) FROM players JOIN rankings GROUP BY players.first_name
medium gold: SELECT avg(ranking) , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
medium pred: SELECT COUNT(*) , rankings.ranking_date FROM rankings GROUP BY rankings.ranking_date
medium gold: SELECT sum(tours) , ranking_date FROM rankings GROUP BY ranking_date
medium pred: SELECT COUNT(matches.winner_entry) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
medium pred: SELECT COUNT(*) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
medium pred: SELECT battle.name , battle.date , battle.result FROM battle
medium gold: SELECT name , date FROM battle
medium pred: SELECT MAX(death.killed) , MIN(death.killed) , death.caused_by_ship_id FROM death GROUP BY death.caused_by_ship_id
medium gold: SELECT max(killed) , min(killed) FROM death
easy pred: SELECT AVG(death.injured) , * FROM death GROUP BY *
easy gold: SELECT avg(injured) FROM death
hard pred: SELECT battle.id , battle.name FROM battle JOIN death GROUP BY battle.id HAVING SUM(death.killed) > "value"
hard gold: SELECT T1.id , T1.name FROM battle AS T1 JOIN ship AS T2 ON T1.id = T2.lost_in_battle JOIN death AS T3 ON T2.id = T3.caused_by_ship_id GROUP BY T1.id HAVING sum(T3.killed) > 10
extra pred: SELECT ship.id , ship.name FROM ship JOIN death GROUP BY death.caused_by_ship_id ORDER BY SUM(death.injured) DESC LIMIT 1
extra gold: SELECT T2.id , T2.name FROM death AS T1 JOIN ship AS t2 ON T1.caused_by_ship_id = T2.id GROUP BY T2.id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT battle.name , battle.result , battle.bulgarian_commander FROM battle WHERE battle.id NOT IN ( SELECT ship.lost_in_battle FROM ship WHERE ship.location = "value" )
extra gold: SELECT name , RESULT , bulgarian_commander FROM battle EXCEPT SELECT T1.name , T1.result , T1.bulgarian_commander FROM battle AS T1 JOIN ship AS T2 ON T1.id = T2.lost_in_battle WHERE T2.location = 'English Channel'
medium pred: SELECT Addresses.line_1 FROM Addresses WHERE Addresses.line_2 BETWEEN "value" AND "value"
medium gold: SELECT line_1 , line_2 FROM addresses
easy pred: SELECT COUNT(Degree_Programs.degree_program_id) FROM Degree_Programs
easy gold: SELECT count(DISTINCT degree_summary_name) FROM Degree_Programs
medium pred: SELECT Courses.course_name , Courses.course_id FROM Courses JOIN Sections GROUP BY Sections.course_id HAVING COUNT(*) < "value"
medium gold: SELECT T1.course_name , T1.course_id FROM Courses AS T1 JOIN Sections AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_id HAVING count(*) <= 2
medium pred: SELECT Students.first_name , Students.middle_name , Students.last_name , COUNT(*) FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.semester_id HAVING COUNT(*) = "value"
medium gold: SELECT T1.first_name , T1.middle_name , T1.last_name , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id HAVING count(*) = 2
medium pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id HAVING COUNT(*) = "value"
medium gold: SELECT T1.first_name , T1.middle_name , T1.last_name , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id HAVING count(*) = 2
hard pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students WHERE Student_Enrolment.degree_program_id = "value"
hard gold: SELECT DISTINCT T1.first_name , T1.middle_name , T1.last_name FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id JOIN Degree_Programs AS T3 ON T2.degree_program_id = T3.degree_program_id WHERE T3.degree_summary_name = 'Bachelor'
hard pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students WHERE Student_Enrolment.degree_program_id = "value"
hard gold: SELECT DISTINCT T1.first_name , T1.middle_name , T1.last_name FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id JOIN Degree_Programs AS T3 ON T2.degree_program_id = T3.degree_program_id WHERE T3.degree_summary_name = 'Bachelor'
extra pred: SELECT Student_Enrolment.degree_program_id FROM Student_Enrolment GROUP BY Student_Enrolment.degree_program_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.degree_summary_name FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id GROUP BY T1.degree_summary_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Student_Enrolment.student_id , Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.student_id , T1.first_name , T1.middle_name , T1.last_name , count(*) , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Students.first_name , Students.middle_name , COUNT(*) FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.student_id , T1.first_name , T1.middle_name , T1.last_name , count(*) , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Courses.course_name FROM Courses JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.course_name FROM Courses AS T1 JOIN Student_Enrolment_Courses AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Courses.course_name FROM Courses JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.course_name FROM Courses AS T1 JOIN Student_Enrolment_Courses AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Addresses.address_id , Addresses.line_1 FROM Addresses JOIN Students GROUP BY Addresses.address_id ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.address_id , T1.line_1 , T1.line_2 FROM Addresses AS T1 JOIN Students AS T2 ON T1.address_id = T2.current_address_id GROUP BY T1.address_id ORDER BY count(*) DESC LIMIT 1
medium pred: SELECT MAX(Transcripts.transcript_date) FROM Transcripts
medium gold: SELECT transcript_date FROM Transcripts ORDER BY transcript_date DESC LIMIT 1
hard pred: SELECT Transcript_Contents.transcript_id , COUNT(*) FROM Transcript_Contents JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.student_enrolment_id ORDER BY COUNT(*) DESC LIMIT 1
hard gold: SELECT count(*) , student_course_id FROM Transcript_Contents GROUP BY student_course_id ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT Student_Enrolment_Courses.course_id , COUNT(Transcript_Contents.transcript_id) FROM Student_Enrolment_Courses JOIN Transcript_Contents GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(Transcript_Contents.transcript_id) DESC LIMIT 1
hard gold: SELECT count(*) , student_course_id FROM Transcript_Contents GROUP BY student_course_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Semesters.semester_name FROM Student_Enrolment JOIN Semesters JOIN Degree_Programs WHERE Degree_Programs.degree_summary_name = "value" INTERSECT SELECT Semesters.semester_name FROM Student_Enrolment JOIN Semesters JOIN Semesters WHERE Student_Enrolment.student_id = "value"
extra gold: SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Master' INTERSECT SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Bachelor'
extra pred: SELECT Student_Enrolment.semester_id FROM Student_Enrolment WHERE Student_Enrolment.degree_program_id = "value" INTERSECT SELECT Student_Enrolment.semester_id FROM Student_Enrolment WHERE Student_Enrolment.degree_program_id = "value"
extra gold: SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Master' INTERSECT SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Bachelor'
easy pred: SELECT Students.current_address_id FROM Students
easy gold: SELECT count(DISTINCT current_address_id) FROM Students
medium pred: SELECT COUNT(TV_Channel.series_name) , TV_Channel.Content FROM TV_Channel
medium gold: SELECT count(DISTINCT series_name) , count(DISTINCT content) FROM TV_Channel;
medium pred: SELECT TV_Channel.series_name FROM TV_series JOIN TV_Channel JOIN Cartoon WHERE Cartoon.Title = "value"
medium gold: SELECT T1.series_name FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T2.Title = "The Rise of the Blue Beetle!";
medium pred: SELECT TV_Channel.series_name FROM TV_series JOIN TV_Channel JOIN Cartoon WHERE Cartoon.Title = "value"
medium gold: SELECT T1.series_name FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T2.Title = "The Rise of the Blue Beetle!";
medium pred: SELECT Cartoon.Title FROM Cartoon JOIN TV_Channel JOIN TV_series WHERE TV_Channel.series_name = "value"
medium gold: SELECT T2.Title FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T1.series_name = "Sky Radio";
easy pred: SELECT TV_Channel.id FROM TV_Channel GROUP BY TV_Channel.id HAVING COUNT(*) > "value"
easy gold: SELECT id FROM tv_channel GROUP BY country HAVING count(*) > 2
hard pred: SELECT TV_Channel.Package_Option FROM TV_Channel EXCEPT SELECT TV_Channel.Package_Option FROM TV_Channel JOIN Cartoon WHERE Cartoon.Directed_by = "value"
hard gold: SELECT package_option FROM TV_Channel WHERE id NOT IN (SELECT channel FROM cartoon WHERE directed_by = 'Ben Jones')
hard pred: SELECT TV_Channel.Package_Option FROM TV_Channel EXCEPT SELECT TV_Channel.Package_Option FROM TV_Channel JOIN Cartoon WHERE Cartoon.Directed_by = "value"
hard gold: SELECT package_option FROM TV_Channel WHERE id NOT IN (SELECT channel FROM cartoon WHERE directed_by = 'Ben Jones')
medium pred: SELECT people.Name FROM poker_player JOIN people GROUP BY people.Name ORDER BY COUNT(*) ASC
medium gold: SELECT T1.Name FROM people AS T1 JOIN poker_player AS T2 ON T1.People_ID = T2.People_ID ORDER BY T2.Final_Table_Made
extra pred: SELECT CONTESTANTS.contestant_number , CONTESTANTS.contestant_name FROM CONTESTANTS JOIN VOTES ORDER BY VOTES.contestant_number ASC LIMIT 1
extra gold: SELECT T1.contestant_number , T1.contestant_name FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number GROUP BY T1.contestant_number ORDER BY count(*) ASC LIMIT 1
extra pred: SELECT AREA_CODE_STATE.area_code FROM VOTES JOIN CONTESTANTS WHERE CONTESTANTS.contestant_name = "value" INTERSECT SELECT AREA_CODE_STATE.area_code FROM VOTES JOIN CONTESTANTS WHERE CONTESTANTS.contestant_name = "value"
extra gold: SELECT T3.area_code FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number JOIN area_code_state AS T3 ON T2.state = T3.state WHERE T1.contestant_name = 'Tabatha Gehling' INTERSECT SELECT T3.area_code FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number JOIN area_code_state AS T3 ON T2.state = T3.state WHERE T1.contestant_name = 'Kelly Clauss'
easy pred: SELECT SUM(country.SurfaceArea) FROM country WHERE country.Continent = "value"
easy gold: SELECT sum(SurfaceArea) FROM country WHERE Region = "Caribbean"
easy pred: SELECT country.Continent FROM country WHERE country.LocalName = "value"
easy gold: SELECT Continent FROM country WHERE Name = "Anguilla"
extra pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.Name = "value" GROUP BY countrylanguage.Language ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Aruba" ORDER BY Percentage DESC LIMIT 1
extra pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.Name = "value" GROUP BY countrylanguage.Language ORDER BY COUNT(country.Code) DESC LIMIT 1
extra gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Aruba" ORDER BY Percentage DESC LIMIT 1
medium pred: SELECT country.Population , country.GNP FROM country WHERE country.Continent = "value" ORDER BY country.GNP DESC LIMIT 1
medium gold: SELECT sum(Population) , max(GNP) FROM country WHERE Continent = "Asia"
medium pred: SELECT AVG(country.LifeExpectancy) FROM country WHERE country.Continent = "value" AND country.Name = "value"
medium gold: SELECT avg(LifeExpectancy) FROM country WHERE Continent = "Africa" AND GovernmentForm = "Republic"
medium pred: SELECT AVG(country.LifeExpectancy) FROM country WHERE country.Continent = "value" AND country.Continent = "value"
medium gold: SELECT avg(LifeExpectancy) FROM country WHERE Continent = "Africa" AND GovernmentForm = "Republic"
easy pred: SELECT city.Population FROM city WHERE city.District = "value"
easy gold: SELECT sum(Population) FROM city WHERE District = "Gelderland"
medium pred: SELECT AVG(country.GNP) , AVG(country.Population) FROM country WHERE country.GovernmentForm = "value"
medium gold: SELECT avg(GNP) , sum(population) FROM country WHERE GovernmentForm = "US Territory"
medium pred: SELECT AVG(country.GNP) , AVG(country.Population) FROM country WHERE country.LocalName = "value"
medium gold: SELECT avg(GNP) , sum(population) FROM country WHERE GovernmentForm = "US Territory"
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.Name = "value"
medium gold: SELECT COUNT(*) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Afghanistan" AND IsOfficial = "T"
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.Name = "value"
medium gold: SELECT COUNT(*) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Afghanistan" AND IsOfficial = "T"
extra pred: SELECT country.Name FROM countrylanguage JOIN country GROUP BY countrylanguage.CountryCode ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name ORDER BY COUNT(*) DESC LIMIT 1
extra pred: SELECT country.Continent FROM country GROUP BY country.Continent ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.Continent FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Continent ORDER BY COUNT(*) DESC LIMIT 1
easy pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" AND countrylanguage.Language = "value"
easy gold: SELECT COUNT(*) FROM (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch")
easy pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" AND countrylanguage.Language = "value"
easy gold: SELECT COUNT(*) FROM (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch")
extra pred: SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value" INTERSECT SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value"
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "French" AND T2.IsOfficial = "T"
extra pred: SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value" INTERSECT SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value"
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "French" AND T2.IsOfficial = "T"
extra pred: SELECT country.Name FROM countrylanguage JOIN country WHERE countrylanguage.Language = "value" OR countrylanguage.Language = "value"
extra gold: select t1.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode where t2.language = "english" and isofficial = "t" union select t1.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode where t2.language = "dutch" and isofficial = "t"
extra pred: SELECT country.Name FROM countrylanguage JOIN country WHERE countrylanguage.Language = "value" OR countrylanguage.Language = "value"
extra gold: SELECT * FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND IsOfficial = "T" UNION SELECT * FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch" AND IsOfficial = "T"
extra pred: SELECT city.Name FROM city WHERE city.Population = "value" ORDER BY city.Population DESC LIMIT 1
extra gold: SELECT T1.Name , T1.Population FROM city AS T1 JOIN countrylanguage AS T2 ON T1.CountryCode = T2.CountryCode WHERE T2.Language = "English" ORDER BY T1.Population DESC LIMIT 1
extra pred: SELECT city.Name FROM city JOIN countrylanguage WHERE countrylanguage.Language = "value" ORDER BY city.Population DESC LIMIT 1
extra gold: SELECT T1.Name , T1.Population FROM city AS T1 JOIN countrylanguage AS T2 ON T1.CountryCode = T2.CountryCode WHERE T2.Language = "English" ORDER BY T1.Population DESC LIMIT 1
hard pred: SELECT country.Name , country.Population , country.LifeExpectancy FROM country WHERE country.SurfaceArea = "value" ORDER BY country.Continent DESC LIMIT 1
hard gold: SELECT Name , Population , LifeExpectancy FROM country WHERE Continent = "Asia" ORDER BY SurfaceArea DESC LIMIT 1
extra pred: SELECT AVG(country.LifeExpectancy) FROM country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language != "value"
extra gold: SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T")
extra pred: SELECT AVG(country.LifeExpectancy) FROM country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language != "value"
extra gold: SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T")
extra pred: SELECT SUM(country.Population) FROM country JOIN countrylanguage WHERE countrylanguage.Language != "value"
extra gold: SELECT sum(Population) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English")
extra pred: SELECT country.Population FROM country JOIN countrylanguage WHERE countrylanguage.Language != "value"
extra gold: SELECT sum(Population) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English")
medium pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.HeadOfState = "value"
medium gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.HeadOfState = "Beatrix" AND T2.IsOfficial = "T"
medium pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.HeadOfState = "value"
medium gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.HeadOfState = "Beatrix" AND T2.IsOfficial = "T"
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.IndepYear < "value"
medium gold: SELECT count(DISTINCT T2.Language) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE IndepYear < 1930 AND T2.IsOfficial = "T"
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.IndepYear < "value"
medium gold: SELECT count(DISTINCT T2.Language) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE IndepYear < 1930 AND T2.IsOfficial = "T"
hard pred: SELECT country.Name FROM country WHERE country.SurfaceArea > ( SELECT MAX(country.SurfaceArea) FROM country WHERE country.Continent = "value" )
hard gold: SELECT Name FROM country WHERE SurfaceArea > (SELECT min(SurfaceArea) FROM country WHERE Continent = "Europe")
hard pred: SELECT country.Name FROM country WHERE country.SurfaceArea > ( SELECT MAX(country.SurfaceArea) FROM country WHERE country.Continent = "value" )
hard gold: SELECT Name FROM country WHERE SurfaceArea > (SELECT min(SurfaceArea) FROM country WHERE Continent = "Europe")
extra pred: SELECT country.Name FROM country WHERE country.Population < ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
extra gold: SELECT Name FROM country WHERE Continent = "Africa" AND population < (SELECT max(population) FROM country WHERE Continent = "Asia")
extra pred: SELECT country.Name FROM country WHERE country.Population < ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
extra gold: SELECT Name FROM country WHERE Continent = "Africa" AND population < (SELECT min(population) FROM country WHERE Continent = "Asia")
extra pred: SELECT country.Name FROM country WHERE country.Population > ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
extra gold: SELECT Name FROM country WHERE Continent = "Asia" AND population > (SELECT max(population) FROM country WHERE Continent = "Africa")
extra pred: SELECT country.Name FROM country WHERE country.Population > ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
extra gold: SELECT Name FROM country WHERE Continent = "Asia" AND population > (SELECT min(population) FROM country WHERE Continent = "Africa")
hard pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
hard gold: SELECT CountryCode FROM countrylanguage EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
hard pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
hard gold: SELECT CountryCode FROM countrylanguage EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
hard pred: SELECT country.Code FROM country WHERE country.GovernmentForm = "value" AND country.GovernmentForm != "value"
hard gold: SELECT Code FROM country WHERE GovernmentForm != "Republic" EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
hard pred: SELECT country.Code FROM country WHERE country.GovernmentForm = "value" EXCEPT SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
hard gold: SELECT Code FROM country WHERE GovernmentForm != "Republic" EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
extra pred: SELECT city.Name FROM city WHERE city.CountryCode = "value" AND city.CountryCode != "value"
extra gold: SELECT DISTINCT T2.Name FROM country AS T1 JOIN city AS T2 ON T2.CountryCode = T1.Code WHERE T1.Continent = 'Europe' AND T1.Name NOT IN (SELECT T3.Name FROM country AS T3 JOIN countrylanguage AS T4 ON T3.Code = T4.CountryCode WHERE T4.IsOfficial = 'T' AND T4.Language = 'English')
extra pred: SELECT city.Name FROM city JOIN country WHERE countrylanguage.IsOfficial = "value" AND country.Continent != "value"
extra gold: SELECT DISTINCT T2.Name FROM country AS T1 JOIN city AS T2 ON T2.CountryCode = T1.Code WHERE T1.Continent = 'Europe' AND T1.Name NOT IN (SELECT T3.Name FROM country AS T3 JOIN countrylanguage AS T4 ON T3.Code = T4.CountryCode WHERE T4.IsOfficial = 'T' AND T4.Language = 'English')
hard pred: SELECT city.Name FROM city JOIN country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language = "value"
hard gold: select distinct t3.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode join city as t3 on t1.code = t3.countrycode where t2.isofficial = 't' and t2.language = 'chinese' and t1.continent = "asia"
hard pred: SELECT city.Name FROM city JOIN country JOIN country JOIN countrylanguage WHERE country.Continent = "value" AND countrylanguage.Language = "value"
hard gold: SELECT DISTINCT T3.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode JOIN city AS T3 ON T1.Code = T3.CountryCode WHERE T2.IsOfficial = 'T' AND T2.Language = 'Chinese' AND T1.Continent = "Asia"
medium pred: SELECT country.Population , country.Name , country.SurfaceArea FROM country ORDER BY country.SurfaceArea DESC LIMIT 1
medium gold: SELECT Name , population , HeadOfState FROM country ORDER BY SurfaceArea DESC LIMIT 1
medium pred: SELECT country.Name , COUNT(*) FROM countrylanguage JOIN country GROUP BY country.Name HAVING COUNT(*) >= "value"
medium gold: SELECT COUNT(T2.Language) , T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name HAVING COUNT(*) > 2
medium pred: SELECT country.Name , COUNT(*) FROM countrylanguage JOIN country GROUP BY country.Name HAVING COUNT(*) > "value"
medium gold: SELECT COUNT(T2.Language) , T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name HAVING COUNT(*) > 2
medium pred: SELECT country.GovernmentForm , AVG(country.Population) FROM country GROUP BY country.GovernmentForm HAVING AVG(country.LifeExpectancy) > "value"
medium gold: SELECT sum(Population) , GovernmentForm FROM country GROUP BY GovernmentForm HAVING avg(LifeExpectancy) > 72
medium pred: SELECT country.GovernmentForm , SUM(country.Population) FROM country WHERE country.LifeExpectancy > "value" GROUP BY country.GovernmentForm
medium gold: SELECT sum(Population) , GovernmentForm FROM country GROUP BY GovernmentForm HAVING avg(LifeExpectancy) > 72
medium pred: SELECT AVG(country.LifeExpectancy) , AVG(country.Population) , country.Continent FROM country GROUP BY country.Continent HAVING AVG(country.Population) < "value"
medium gold: SELECT sum(Population) , avg(LifeExpectancy) , Continent FROM country GROUP BY Continent HAVING avg(LifeExpectancy) < 72
medium pred: SELECT country.Continent , SUM(country.Population) , SUM(country.LifeExpectancy) FROM country WHERE country.LifeExpectancy < "value" GROUP BY country.Continent
medium gold: SELECT sum(Population) , avg(LifeExpectancy) , Continent FROM country GROUP BY Continent HAVING avg(LifeExpectancy) < 72
medium pred: SELECT country.Name FROM country WHERE country.Continent = "value" AND country.Population > "value"
medium gold: SELECT Name FROM country WHERE continent = "Europe" AND Population = "80000"
medium pred: SELECT country.Name FROM country WHERE country.Continent = "value" AND country.Population > "value"
medium gold: SELECT Name FROM country WHERE continent = "Europe" AND Population = "80000"
hard pred: SELECT SUM(country.Population) , AVG(country.SurfaceArea) FROM country WHERE country.Continent = "value"
hard gold: select sum(population) , avg(surfacearea) from country where continent = "north america" and surfacearea > 3000
hard pred: SELECT SUM(country.Population) , AVG(country.SurfaceArea) FROM country WHERE country.Continent = "value"
hard gold: select sum(population) , avg(surfacearea) from country where continent = "north america" and surfacearea > 3000
medium pred: SELECT countrylanguage.Language , MAX(countrylanguage.Percentage) FROM countrylanguage GROUP BY countrylanguage.Language ORDER BY SUM(countrylanguage.Percentage) DESC LIMIT 1
medium gold: SELECT LANGUAGE , CountryCode , max(Percentage) FROM countrylanguage GROUP BY CountryCode
medium pred: SELECT countrylanguage.CountryCode , SUM(countrylanguage.Language) FROM countrylanguage GROUP BY countrylanguage.CountryCode ORDER BY SUM(countrylanguage.Percentage) DESC LIMIT 1
medium gold: SELECT LANGUAGE , CountryCode , max(Percentage) FROM countrylanguage GROUP BY CountryCode
extra pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" ORDER BY countrylanguage.Percentage DESC LIMIT 1
extra gold: SELECT count(*) , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
extra pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" GROUP BY * ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT count(*) , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
medium pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language = "value" ORDER BY countrylanguage.Percentage DESC LIMIT 1
medium gold: SELECT CountryCode , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
medium pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language = "value" GROUP BY countrylanguage.CountryCode ORDER BY AVG(countrylanguage.Percentage) DESC LIMIT 1
medium gold: SELECT CountryCode , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
medium pred: SELECT COUNT(*) FROM Friend
medium gold: SELECT student_id , count(*) FROM Friend GROUP BY student_id
medium pred: SELECT Highschooler.name , COUNT(*) FROM Highschooler JOIN Friend GROUP BY Highschooler.name
medium gold: SELECT T2.name , count(*) FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
hard gold: SELECT T3.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id JOIN Highschooler AS T3 ON T1.friend_id = T3.id WHERE T2.name = "Kyle"
hard pred: SELECT Friend.friend_id FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
hard gold: SELECT T3.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id JOIN Highschooler AS T3 ON T1.friend_id = T3.id WHERE T2.name = "Kyle"
medium pred: SELECT COUNT(Friend.friend_id) FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
medium gold: SELECT count(*) FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.name = "Kyle"
hard pred: SELECT Highschooler.ID FROM Highschooler INTERSECT SELECT Likes.student_id FROM Likes
hard gold: SELECT student_id FROM Friend INTERSECT SELECT liked_id FROM Likes
hard pred: SELECT Likes.student_id FROM Likes INTERSECT SELECT Likes.liked_id FROM Likes
hard gold: SELECT student_id FROM Friend INTERSECT SELECT liked_id FROM Likes
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Likes INTERSECT SELECT Highschooler.name FROM Highschooler JOIN Likes
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id INTERSECT SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.liked_id = T2.id
medium pred: SELECT COUNT(*) FROM Likes GROUP BY Likes.student_id
medium gold: SELECT student_id , count(*) FROM Likes GROUP BY student_id
extra pred: SELECT Highschooler.name FROM Highschooler JOIN Likes ORDER BY Highschooler.name DESC LIMIT 1
extra gold: SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Highschooler.name FROM Highschooler JOIN Likes ORDER BY Likes.student_id DESC LIMIT 1
extra gold: SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.grade > "value" GROUP BY Friend.student_id ORDER BY COUNT(*) ASC
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.grade > 5 GROUP BY T1.student_id HAVING count(*) >= 2
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.grade > "value" GROUP BY Friend.student_id ORDER BY COUNT(*) ASC
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.grade > 5 GROUP BY T1.student_id HAVING count(*) >= 2
medium pred: SELECT Likes.liked_id FROM Highschooler JOIN Likes WHERE Highschooler.name = "value"
medium gold: SELECT count(*) FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.name = "Kyle"
hard pred: SELECT AVG(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID IN ( SELECT Friend.student_id FROM Friend )
hard gold: SELECT avg(grade) FROM Highschooler WHERE id IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
hard pred: SELECT AVG(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID IN ( SELECT Friend.student_id FROM Friend )
hard gold: SELECT avg(grade) FROM Highschooler WHERE id IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
extra pred: SELECT MIN(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID NOT IN ( SELECT Friend.student_id FROM Friend )
extra gold: SELECT min(grade) FROM Highschooler WHERE id NOT IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
extra pred: SELECT MIN(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID NOT IN ( SELECT Friend.student_id FROM Friend )
extra gold: SELECT min(grade) FROM Highschooler WHERE id NOT IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
hard pred: SELECT AVG(Dogs.age) FROM Dogs JOIN Treatments
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
hard pred: SELECT AVG(Dogs.age) FROM Treatments JOIN Dogs
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) > "value"
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Professionals.state HAVING COUNT(*) > "value"
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
hard gold: select name from dogs where dog_id not in ( select dog_id from treatments group by dog_id having sum(cost_of_treatment) > 1000 )
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
hard gold: select name from dogs where dog_id not in ( select dog_id from treatments group by dog_id having sum(cost_of_treatment) > 1000 )
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs JOIN Charges GROUP BY Owners.owner_id ORDER BY SUM(Charges.charge_amount) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY SUM(Dogs.date_adopted) DESC LIMIT 1
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Treatments JOIN Dogs GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Dogs JOIN Treatments GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
medium gold: SELECT charge_type , charge_amount FROM Charges
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
medium gold: SELECT charge_type , charge_amount FROM Charges
easy pred: SELECT Charges.charge_type FROM Charges ORDER BY Charges.charge_amount DESC LIMIT 1
easy gold: SELECT max(charge_amount) FROM Charges
easy pred: SELECT Charges.charge_amount FROM Charges ORDER BY Charges.charge_type DESC LIMIT 1
easy gold: SELECT max(charge_amount) FROM Charges
medium pred: SELECT Sizes.size_code , Sizes.size_code FROM Sizes
medium gold: SELECT DISTINCT breed_code , size_code FROM dogs
medium pred: SELECT singer.Name FROM singer JOIN song GROUP BY song.Singer_ID HAVING COUNT(*) > "value"
medium gold: SELECT T1.Name FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID = T2.Singer_ID GROUP BY T1.Name HAVING COUNT(*) > 1
medium pred: SELECT Ref_Property_Types.property_type_description FROM Ref_Property_Types JOIN Properties ORDER BY Properties.property_type_code DESC LIMIT 1
medium gold: SELECT T2.property_type_description FROM Properties AS T1 JOIN Ref_Property_Types AS T2 ON T1.property_type_code = T2.property_type_code GROUP BY T1.property_type_code
hard pred: SELECT Properties.property_name FROM Properties WHERE Properties.property_type_code = "value" UNION SELECT Properties.property_name FROM Properties WHERE Properties.room_count > "value"
hard gold: SELECT property_name FROM Properties WHERE property_type_code = "House" UNION SELECT property_name FROM Properties WHERE property_type_code = "Apartment" AND room_count > 1
easy medium hard extra all
count 248 446 174 166 1034
====================== EXACT MATCHING ACCURACY =====================
exact match 0.927 0.796 0.684 0.512 0.763
---------------------PARTIAL MATCHING ACCURACY----------------------
select 0.964 0.904 0.977 0.898 0.929
select(no AGG) 0.976 0.933 0.977 0.916 0.948
where 0.936 0.876 0.719 0.640 0.816
where(no OP) 0.955 0.887 0.809 0.697 0.852
group(no Having) 0.857 0.885 0.857 0.810 0.857
group 0.857 0.831 0.810 0.797 0.820
order 0.917 0.817 0.820 0.788 0.817
and/or 1.000 0.986 0.977 0.915 0.977
IUEN 0.000 0.000 0.667 0.500 0.582
keywords 0.961 0.950 0.878 0.819 0.913
---------------------- PARTIAL MATCHING RECALL ----------------------
select 0.964 0.904 0.977 0.898 0.929
select(no AGG) 0.976 0.933 0.977 0.916 0.948
where 0.954 0.896 0.681 0.606 0.810
where(no OP) 0.972 0.907 0.766 0.660 0.845
group(no Having) 0.900 0.865 0.923 0.810 0.860
group 0.900 0.812 0.872 0.797 0.823
order 1.000 0.893 0.909 0.848 0.892
and/or 0.992 0.995 0.994 0.974 0.991
IUEN 0.000 0.000 0.667 0.529 0.605
keywords 0.993 0.950 0.868 0.819 0.916
---------------------- PARTIAL MATCHING F1 --------------------------
select 0.964 0.904 0.977 0.898 0.929
select(no AGG) 0.976 0.933 0.977 0.916 0.948
where 0.945 0.886 0.699 0.623 0.813
where(no OP) 0.963 0.897 0.787 0.678 0.849
group(no Having) 0.878 0.875 0.889 0.810 0.858
group 0.878 0.821 0.840 0.797 0.821
order 0.957 0.854 0.862 0.817 0.853
and/or 0.996 0.991 0.985 0.943 0.984
IUEN 1.000 1.000 0.667 0.514 0.594
keywords 0.977 0.950 0.873 0.819 0.914

File diff suppressed because it is too large Load Diff

9
proton/run/run_evaluation.sh Executable file
View File

@ -0,0 +1,9 @@
#!/bin/bash
task=evaluation
read_model_path=saved_models/$1
batch_size=20
beam_size=5
device=0
python scripts/text2sql.py --task $task --testing --read_model_path $read_model_path \
--batch_size $batch_size --beam_size $beam_size --device $device

58
proton/run/run_gan.sh Normal file
View File

@ -0,0 +1,58 @@
task=lgesql_large
seed=999
device=0
testing='' #'--testing'
read_model_path='saved_models/electra-msde-75.3/'
model=lgesql
output_model=with_pruning # without_pruning
local_and_nonlocal=$1 # mmc, msde, local
plm=$2
subword_aggregation=attentive-pooling
schema_aggregation=head+tail
gnn_hidden_size=512
gnn_num_layers=8
relation_share_heads='--relation_share_heads'
score_function='affine'
num_heads=8
dropout=0.2
attn_drop=0.0
drop_connect=0.2
lstm=onlstm
chunk_size=8
att_vec_size=512
sep_cxt=''
lstm_hidden_size=512
lstm_num_layers=1
action_embed_size=128
field_embed_size=64
type_embed_size=64
no_context_feeding='--no_context_feeding'
no_parent_production_embed=''
no_parent_field_embed=''
no_parent_field_type_embed=''
no_parent_state=''
batch_size=20
grad_accumulate=5
lr=1e-4
layerwise_decay=0.8
l2=0.1
smoothing=0.15
warmup_ratio=0.1
lr_schedule=linear
eval_after_epoch=1
max_epoch=200
max_norm=5
beam_size=5
python scripts/text2gan.py --task $task --seed $seed --device $device $testing --read_model_path $read_model_path \
--plm $plm --gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
--subword_aggregation $subword_aggregation --schema_aggregation $schema_aggregation --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
--smoothing $smoothing --layerwise_decay $layerwise_decay --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size

56
proton/run/run_lgesql_glove.sh Executable file
View File

@ -0,0 +1,56 @@
task=lgesql_glove
seed=999
device=0
testing='' #'--testing'
read_model_path=''
model=lgesql
output_model=with_pruning # without_pruning
local_and_nonlocal=$1 # mmc, msde, local
embed_size=300
schema_aggregation=head+tail
gnn_hidden_size=256
gnn_num_layers=8
relation_share_heads='' # '--relation_share_heads'
score_function='affine'
num_heads=8
dropout=0.2
attn_drop=0.0
drop_connect=0.2
lstm=onlstm
chunk_size=8
att_vec_size=512
sep_cxt=''
lstm_hidden_size=512
lstm_num_layers=1
action_embed_size=128
field_embed_size=64
type_embed_size=64
no_context_feeding='--no_context_feeding'
no_parent_production_embed=''
no_parent_field_embed=''
no_parent_field_type_embed=''
no_parent_state=''
batch_size=20
grad_accumulate=2
lr=5e-4
l2=1e-4
smooth=0.15
warmup_ratio=0.1
lr_schedule=linear
eval_after_epoch=60
max_epoch=100
max_norm=5
beam_size=5
python scripts/text2sql.py --task $task --seed $seed --device $device $testing $read_model_path \
--gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
--schema_aggregation $schema_aggregation --embed_size $embed_size --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
--smooth $smooth --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size

58
proton/run/run_lgesql_plm.sh Executable file
View File

@ -0,0 +1,58 @@
task=lgesql_large
seed=999
device=0
testing='' #'--testing'
read_model_path=''
model=lgesql
output_model=with_pruning # without_pruning
local_and_nonlocal=$1 # mmc, msde, local
plm=$2
subword_aggregation=attentive-pooling
schema_aggregation=head+tail
gnn_hidden_size=512
gnn_num_layers=8
relation_share_heads='--relation_share_heads'
score_function='affine'
num_heads=8
dropout=0.2
attn_drop=0.0
drop_connect=0.2
lstm=onlstm
chunk_size=8
att_vec_size=512
sep_cxt=''
lstm_hidden_size=512
lstm_num_layers=1
action_embed_size=128
field_embed_size=64
type_embed_size=64
no_context_feeding='--no_context_feeding'
no_parent_production_embed=''
no_parent_field_embed=''
no_parent_field_type_embed=''
no_parent_state=''
batch_size=20
grad_accumulate=5
lr=1e-4
layerwise_decay=0.8
l2=0.1
smoothing=0.15
warmup_ratio=0.1
lr_schedule=linear
eval_after_epoch=120
max_epoch=200
max_norm=5
beam_size=5
python scripts/text2sql.py --task $task --seed $seed --device $device $testing $read_model_path \
--plm $plm --gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
--subword_aggregation $subword_aggregation --schema_aggregation $schema_aggregation --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
--smoothing $smoothing --layerwise_decay $layerwise_decay --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size

20
proton/run/run_preprocessing.sh Executable file
View File

@ -0,0 +1,20 @@
#!/bin/bash
train_data='data/train.json'
dev_data='data/dev.json'
table_data='data/tables.json'
train_out='data/train.lgesql.bin'
dev_out='data/dev.lgesql.bin'
table_out='data/tables.bin'
vocab_glove='pretrained_models/glove.42b.300d/vocab_glove.txt'
vocab='pretrained_models/glove.42b.300d/vocab.txt'
#echo "Start to preprocess the original train dataset ..."
#python3 -u preprocess/process_dataset.py --dataset_path ${train_data} --raw_table_path ${table_data} --table_path ${table_out} --output_path 'data/train.bin' --skip_large #--verbose > train.log
echo "Start to preprocess the original dev dataset ..."
python3 -u preprocess/process_dataset.py --dataset_path ${dev_data} --raw_table_path ${table_data} --table_path ${table_out} --output_path 'data/dev.bin' #--verbose > dev.log
#echo "Start to build word vocab for the dataset ..."
#python3 -u preprocess/build_glove_vocab.py --data_paths 'data/train.bin' --table_path ${table_out} --reference_file ${vocab_glove} --mwf 4 --output_path ${vocab}
echo "Start to construct graphs for the dataset ..."
#python3 -u preprocess/process_graphs.py --dataset_path 'data/train.bin' --table_path ${table_out} --method 'lgesql' --output_path ${train_out}
python3 -u preprocess/process_graphs.py --dataset_path 'data/dev.bin' --table_path ${table_out} --method 'lgesql' --output_path ${dev_out}

7
proton/run/run_submission.sh Executable file
View File

@ -0,0 +1,7 @@
saved_model=saved_models/$1
output_path=saved_models/$1/predicted_sql.txt
batch_size=10
beam_size=5
python eval.py --db_dir 'data/database' --table_path 'data/tables.json' --dataset_path 'data/dev.json' --saved_model $saved_model --output_path $output_path --batch_size $batch_size --beam_size $beam_size
python evaluation.py --gold data/dev_gold.sql --pred $output_path --db data/database --table data/tables.json --etype match > $saved_model/evaluation.log

124
proton/scripts/text2sql.py Normal file
View File

@ -0,0 +1,124 @@
#coding=utf8
import sys, os, time, json, gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from argparse import Namespace
from utils.args import init_args
from utils.hyperparams import hyperparam_path
from utils.initialization import *
from utils.example import Example
from utils.batch import Batch
from utils.optimization import set_optimizer
from model.model_utils import Registrable
from model.model_constructor import *
# initialization params, output path, logger, random seed and torch.device
args = init_args(sys.argv[1:])
exp_path = hyperparam_path(args)
logger = set_logger(exp_path, args.testing)
set_random_seed(args.seed)
device = set_torch_device(args.device)
logger.info("Initialization finished ...")
logger.info("Output path is %s" % (exp_path))
logger.info("Random seed is set to %d" % (args.seed))
logger.info("Use GPU with index %s" % (args.device) if args.device >= 0 else "Use CPU as target torch device")
# load dataset and vocabulary
start_time = time.time()
if args.read_model_path:
params = json.load(open(os.path.join(args.read_model_path, 'params.json')), object_hook=lambda d: Namespace(**d))
params.lazy_load = True
else:
params = args
# set up the grammar, transition system, evaluator, etc.
Example.configuration(plm=params.plm, method=params.model)
train_dataset, dev_dataset = Example.load_dataset('train'), Example.load_dataset('dev')
logger.info("Load dataset and database finished, cost %.4fs ..." % (time.time() - start_time))
logger.info("Dataset size: train -> %d ; dev -> %d" % (len(train_dataset), len(dev_dataset)))
sql_trans, evaluator = Example.trans, Example.evaluator
args.word_vocab, args.relation_num = len(Example.word_vocab), len(Example.relation_vocab)
# model init, set optimizer
model = Registrable.by_name('text2sql')(params, sql_trans).to(device)
if args.read_model_path:
check_point = torch.load(open(os.path.join(args.read_model_path, 'model.bin'), 'rb'), map_location=device)
model.load_state_dict(check_point['model'])
logger.info("Load saved model from path: %s" % (args.read_model_path))
else:
json.dump(vars(params), open(os.path.join(exp_path, 'params.json'), 'w'), indent=4)
if params.plm is None:
ratio = Example.word2vec.load_embeddings(model.encoder.input_layer.word_embed, Example.word_vocab, device=device)
logger.info("Init model and word embedding layer with a coverage %.2f" % (ratio))
# logger.info(str(model))
def decode(choice, output_path, acc_type='sql', use_checker=False):
assert acc_type in ['beam', 'ast', 'sql'] and choice in ['train', 'dev']
model.eval()
dataset = train_dataset if choice == 'train' else dev_dataset
all_hyps = []
with torch.no_grad():
for i in range(0, len(dataset), args.batch_size):
current_batch = Batch.from_example_list(dataset[i: i + args.batch_size], device, train=False)
hyps = model.parse(current_batch, args.beam_size)
all_hyps.extend(hyps)
acc = evaluator.acc(all_hyps, dataset, output_path, acc_type=acc_type, etype='match', use_checker=use_checker)
torch.cuda.empty_cache()
gc.collect()
return acc
if not args.testing:
num_training_steps = ((len(train_dataset) + args.batch_size - 1) // args.batch_size) * args.max_epoch
num_warmup_steps = int(num_training_steps * args.warmup_ratio)
logger.info('Total training steps: %d;\t Warmup steps: %d' % (num_training_steps, num_warmup_steps))
optimizer, scheduler = set_optimizer(model, args, num_warmup_steps, num_training_steps)
start_epoch, nsamples, best_result = 0, len(train_dataset), {'dev_acc': 0.}
train_index, step_size = np.arange(nsamples), args.batch_size // args.grad_accumulate
if args.read_model_path and args.load_optimizer:
optimizer.load_state_dict(check_point['optim'])
scheduler.load_state_dict(check_point['scheduler'])
start_epoch = check_point['epoch'] + 1
logger.info('Start training ......')
for i in range(start_epoch, args.max_epoch):
start_time = time.time()
epoch_loss, epoch_gp_loss, count = 0, 0, 0
np.random.shuffle(train_index)
model.train()
for j in range(0, nsamples, step_size):
count += 1
cur_dataset = [train_dataset[k] for k in train_index[j: j + step_size]]
current_batch = Batch.from_example_list(cur_dataset, device, train=True, smoothing=args.smoothing)
loss, gp_loss = model(current_batch) # see utils/batch.py for batch elements
epoch_loss += loss.item()
epoch_gp_loss += gp_loss.item()
loss += gp_loss
loss.backward()
if count == args.grad_accumulate or j + step_size >= nsamples:
count = 0
model.pad_embedding_grad_zero()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
logger.info('Training: \tEpoch: %d\tTime: %.4f\tTraining loss: %.4f/%.4f' % (i, time.time() - start_time, epoch_loss, epoch_gp_loss))
torch.cuda.empty_cache()
gc.collect()
if i < args.eval_after_epoch: # avoid unnecessary evaluation
continue
start_time = time.time()
dev_acc = decode('dev', os.path.join(exp_path, 'dev.iter' + str(i)), acc_type='sql')
logger.info('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f' % (i, time.time() - start_time, dev_acc))
if dev_acc > best_result['dev_acc']:
best_result['dev_acc'], best_result['iter'] = dev_acc, i
torch.save({
'epoch': i, 'model': model.state_dict(),
'optim': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, open(os.path.join(exp_path, 'model.bin'), 'wb'))
logger.info('NEW BEST MODEL: \tEpoch: %d\tDev acc: %.4f' % (i, dev_acc))
logger.info('FINAL BEST RESULT: \tEpoch: %d\tDev acc: %.4f' % (best_result['iter'], best_result['dev_acc']))
else:
start_time = time.time()
dev_acc = decode('dev', output_path=os.path.join(args.read_model_path, 'dev.eval'), acc_type='sql')
dev_acc_checker = decode('dev', output_path=os.path.join(args.read_model_path, 'dev.eval.checker'), acc_type='sql', use_checker=True)
logger.info("Evaluation costs %.2fs ; Dev dataset exact match/checker is %.4f/%.4f ." % (time.time() - start_time, dev_acc, dev_acc_checker))

14
proton/setup.sh Normal file
View File

@ -0,0 +1,14 @@
conda create -n text2sql python=3.6
source activate text2sql
pip install torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
python -c "import stanza; stanza.download('en')"
python -c "from embeddings import GloveEmbedding; emb = GloveEmbedding('common_crawl_48', d_emb=300)"
python -c "import nltk; nltk.download('stopwords')"
mkdir -p pretrained_models && cd pretrained_models
git lfs install
git clone https://huggingface.co/bert-large-uncased-whole-word-masking
git clone https://huggingface.co/google/electra-large-discriminator
mkdir -p glove.42b.300d && cd glove.42b.300d
wget -c http://nlp.stanford.edu/data/glove.42B.300d.zip && unzip glove.42B.300d.zip
awk -v FS=' ' '{print $1}' glove.42B.300d.txt > vocab_glove.txt

BIN
proton/static/proton.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 315 KiB

84
proton/utils/args.py Normal file
View File

@ -0,0 +1,84 @@
#coding=utf-8
import argparse
import sys
def init_args(params=sys.argv[1:]):
arg_parser = argparse.ArgumentParser()
arg_parser = add_argument_base(arg_parser)
arg_parser = add_argument_encoder(arg_parser)
arg_parser = add_argument_decoder(arg_parser)
opt = arg_parser.parse_args(params)
if opt.model == 'rgatsql' and opt.local_and_nonlocal == 'msde':
opt.local_and_nonlocal = 'global'
if opt.model == 'lgesql' and opt.local_and_nonlocal == 'global':
opt.local_and_nonlocal = 'msde'
return opt
def add_argument_base(arg_parser):
#### General configuration ####
arg_parser.add_argument('--task', default='text2sql', help='task name')
arg_parser.add_argument('--seed', default=999, type=int, help='Random seed')
arg_parser.add_argument('--device', type=int, default=0, help='Use which device: -1 -> cpu ; the index of gpu o.w.')
arg_parser.add_argument('--testing', action='store_true', help='training or evaluation mode')
arg_parser.add_argument('--read_model_path', type=str, help='read pretrained model path')
#### Training Hyperparams ####
arg_parser.add_argument('--batch_size', default=20, type=int, help='Batch size')
arg_parser.add_argument('--grad_accumulate', default=1, type=int, help='accumulate grad and update once every x steps')
arg_parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
arg_parser.add_argument('--layerwise_decay', type=float, default=1.0, help='layerwise decay rate for lr, used for PLM')
arg_parser.add_argument('--l2', type=float, default=1e-4, help='weight decay coefficient')
arg_parser.add_argument('--warmup_ratio', type=float, default=0.1, help='warmup steps proportion')
arg_parser.add_argument('--lr_schedule', default='linear', choices=['constant', 'linear', 'ratsql', 'cosine'], help='lr scheduler')
arg_parser.add_argument('--eval_after_epoch', default=40, type=int, help='Start to evaluate after x epoch')
arg_parser.add_argument('--load_optimizer', action='store_true', default=False, help='Whether to load optimizer state')
arg_parser.add_argument('--max_epoch', type=int, default=100, help='terminate after maximum epochs')
arg_parser.add_argument('--max_norm', default=5., type=float, help='clip gradients')
return arg_parser
def add_argument_encoder(arg_parser):
# Encoder Hyperparams
arg_parser.add_argument('--model', choices=['rgatsql', 'lgesql'], default='lgesql', help='which text2sql model to use')
arg_parser.add_argument('--local_and_nonlocal', choices=['mmc', 'msde', 'local', 'global'], default='mmc',
help='how to integrate local and non-local relations: mmc -> multi-head multi-view concatenation ; msde -> mixed static and dynamic embeddings')
arg_parser.add_argument('--output_model', choices=['without_pruning', 'with_pruning'], default='without_pruning', help='whether add graph pruning')
arg_parser.add_argument('--plm', type=str, choices=['bert-base-uncased', 'bert-large-uncased', 'bert-large-uncased-whole-word-masking',
'roberta-base', 'roberta-large', 'grappa_large_jnt', 'electra-base-discriminator', 'electra-large-discriminator'], help='pretrained model name')
arg_parser.add_argument('--subword_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling'], default='attentive-pooling', help='aggregate subword feats from PLM')
arg_parser.add_argument('--schema_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling', 'head+tail'], default='head+tail', help='aggregate schema words feats')
arg_parser.add_argument('--dropout', type=float, default=0.2, help='feature dropout rate')
arg_parser.add_argument('--attn_drop', type=float, default=0., help='dropout rate of attention weights')
arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings, only used in glove.42B.300d')
arg_parser.add_argument('--gnn_num_layers', default=8, type=int, help='num of GNN layers in encoder')
arg_parser.add_argument('--gnn_hidden_size', default=256, type=int, help='size of GNN layers hidden states')
arg_parser.add_argument('--num_heads', default=8, type=int, help='num of heads in multihead attn')
arg_parser.add_argument('--relation_share_layers', action='store_true')
arg_parser.add_argument('--relation_share_heads', action='store_true')
arg_parser.add_argument('--score_function', choices=['affine', 'bilinear', 'biaffine', 'dot'], default='affine', help='graph pruning score function')
arg_parser.add_argument('--smoothing', type=float, default=0.15, help='label smoothing factor for graph pruning')
return arg_parser
def add_argument_decoder(arg_parser):
# Decoder Hyperparams
arg_parser.add_argument('--lstm', choices=['lstm', 'onlstm'], default='onlstm', help='Type of LSTM used, ONLSTM or traditional LSTM')
arg_parser.add_argument('--chunk_size', default=8, type=int, help='parameter of ONLSTM')
arg_parser.add_argument('--att_vec_size', default=512, type=int, help='size of attentional vector')
arg_parser.add_argument('--sep_cxt', action='store_true', help='when calculating context vectors, use seperate cxt for question and schema')
arg_parser.add_argument('--drop_connect', type=float, default=0.2, help='recurrent connection dropout rate in decoder lstm')
arg_parser.add_argument('--lstm_num_layers', type=int, default=1, help='num_layers of decoder')
arg_parser.add_argument('--lstm_hidden_size', default=512, type=int, help='Size of LSTM hidden states')
arg_parser.add_argument('--action_embed_size', default=128, type=int, help='Size of ApplyRule/GenToken action embeddings')
arg_parser.add_argument('--field_embed_size', default=64, type=int, help='Embedding size of ASDL fields')
arg_parser.add_argument('--type_embed_size', default=64, type=int, help='Embeddings ASDL types')
arg_parser.add_argument('--no_context_feeding', action='store_true', default=False,
help='Do not use embedding of context vectors')
arg_parser.add_argument('--no_parent_production_embed', default=False, action='store_true',
help='Do not use embedding of parent ASDL production to update decoder LSTM state')
arg_parser.add_argument('--no_parent_field_embed', default=False, action='store_true',
help='Do not use embedding of parent field to update decoder LSTM state')
arg_parser.add_argument('--no_parent_field_type_embed', default=False, action='store_true',
help='Do not use embedding of the ASDL type of parent field to update decoder LSTM state')
arg_parser.add_argument('--no_parent_state', default=False, action='store_true',
help='Do not use the parent hidden state to update decoder LSTM state')
arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search')
arg_parser.add_argument('--decode_max_step', default=100, type=int, help='Maximum number of time steps used in decoding')
return arg_parser

230
proton/utils/batch.py Normal file
View File

@ -0,0 +1,230 @@
#coding=utf8
import torch
import numpy as np
from utils.example import Example, get_position_ids
from utils.constants import PAD, UNK
from model.model_utils import lens2mask, cached_property
import torch.nn.functional as F
def from_example_list_base(ex_list, device='cpu', train=True):
"""
question_lens: torch.long, bsize
questions: torch.long, bsize x max_question_len, include [CLS] if add_cls
table_lens: torch.long, bsize, number of tables for each example
table_word_lens: torch.long, number of words for each table name
tables: torch.long, sum_of_tables x max_table_word_len
column_lens: torch.long, bsize, number of columns for each example
column_word_lens: torch.long, number of words for each column name
columns: torch.long, sum_of_columns x max_column_word_len
"""
batch = Batch(ex_list, device)
plm = Example.plm
pad_idx = Example.word_vocab[PAD] if plm is None else Example.tokenizer.pad_token_id
question_lens = [len(ex.question) for ex in ex_list]
batch.question_lens = torch.tensor(question_lens, dtype=torch.long, device=device)
batch.table_lens = torch.tensor([len(ex.table) for ex in ex_list], dtype=torch.long, device=device)
table_word_lens = [len(t) for ex in ex_list for t in ex.table]
batch.table_word_lens = torch.tensor(table_word_lens, dtype=torch.long, device=device)
batch.column_lens = torch.tensor([len(ex.column) for ex in ex_list], dtype=torch.long, device=device)
column_word_lens = [len(c) for ex in ex_list for c in ex.column]
batch.column_word_lens = torch.tensor(column_word_lens, dtype=torch.long, device=device)
if plm is None: # glove.42B.300d
questions = [ex.question_id + [pad_idx] * (batch.max_question_len - len(ex.question_id)) for ex in ex_list]
batch.questions = torch.tensor(questions, dtype=torch.long, device=device)
tables = [t + [pad_idx] * (batch.max_table_word_len - len(t)) for ex in ex_list for t in ex.table_id]
batch.tables = torch.tensor(tables, dtype=torch.long, device=device)
columns = [c + [pad_idx] * (batch.max_column_word_len - len(c)) for ex in ex_list for c in ex.column_id]
batch.columns = torch.tensor(columns, dtype=torch.long, device=device)
else:
# prepare inputs for pretrained models
batch.inputs = {"input_ids": None, "attention_mask": None, "token_type_ids": None, "position_ids": None}
input_lens = [len(ex.input_id) for ex in ex_list]
max_len = max(input_lens)
input_ids = [ex.input_id + [pad_idx] * (max_len - len(ex.input_id)) for ex in ex_list]
batch.inputs["input_ids"] = torch.tensor(input_ids, dtype=torch.long, device=device)
attention_mask = [[1] * l + [0] * (max_len - l) for l in input_lens]
batch.inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.float, device=device)
token_type_ids = [ex.segment_id + [0] * (max_len - len(ex.segment_id)) for ex in ex_list]
batch.inputs["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long, device=device)
position_ids = [get_position_ids(ex, shuffle=train) + [0] * (max_len - len(ex.input_id)) for ex in ex_list]
batch.inputs["position_ids"] = torch.tensor(position_ids, dtype=torch.long, device=device)
# extract representations after plm, remove [SEP]
question_mask_plm = [ex.question_mask_plm + [0] * (max_len - len(ex.question_mask_plm)) for ex in ex_list]
batch.question_mask_plm = torch.tensor(question_mask_plm, dtype=torch.bool, device=device)
table_mask_plm = [ex.table_mask_plm + [0] * (max_len - len(ex.table_mask_plm)) for ex in ex_list]
batch.table_mask_plm = torch.tensor(table_mask_plm, dtype=torch.bool, device=device)
column_mask_plm = [ex.column_mask_plm + [0] * (max_len - len(ex.column_mask_plm)) for ex in ex_list]
batch.column_mask_plm = torch.tensor(column_mask_plm, dtype=torch.bool, device=device)
# subword aggregation
question_subword_lens = [l for ex in ex_list for l in ex.question_subword_len]
batch.question_subword_lens = torch.tensor(question_subword_lens, dtype=torch.long, device=device)
table_subword_lens = [l for ex in ex_list for l in ex.table_subword_len]
batch.table_subword_lens = torch.tensor(table_subword_lens, dtype=torch.long, device=device)
column_subword_lens = [l for ex in ex_list for l in ex.column_subword_len]
batch.column_subword_lens = torch.tensor(column_subword_lens, dtype=torch.long, device=device)
batch.question_unk_mask, batch.table_unk_mask, batch.column_unk_mask = None, None, None
if not train and plm is None:
# during evaluation, for words not in vocab but in glove vocab, extract its correpsonding embedding
word2vec, unk_idx = Example.word2vec, Example.word_vocab[UNK]
question_unk_mask = (batch.questions == unk_idx).cpu()
if question_unk_mask.any().item():
raw_questions = np.array([ex.question + [PAD] * (batch.max_question_len - len(ex.question)) for ex in ex_list], dtype='<U100')
unk_words = raw_questions[question_unk_mask.numpy()].tolist()
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
if oov_flag.any().item():
batch.question_unk_mask = question_unk_mask.masked_scatter_(torch.clone(question_unk_mask), oov_flag).to(device)
batch.question_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
table_unk_mask = (batch.tables == unk_idx).cpu()
if table_unk_mask.any().item():
raw_tables = np.array([t + [PAD] * (batch.max_table_word_len - len(t)) for ex in ex_list for t in ex.table], dtype='<U100')
unk_words = raw_tables[table_unk_mask.numpy()].tolist()
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
if oov_flag.any().item():
batch.table_unk_mask = table_unk_mask.masked_scatter_(torch.clone(table_unk_mask), oov_flag).to(device)
batch.table_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
column_unk_mask = (batch.columns == unk_idx).cpu()
if column_unk_mask.any().item():
raw_columns = np.array([c + [PAD] * (batch.max_column_word_len - len(c)) for ex in ex_list for c in ex.column], dtype='<U100')
unk_words = raw_columns[column_unk_mask.numpy()].tolist()
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
if oov_flag.any().item():
batch.column_unk_mask = column_unk_mask.masked_scatter_(torch.clone(column_unk_mask), oov_flag).to(device)
batch.column_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
return batch
def from_example_list_text2sql(ex_list, device='cpu', train=True, **kwargs):
""" New fields: batch.lens, batch.max_len, batch.relations, batch.relations_mask
"""
batch = from_example_list_base(ex_list, device, train)
batch.graph = Example.graph_factory.batch_graphs(ex_list, device, train=train, **kwargs)
if train:
batch.max_action_num = max([len(ex.tgt_action) for ex in ex_list])
return batch
class Batch():
def __init__(self, examples, device='cpu'):
super(Batch, self).__init__()
self.examples = examples
self.device = device
@classmethod
def from_example_list(cls, ex_list, device='cpu', train=True, method='text2sql', **kwargs):
method_dict = {
"text2sql": from_example_list_text2sql,
}
return method_dict[method](ex_list, device, train=train, **kwargs)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
@cached_property
def max_question_len(self):
return torch.max(self.question_lens).item()
@cached_property
def max_table_len(self):
return torch.max(self.table_lens).item()
@cached_property
def max_column_len(self):
return torch.max(self.column_lens).item()
@cached_property
def max_table_word_len(self):
return torch.max(self.table_word_lens).item()
@cached_property
def max_column_word_len(self):
return torch.max(self.column_word_lens).item()
@cached_property
def max_question_subword_len(self):
return torch.max(self.question_subword_lens).item()
@cached_property
def max_table_subword_len(self):
return torch.max(self.table_subword_lens).item()
@cached_property
def max_column_subword_len(self):
return torch.max(self.column_subword_lens).item()
""" Different types of nodes are seperated instead of concatenated together """
@cached_property
def mask(self):
return torch.cat([self.question_mask, self.table_mask, self.column_mask], dim=1)
@cached_property
def question_mask(self):
return lens2mask(self.question_lens)
@cached_property
def table_mask(self):
return lens2mask(self.table_lens)
@cached_property
def column_mask(self):
return lens2mask(self.column_lens)
@cached_property
def table_word_mask(self):
return lens2mask(self.table_word_lens)
@cached_property
def column_word_mask(self):
return lens2mask(self.column_word_lens)
@cached_property
def question_subword_mask(self):
return lens2mask(self.question_subword_lens)
@cached_property
def table_subword_mask(self):
return lens2mask(self.table_subword_lens)
@cached_property
def column_subword_mask(self):
return lens2mask(self.column_subword_lens)
def get_frontier_field_idx(self, t):
ids = []
for e in self.examples:
if t < len(e.tgt_action):
ids.append(Example.grammar.field2id[e.tgt_action[t].frontier_field])
# assert self.grammar.id2field[ids[-1]] == e.tgt_action[t].frontier_field
else:
ids.append(0)
return torch.tensor(ids, dtype=torch.long, device=self.device)
def get_frontier_prod_idx(self, t):
ids = []
for e in self.examples:
if t < len(e.tgt_action):
ids.append(Example.grammar.prod2id[e.tgt_action[t].frontier_prod])
# assert self.grammar.id2prod[ids[-1]] == e.tgt_action[t].frontier_prod
else:
ids.append(0)
return torch.tensor(ids, dtype=torch.long, device=self.device)
def get_frontier_field_type_idx(self, t):
ids = []
for e in self.examples:
if t < len(e.tgt_action):
ids.append(Example.grammar.type2id[e.tgt_action[t].frontier_field.type])
# assert self.grammar.id2type[ids[-1]] == e.tgt_action[t].frontier_field.type
else:
ids.append(0)
return torch.tensor(ids, dtype=torch.long, device=self.device)

19
proton/utils/constants.py Normal file
View File

@ -0,0 +1,19 @@
PAD = '[PAD]'
BOS = '[CLS]'
EOS = '[SEP]'
UNK = '[UNK]'
GRAMMAR_FILEPATH = 'asdl/sql/grammar/sql_asdl_v2.txt'
SCHEMA_TYPES = ['table', 'others', 'text', 'time', 'number', 'boolean']
MAX_RELATIVE_DIST = 2
# relations: type_1-type_2-rel_name, r represents reverse edge, b represents bidirectional edge
RELATIONS = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity' for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1)] + \
['table-table-identity', 'table-table-fk', 'table-table-fkr', 'table-table-fkb'] + \
['column-column-identity', 'column-column-sametable', 'column-column-fk', 'column-column-fkr'] + \
['table-column-pk', 'column-table-pk', 'table-column-has', 'column-table-has'] + \
['question-column-exactmatch', 'question-column-partialmatch','question-column-partialsemanticmatch', 'question-column-nomatch','question-column-semanticmatch', 'question-column-valuematch',
'column-question-exactmatch', 'column-question-partialmatch','column-question-partialsemanticmatch', 'column-question-nomatch','column-question-semanticmatch', 'column-question-valuematch'] + \
['question-table-exactmatch', 'question-table-partialmatch','question-table-partialsemanticmatch', 'question-table-nomatch','question-table-semanticmatch',
'table-question-exactmatch', 'table-question-partialmatch', 'table-question-partialsemanticmatch','table-question-semanticmatch','table-question-nomatch'] + \
['question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic'] + \
['*-*-identity', '*-question-generic', 'question-*-generic', '*-table-generic', 'table-*-generic', '*-column-generic', 'column-*-generic']

375
proton/utils/evaluator.py Normal file
View File

@ -0,0 +1,375 @@
#coding=utf8
import sys, tempfile, os
import pickle, json
import numpy as np
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from evaluation import evaluate, build_foreign_key_map_from_json, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, eval_exec_match
from evaluation import Evaluator as Engine
from process_sql import get_schema, Schema, get_sql
class Evaluator():
def __init__(self, transition_system, table_path='data/tables.json', database_dir='data/database'):
super(Evaluator, self).__init__()
self.transition_system = transition_system
self.kmaps = build_foreign_key_map_from_json(table_path)
self.database_dir = database_dir
self.engine = Engine()
self.checker = Checker(table_path, database_dir)
self.acc_dict = {
"sql": self.sql_acc, # use golden sql as references
"ast": self.ast_acc, # compare ast accuracy, ast may be incorrect when constructed from raw sql
"beam": self.beam_acc, # if the correct answer exist in the beam, assume the result is true
}
def acc(self, pred_hyps, dataset, output_path=None, acc_type='sql', etype='match', use_checker=False):
assert len(pred_hyps) == len(dataset) and acc_type in self.acc_dict and etype in ['match', 'exec']
acc_method = self.acc_dict[acc_type]
return acc_method(pred_hyps, dataset, output_path, etype, use_checker)
def beam_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
scores, results = {}, []
for each in ['easy', 'medium', 'hard', 'extra', 'all']:
scores[each] = [0, 0.] # first is count, second is total score
for idx, pred in enumerate(pred_hyps):
question, gold_sql, db = dataset[idx].ex['question'], dataset[idx].query, dataset[idx].db
for b_id, hyp in enumerate(pred):
pred_sql = self.transition_system.ast_to_surface_code(hyp.tree, db)
score, hardness = self.single_acc(pred_sql, gold_sql, db['db_id'], etype)
if int(score) == 1:
scores[hardness][0] += 1
scores[hardness][1] += 1.0
scores['all'][0] += 1
scores['all'][1] += 1.0
results.append((hardness, question, gold_sql, pred_sql, b_id, True))
break
else:
scores[hardness][0] += 1
scores['all'][0] += 1
pred_sql = self.transition_system.ast_to_surface_code(pred[0].tree, db)
results.append((hardness, question, gold_sql, pred_sql, 0, False))
for each in scores:
accuracy = scores[each][1] / float(scores[each][0]) if scores[each][0] != 0 else 0.
scores[each].append(accuracy)
of = open(output_path, 'w', encoding='utf8') if output_path is not None else \
tempfile.TemporaryFile('w+t')
for item in results:
of.write('Level: %s\n' % (item[0]))
of.write('Question: %s\n' % (item[1]))
of.write('Gold SQL: %s\n' %(item[2]))
of.write('Pred SQL (%s): %s\n' % (item[4], item[3]))
of.write('Correct: %s\n\n' % (item[5]))
for each in scores:
of.write('Level %s: %s\n' % (each, scores[each]))
of.close()
return scores['all'][2]
def single_acc(self, pred_sql, gold_sql, db, etype):
"""
@return:
score(float): 0 or 1, etype score
hardness(str): one of 'easy', 'medium', 'hard', 'extra'
"""
db_name = db
db = os.path.join(self.database_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, gold_sql)
hardness = self.engine.eval_hardness(g_sql)
try:
p_sql = get_sql(schema, pred_sql)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
kmap = self.kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) # kmap: map __tab.col__ to pivot __tab.col__
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
if etype == 'exec':
score = float(eval_exec_match(db, pred_sql, gold_sql, p_sql, g_sql))
if etype == 'match':
score = float(self.engine.eval_exact_match(p_sql, g_sql))
return score, hardness
def ast_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
pred_asts = [hyp[0].tree for hyp in pred_hyps]
ref_asts = [ex.ast for ex in dataset]
dbs = [ex.db for ex in dataset]
pred_sqls, ref_sqls = [], []
for pred, ref, db in zip(pred_asts, ref_asts, dbs):
pred_sql = self.transition_system.ast_to_surface_code(pred, db)
ref_sql = self.transition_system.ast_to_surface_code(ref, db)
pred_sqls.append(pred_sql)
ref_sqls.append(ref_sql)
with tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_pred, \
tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_ref:
of = open(output_path, 'w', encoding='utf8') if output_path is not None \
else tempfile.TemporaryFile('w+t')
# write pred and ref sqls
for s in pred_sqls:
tmp_pred.write(s + '\n')
tmp_pred.flush()
for s, db in zip(ref_sqls, dbs):
tmp_ref.write(s + '\t' + db['db_id'] + '\n')
tmp_ref.flush()
# calculate ast accuracy
old_print = sys.stdout
sys.stdout = of
result_type = 'exact' if etype == 'match' else 'exec'
all_exact_acc = evaluate(tmp_ref.name, tmp_pred.name, self.database_dir, etype, self.kmaps)['all'][result_type]
sys.stdout = old_print
of.close()
return float(all_exact_acc)
def sql_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
pred_sqls, ref_sqls = [], [ex.query for ex in dataset]
dbs = [ex.db for ex in dataset]
for idx, hyp in enumerate(pred_hyps):
if use_checker:
pred_sql = self.obtain_sql(hyp, dbs[idx])
else:
best_ast = hyp[0].tree # by default, the top beam prediction
pred_sql = self.transition_system.ast_to_surface_code(best_ast, dbs[idx])
pred_sqls.append(pred_sql)
with tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_pred, \
tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_ref:
of = open(output_path, 'w', encoding='utf8') if output_path is not None \
else tempfile.TemporaryFile('w+t')
# write pred and ref sqls
for s in pred_sqls:
tmp_pred.write(s + '\n')
tmp_pred.flush()
for s, db in zip(ref_sqls, dbs):
tmp_ref.write(s + '\t' + db['db_id'] + '\n')
tmp_ref.flush()
# calculate sql accuracy
old_print = sys.stdout
sys.stdout = of
result_type = 'exact' if etype == 'match' else 'exec'
all_exact_acc = evaluate(tmp_ref.name, tmp_pred.name, self.database_dir, etype, self.kmaps)['all'][result_type]
sys.stdout = old_print
of.close()
return float(all_exact_acc)
def obtain_sql(self, hyps, db):
beam = len(hyps)
for hyp in hyps:
cur_ast = hyp.tree
pred_sql = self.transition_system.ast_to_surface_code(cur_ast, db)
if self.checker.validity_check(pred_sql, db['db_id']):
sql = pred_sql
break
else:
best_ast = hyps[0].tree
sql = self.transition_system.ast_to_surface_code(best_ast, db)
return sql
class Checker():
def __init__(self, table_path='data/tables.json', db_dir='data/database'):
super(Checker, self).__init__()
self.table_path = table_path
self.db_dir = db_dir
self.schemas, self.database, self.tables = self._get_schemas_from_json(self.table_path)
def _get_schemas_from_json(self, table_path):
database_list = json.load(open(table_path, 'r'))
database = {}
for db in database_list:
database[db['db_id']] = db
tables, schemas = {}, {}
for db in database_list:
db_id = db['db_id']
schema = {}
column_names_original = db['column_names_original']
table_names_original = db['table_names_original']
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower()) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
return schemas, database, tables
def validity_check(self, sql: str, db: str):
""" Check whether the given sql query is valid, including:
1. only use columns in tables mentioned in FROM clause
2. comparison operator or MAX/MIN/SUM/AVG only applied to columns of type number/time
@params:
sql(str): SQL query
db(str): db_id field, database name
@return:
flag(boolean)
"""
schema, table = self.schemas[db], self.tables[db]
schema = SchemaID(schema, table)
try:
sql = get_sql(schema, sql)
return self.sql_check(sql, self.database[db])
except Exception as e:
print('Runtime error occurs:', e)
return False
def sql_check(self, sql: dict, db: dict):
if sql['intersect']:
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['intersect'], db)
if sql['union']:
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['union'], db)
if sql['except']:
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['except'], db)
return self.sqlunit_check(sql, db)
def sqlunit_check(self, sql: dict, db: dict):
if sql['from']['table_units'][0][0] == 'sql':
if not self.sql_check(sql['from']['table_units'][0][1], db): return False
table_ids = []
else:
table_ids = list(map(lambda table_unit: table_unit[1], sql['from']['table_units']))
return self.select_check(sql['select'], table_ids, db) & \
self.cond_check(sql['where'], table_ids, db) & \
self.groupby_check(sql['groupBy'], table_ids, db) & \
self.cond_check(sql['having'], table_ids, db) & \
self.orderby_check(sql['orderBy'], table_ids, db)
def select_check(self, select, table_ids: list, db: dict):
select = select[1]
for agg_id, val_unit in select:
if not self.valunit_check(val_unit, table_ids, db): return False
# MAX/MIN/SUM/AVG
# if agg_id in [1, 2, 4, 5] and (self.valunit_type(val_unit, db) not in ['number', 'time']):
# return False
return True
def cond_check(self, cond, table_ids: list, db: dict):
if len(cond) == 0:
return True
for idx in range(0, len(cond), 2):
cond_unit = cond[idx]
_, cmp_op, val_unit, val1, val2 = cond_unit
flag = self.valunit_check(val_unit, table_ids, db)
# if cmp_op in [3, 4, 5, 6]: # >, <, >=, <=
# flag &= (self.valunit_type(val_unit, db) in ['number', 'time'])
if type(val1) == dict:
flag &= self.sql_check(val1, db)
if type(val2) == dict:
flag &= self.sql_check(val2, db)
if not flag: return False
return True
def groupby_check(self, groupby, table_ids: list, db: dict):
if not groupby: return True
for col_unit in groupby:
if not self.colunit_check(col_unit, table_ids, db): return False
return True
def orderby_check(self, orderby, table_ids: list, db: dict):
if not orderby: return True
orderby = orderby[1]
for val_unit in orderby:
if not self.valunit_check(val_unit, table_ids, db): return False
return True
def colunit_check(self, col_unit: list, table_ids: list, db: dict):
""" Check from the following aspects:
1. column belongs to the tables in FROM clause
2. column type is valid for AGG_OP
"""
agg_id, col_id, _ = col_unit
if col_id == 0: return True
tab_id = db['column_names'][col_id][0]
if tab_id not in table_ids: return False
col_type = db['column_types'][col_id]
if agg_id in [1, 2, 4, 5]: # MAX, MIN, SUM, AVG
return (col_type in ['time', 'number'])
return True
def valunit_check(self, val_unit: list, table_ids: list, db: dict):
unit_op, col_unit1, col_unit2 = val_unit
if unit_op == 0:
return self.colunit_check(col_unit1, table_ids, db)
if not (self.colunit_check(col_unit1, table_ids, db) and self.colunit_check(col_unit2, table_ids, db)):
return False
# COUNT/SUM/AVG -> number
agg_id1, col_id1, _ = col_unit1
agg_id2, col_id2, _ = col_unit2
t1 = 'number' if agg_id1 > 2 else db['column_types'][col_id1]
t2 = 'number' if agg_id2 > 2 else db['column_types'][col_id2]
if (t1 not in ['number', 'time']) or (t2 not in ['number', 'time']) or t1 != t2:
return False
return True
def valunit_type(self, val_unit: list, db: dict):
unit_op, col_unit1, col_unit2 = val_unit
if unit_op == 0:
agg_id, col_id, _ = col_unit1
if agg_id > 2: return 'number'
else: return ('number' if col_id == 0 else db['column_types'][col_id])
else:
return 'number'
class SchemaID():
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
column_names_original = table['column_names_original']
table_names_original = table['table_names_original']
#print 'column_names_original: ', column_names_original
#print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {'*': i}
else:
key = table_names_original[tab_id].lower()
val = col.lower()
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
if __name__ == '__main__':
checker = Checker('data/tables.json', 'data/database')
train, dev = json.load(open('data/train.json', 'r')), json.load(open('data/dev.json', 'r'))
test = json.load(open('data/test.json', 'r'))
count = 0
for idx, ex in enumerate(train):
sql, db = ex['query'].strip(), ex['db_id']
flag = checker.validity_check(sql, db)
if not flag:
print(idx, ': ' + sql + '\t' + db)
count += 1
print('Total invalid is %d' % (count))

151
proton/utils/example.py Normal file
View File

@ -0,0 +1,151 @@
#coding=utf8
import os, pickle, json
import torch, random
import numpy as np
from asdl.asdl import ASDLGrammar
from asdl.transition_system import TransitionSystem
from utils.constants import UNK, GRAMMAR_FILEPATH, SCHEMA_TYPES, RELATIONS
from utils.graph_example import GraphFactory
from utils.vocab import Vocab
from utils.word2vec import Word2vecUtils
from transformers import AutoTokenizer
from utils.evaluator import Evaluator
from itertools import chain
class Example():
@classmethod
def configuration(cls, plm=None, method='lgesql', table_path='data/tables.json', tables='data/tables.bin', db_dir='data/database'):
cls.plm, cls.method = plm, method
cls.grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
cls.trans = TransitionSystem.get_class_by_lang('sql')(cls.grammar)
cls.tables = pickle.load(open(tables, 'rb')) if type(tables) == str else tables
cls.evaluator = Evaluator(cls.trans, table_path, db_dir)
if plm is None:
cls.word2vec = Word2vecUtils()
cls.tokenizer = lambda x: x
cls.word_vocab = Vocab(padding=True, unk=True, boundary=True, default=UNK,
filepath='./pretrained_models/glove.42b.300d/vocab.txt', specials=SCHEMA_TYPES) # word vocab for glove.42B.300d
else:
cls.tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', plm))
cls.word_vocab = cls.tokenizer.get_vocab()
cls.relation_vocab = Vocab(padding=False, unk=False, boundary=False, iterable=RELATIONS, default=None)
cls.graph_factory = GraphFactory(cls.method, cls.relation_vocab)
@classmethod
def load_dataset(cls, choice, debug=False):
assert choice in ['train', 'dev']
fp = os.path.join('data', choice + '.' + cls.method + '.bin')
datasets = pickle.load(open(fp, 'rb'))
# question_lens = [len(ex['processed_question_toks']) for ex in datasets]
# print('Max/Min/Avg question length in %s dataset is: %d/%d/%.2f' % (choice, max(question_lens), min(question_lens), float(sum(question_lens))/len(question_lens)))
# action_lens = [len(ex['actions']) for ex in datasets]
# print('Max/Min/Avg action length in %s dataset is: %d/%d/%.2f' % (choice, max(action_lens), min(action_lens), float(sum(action_lens))/len(action_lens)))
examples, outliers = [], 0
for ex in datasets:
if choice == 'train' and len(cls.tables[ex['db_id']]['column_names']) > 100:
outliers += 1
continue
examples.append(cls(ex, cls.tables[ex['db_id']]))
if debug and len(examples) >= 100:
return examples
if choice == 'train':
print("Skip %d extremely large samples in training dataset ..." % (outliers))
return examples
def __init__(self, ex: dict, db: dict):
super(Example, self).__init__()
self.ex = ex
self.db = db
""" Mapping word to corresponding index """
if Example.plm is None:
self.question = ex['processed_question_toks']
self.question_id = [Example.word_vocab[w] for w in self.question]
self.column = [[db['column_types'][idx].lower()] + c for idx, c in enumerate(db['processed_column_toks'])]
self.column_id = [[Example.word_vocab[w] for w in c] for c in self.column]
self.table = [['table'] + t for t in db['processed_table_toks']]
self.table_id = [[Example.word_vocab[w] for w in t] for t in self.table]
else:
t = Example.tokenizer
self.question = [q.lower() for q in ex['raw_question_toks']]
self.question_id = [t.cls_token_id] # map token to id
self.question_mask_plm = [] # remove SEP token in our case
self.question_subword_len = [] # subword len for each word, exclude SEP token
for w in self.question:
toks = t.convert_tokens_to_ids(t.tokenize(w))
self.question_id.extend(toks)
self.question_subword_len.append(len(toks))
self.question_mask_plm = [0] + [1] * (len(self.question_id) - 1) + [0]
self.question_id.append(t.sep_token_id)
self.table = [['table'] + t.lower().split() for t in db['table_names']]
self.table_id, self.table_mask_plm, self.table_subword_len = [], [], []
self.table_word_len = []
for s in self.table:
l = 0
for w in s:
toks = t.convert_tokens_to_ids(t.tokenize(w))
self.table_id.extend(toks)
self.table_subword_len.append(len(toks))
l += len(toks)
self.table_word_len.append(l)
self.table_mask_plm = [1] * len(self.table_id)
self.column = [[db['column_types'][idx].lower()] + c.lower().split() for idx, (_, c) in enumerate(db['column_names'])]
self.column_id, self.column_mask_plm, self.column_subword_len = [], [], []
self.column_word_len = []
for s in self.column:
l = 0
for w in s:
toks = t.convert_tokens_to_ids(t.tokenize(w))
self.column_id.extend(toks)
self.column_subword_len.append(len(toks))
l += len(toks)
self.column_word_len.append(l)
self.column_mask_plm = [1] * len(self.column_id) + [0]
self.column_id.append(t.sep_token_id)
self.input_id = self.question_id + self.table_id + self.column_id
self.segment_id = [0] * len(self.question_id) + [1] * (len(self.table_id) + len(self.column_id)) \
if Example.plm != 'grappa_large_jnt' and not Example.plm.startswith('roberta') \
else [0] * (len(self.question_id) + len(self.table_id) + len(self.column_id))
self.question_mask_plm = self.question_mask_plm + [0] * (len(self.table_id) + len(self.column_id))
self.table_mask_plm = [0] * len(self.question_id) + self.table_mask_plm + [0] * len(self.column_id)
self.column_mask_plm = [0] * (len(self.question_id) + len(self.table_id)) + self.column_mask_plm
self.graph = Example.graph_factory.graph_construction(ex, db)
# outputs
self.query = ' '.join(ex['query'].split('\t'))
self.ast = ex['ast']
self.tgt_action = ex['actions']
self.used_tables, self.used_columns = ex['used_tables'], ex['used_columns']
def get_position_ids(ex, shuffle=True):
# cluster columns with their corresponding table and randomly shuffle tables and columns
# [CLS] q1 q2 ... [SEP] * t1 c1 c2 c3 t2 c4 c5 ... [SEP]
db, table_word_len, column_word_len = ex.db, ex.table_word_len, ex.column_word_len
table_num, column_num = len(db['table_names']), len(db['column_names'])
question_position_id = list(range(len(ex.question_id)))
start = len(question_position_id)
table_position_id, column_position_id = [None] * table_num, [None] * column_num
column_position_id[0] = list(range(start, start + column_word_len[0]))
start += column_word_len[0] # special symbol * first
table_idxs = list(range(table_num))
if shuffle:
random.shuffle(table_idxs)
for idx in table_idxs:
col_idxs = db['table2columns'][idx]
table_position_id[idx] = list(range(start, start + table_word_len[idx]))
start += table_word_len[idx]
if shuffle:
random.shuffle(col_idxs)
for col_id in col_idxs:
column_position_id[col_id] = list(range(start, start + column_word_len[col_id]))
start += column_word_len[col_id]
position_id = question_position_id + list(chain.from_iterable(table_position_id)) + \
list(chain.from_iterable(column_position_id)) + [start]
assert len(position_id) == len(ex.input_id)
return position_id

View File

@ -0,0 +1,76 @@
#coding=utf8
import numpy as np
import dgl, torch, math
class GraphExample():
pass
class BatchedGraph():
pass
class GraphFactory():
def __init__(self, method='rgatsql', relation_vocab=None):
super(GraphFactory, self).__init__()
self.method = eval('self.' + method)
self.batch_method = eval('self.batch_' + method)
self.relation_vocab = relation_vocab
def graph_construction(self, ex: dict, db: dict):
return self.method(ex, db)
def rgatsql(self, ex, db):
graph = GraphExample()
local_edges = ex['graph'].local_edges
rel_ids = list(map(lambda r: self.relation_vocab[r[2]], local_edges))
graph.local_edges = torch.tensor(rel_ids, dtype=torch.long)
global_edges = ex['graph'].global_edges
rel_ids = list(map(lambda r: self.relation_vocab[r[2]], global_edges))
graph.global_edges = torch.tensor(rel_ids, dtype=torch.long)
graph.local_g, graph.global_g = ex['graph'].local_g, ex['graph'].global_g
graph.gp = ex['graph'].gp
graph.question_mask = torch.tensor(ex['graph'].question_mask, dtype=torch.bool)
graph.schema_mask = torch.tensor(ex['graph'].schema_mask, dtype=torch.bool)
graph.node_label = torch.tensor(ex['graph'].node_label, dtype=torch.float)
# extract local relations (used in msde), global_edges = local_edges + nonlocal_edges
local_enum, global_enum = graph.local_edges.size(0), graph.global_edges.size(0)
graph.local_mask = torch.tensor([1] * local_enum + [0] * (global_enum - local_enum), dtype=torch.bool)
return graph
def lgesql(self, ex, db):
graph = self.rgatsql(ex, db)
# add line graph
graph.lg = ex['graph'].lg
return graph
def batch_graphs(self, ex_list, device, train=True, **kwargs):
""" Batch graphs in example list """
return self.batch_method(ex_list, device, train=train, **kwargs)
def batch_lgesql(self, ex_list, device, train=True, **kwargs):
bg = self.batch_rgatsql(ex_list, device, train=train, **kwargs)
src_ids, dst_ids = bg.local_g.edges(order='eid')
bg.src_ids, bg.dst_ids = src_ids.long(), dst_ids.long()
bg.lg = dgl.batch([ex.graph.lg for ex in ex_list]).to(device)
return bg
def batch_rgatsql(self, ex_list, device, train=True, **kwargs):
# method = kwargs.pop('local_and_nonlocal', 'global')
graph_list = [ex.graph for ex in ex_list]
bg = BatchedGraph()
bg.local_g = dgl.batch([ex.local_g for ex in graph_list]).to(device)
# bg.local_edges = torch.cat([ex.local_edges for ex in graph_list], dim=0).to(device)
bg.local_mask = torch.cat([ex.graph.local_mask for ex in ex_list], dim=0).to(device)
bg.global_g = dgl.batch([ex.global_g for ex in graph_list]).to(device)
bg.global_edges = torch.cat([ex.global_edges for ex in graph_list], dim=0).to(device)
if train:
bg.question_mask = torch.cat([ex.question_mask for ex in graph_list], dim=0).to(device)
bg.schema_mask = torch.cat([ex.schema_mask for ex in graph_list], dim=0).to(device)
smoothing = kwargs.pop('smoothing', 0.0)
node_label = torch.cat([ex.node_label for ex in graph_list], dim=0)
node_label = node_label.masked_fill_(~ node_label.bool(), 2 * smoothing) - smoothing
bg.node_label = node_label.to(device)
bg.gp = dgl.batch([ex.gp for ex in graph_list]).to(device)
return bg

View File

@ -0,0 +1,48 @@
#coding=utf8
import sys, os
EXP_PATH = 'exp'
def hyperparam_path(args):
if args.read_model_path and args.testing:
return args.read_model_path
exp_path = hyperparam_path_text2sql(args)
if not os.path.exists(exp_path):
os.makedirs(exp_path)
return exp_path
def hyperparam_path_text2sql(args):
task = 'task_%s__model_%s_view_%s' % (args.task, args.model, args.local_and_nonlocal)
task += '' if 'without' in args.output_model else '_gp_%s' % (args.smoothing)
# encoder params
exp_path = 'emb_%s' % (args.embed_size) if args.plm is None else 'plm_%s' % (args.plm)
exp_path += '__gnn_%s_x_%s' % (args.gnn_hidden_size, args.gnn_num_layers)
exp_path += '__share' if args.relation_share_layers else ''
exp_path += '__head_%s' % (args.num_heads)
exp_path += '__share' if args.relation_share_heads else ''
exp_path += '__dp_%s' % (args.dropout)
exp_path += '__dpa_%s' % (args.attn_drop)
exp_path += '__dpc_%s' % (args.drop_connect)
# decoder params
# exp_path += '__cell_%s_%s_x_%s' % (args.lstm, args.lstm_hidden_size, args.lstm_num_layers)
# exp_path += '_chunk_%s' % (args.chunk_size) if args.lstm == 'onlstm' else ''
# exp_path += '_no' if args.no_parent_state else ''
# exp_path += '__attvec_%s' % (args.att_vec_size)
# exp_path += '__sepcxt' if args.sep_cxt else '__jointcxt'
# exp_path += '_no' if args.no_context_feeding else ''
# exp_path += '__ae_%s' % (args.action_embed_size)
# exp_path += '_no' if args.no_parent_production_embed else ''
# exp_path += '__fe_%s' % ('no' if args.no_parent_field_embed else args.field_embed_size)
# exp_path += '__te_%s' % ('no' if args.no_parent_field_type_embed else args.type_embed_size)
# training params
exp_path += '__bs_%s' % (args.batch_size)
exp_path += '__lr_%s' % (args.lr) if args.plm is None else '__lr_%s_ld_%s' % (args.lr, args.layerwise_decay)
exp_path += '__l2_%s' % (args.l2)
exp_path += '__wp_%s' % (args.warmup_ratio)
exp_path += '__sd_%s' % (args.lr_schedule)
exp_path += '__me_%s' % (args.max_epoch)
exp_path += '__mn_%s' % (args.max_norm)
exp_path += '__bm_%s' % (args.beam_size)
exp_path += '__seed_%s' % (args.seed)
exp_path = os.path.join(EXP_PATH, task, exp_path)
return exp_path

View File

@ -0,0 +1,46 @@
#coding=utf8
""" Utility functions include:
1. set output logging path
2. set random seed for all libs
3. select torch.device
"""
import sys, os, logging
import random, torch, dgl
import numpy as np
def set_logger(exp_path, testing=False):
logFormatter = logging.Formatter('%(asctime)s - %(message)s') #('%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('mylogger')
logger.setLevel(logging.DEBUG)
if testing:
fileHandler = logging.FileHandler('%s/log_test.txt' % (exp_path), mode='w')
else:
fileHandler = logging.FileHandler('%s/log_train.txt' % (exp_path), mode='w')
fileHandler.setFormatter(logFormatter)
logger.addHandler(fileHandler)
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setFormatter(logFormatter)
logger.addHandler(consoleHandler)
return logger
def set_random_seed(random_seed=999):
random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
dgl.random.seed(random_seed)
def set_torch_device(deviceId):
if deviceId < 0:
device = torch.device("cpu")
else:
assert torch.cuda.device_count() >= deviceId + 1
device = torch.device("cuda:%d" % (deviceId))
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # used when debug
## These two sentences are used to ensure reproducibility with cudnnbacken
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
return device

View File

@ -0,0 +1,243 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import logging
import re, math
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
from collections import defaultdict
logger = logging.getLogger(__name__)
def set_optimizer(model, args, num_warmup_steps, num_training_steps, last_epoch=-1):
plm = hasattr(model.encoder.input_layer, 'plm_model')
if plm and args.layerwise_decay <= 0.: # fix plm params
for n, p in model.named_parameters():
if 'plm_model' in n:
p.requires_grad = False
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
no_decay = ['bias', 'LayerNorm.weight']
if plm and 0. < args.layerwise_decay <= 0.5: # seperate lr for plm
grouped_params = [
{'params': list(set([p for n, p in params if 'plm_model' in n and not any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': args.l2},
{'params': list(set([p for n, p in params if 'plm_model' in n and any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': 0.0},
{'params': list(set([p for n, p in params if 'plm_model' not in n and not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2},
{'params': list(set([p for n, p in params if 'plm_model' not in n and any(nd in n for nd in no_decay)])), 'weight_decay': 0.0},
]
print('Use seperate lr %f for pretrained model ...' % (args.lr * args.layerwise_decay))
elif plm and 0.5 < args.layerwise_decay < 1.: # lr decay layerwise for plm
pattern = r'encoder\.layer\.(.*?)\.'
num_layers = int(model.encoder.input_layer.plm_model.config.num_hidden_layers)
groups = {"decay": defaultdict(list), "no_decay": defaultdict(list)} # record grouped params
for n, p in params:
res = re.search(pattern, n) if 'plm_model' in n else None
depth = int(res.group(1)) if res is not None else 0 if 'plm_model' in n else num_layers
if any(nd in n for nd in no_decay):
groups["no_decay"][int(depth)].append(p)
else:
groups["decay"][int(depth)].append(p)
grouped_params = []
for d in groups["decay"]:
lr = args.lr * (args.layerwise_decay ** (num_layers - d))
grouped_params.append({'params': list(set(groups["decay"][d])), 'lr': lr, 'weight_decay': args.l2})
for d in groups["no_decay"]:
lr = args.lr * (args.layerwise_decay ** (num_layers - d))
grouped_params.append({'params': list(set(groups["no_decay"][d])), 'lr': lr, 'weight_decay': 0.0})
print('Use layerwise decay (rate %f) lr %f for pretrained model ...' % (args.layerwise_decay, args.lr))
else: # the same lr for plm and other modules
grouped_params = [
{'params': list(set([p for n, p in params if not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2},
{'params': list(set([p for n, p in params if any(nd in n for nd in no_decay)])), 'weight_decay': 0.0},
]
print('Use the same lr %f for all parameters ...' % (args.lr))
optimizer = AdamW(grouped_params, lr=args.lr, max_grad_norm=args.max_norm)
schedule_func = schedule_dict[args.lr_schedule]
scheduler = schedule_func(optimizer, num_warmup_steps, num_training_steps, last_epoch=last_epoch)
return optimizer, scheduler
def get_ratsql_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
""" Create a schedule with a learning rate that decreases according to the formular
in RATSQL model
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return max(0.0, math.sqrt((num_training_steps - current_step) / float(num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule(optimizer, *args, last_epoch=-1):
""" Create a schedule with a constant learning rate.
"""
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
""" Create a schedule with a constant learning rate preceded by a warmup
period during which the learning rate increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
""" Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function between 0 and `pi * cycles` after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1
):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function with several hard restarts, after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
schedule_dict = {
"constant": get_constant_schedule,
"linear": get_linear_schedule_with_warmup,
"ratsql": get_ratsql_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
}
class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix.
Parameters:
lr (float): learning rate. Default 1e-3.
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
eps (float): Adams epsilon. Default: 1e-6
weight_decay (float): Weight decay. Default: 0.0
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, max_grad_norm=-1, correct_bias=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, max_grad_norm=max_grad_norm, correct_bias=correct_bias)
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
return loss

70
proton/utils/vocab.py Normal file
View File

@ -0,0 +1,70 @@
#coding=utf8
from utils.constants import PAD, UNK, BOS, EOS
class Vocab():
def __init__(self, padding=False, unk=False, boundary=False, min_freq=1,
filepath=None, iterable=None, default=UNK, specials=[]):
super(Vocab, self).__init__()
self.word2id = dict()
self.id2word = dict()
self.default = default # if default is None, ensure that no oov words
if padding:
idx = len(self.word2id)
self.word2id[PAD], self.id2word[idx] = idx, PAD
if unk:
idx = len(self.word2id)
self.word2id[UNK], self.id2word[idx] = idx, UNK
if boundary:
idx = len(self.word2id)
self.word2id[BOS], self.id2word[idx] = idx, BOS
self.word2id[EOS], self.id2word[idx + 1] = idx + 1, EOS
for w in specials:
if w not in self.word2id:
idx = len(self.word2id)
self.word2id[w], self.id2word[idx] = idx, w
if filepath is not None:
self.from_filepath(filepath, min_freq=min_freq)
elif iterable is not None:
self.from_iterable(iterable)
assert (self.default is None) or (self.default in self.word2id)
def from_filepath(self, filepath, min_freq=1):
with open(filepath, 'r', encoding='utf-8') as inf:
for line in inf:
line = line.strip()
if line == '': continue
line = line.split('\t') # ignore count or frequency
if len(line) == 1:
word, freq = line[0], min_freq
else:
assert len(line) == 2
word, freq = line
word = word.lower()
if word not in self.word2id and int(freq) >= min_freq:
idx = len(self.word2id)
self.word2id[word] = idx
self.id2word[idx] = word
def from_iterable(self, iterable):
for item in iterable:
if item not in self.word2id:
idx = len(self.word2id)
self.word2id[item] = idx
self.id2word[idx] = item
def __len__(self):
return len(self.word2id)
@property
def vocab_size(self):
return len(self.word2id)
def __getitem__(self, key):
""" If self.default is None, it means we do not allow out of vocabulary token;
If self.default is not None, we get the idx of self.default if key does not exist.
"""
if self.default is None:
return self.word2id[key]
else:
return self.word2id.get(key, self.word2id[self.default])

37
proton/utils/word2vec.py Normal file
View File

@ -0,0 +1,37 @@
#coding=utf8
from embeddings import GloveEmbedding
import numpy as np
from utils.constants import PAD
import torch, random
class Word2vecUtils():
def __init__(self):
super(Word2vecUtils, self).__init__()
self.word_embed = GloveEmbedding('common_crawl_48', d_emb=300)
self.initializer = lambda: np.random.normal(size=300).tolist()
def load_embeddings(self, module, vocab, device='cpu'):
""" Initialize the embedding with glove and char embedding
"""
emb_size = module.weight.data.size(-1)
assert emb_size == 300, 'Embedding size is not 300, cannot be initialized by GLOVE'
outliers = 0
for word in vocab.word2id:
if word == PAD: # PAD symbol is always 0-vector
module.weight.data[vocab[PAD]] = torch.zeros(emb_size, dtype=torch.float, device=device)
continue
word_emb = self.word_embed.emb(word, default='none')
if word_emb[0] is None: # oov
word_emb = self.initializer()
outliers += 1
module.weight.data[vocab[word]] = torch.tensor(word_emb, dtype=torch.float, device=device)
return 1 - outliers / float(len(vocab))
def emb(self, word):
word_emb = self.word_embed.emb(word, default='none')
if word_emb[0] is None:
return None
else:
return word_emb