add: proton
This commit is contained in:
parent
890ad4ee51
commit
e42cda6d18
|
@ -0,0 +1,67 @@
|
|||
# Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing
|
||||
|
||||
This repository contains code for the KDD 2022 paper [Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing](https://arxiv.org/abs/2206.14017).
|
||||
|
||||
If you use Proton in your work, please cite it as follows:
|
||||
|
||||
```
|
||||
@inproceedings{wang2022proton,
|
||||
title={Proton: Probing Schema Linking Information from Pre-trained Language Models for Text-to-SQL Parsing},
|
||||
author={Wang, Lihan and Qin, Bowen and Hui, Binyuan and Li, Bowen and Yang, Min and Wang, Bailin and Li, Binhua and Huang, Fei and Si, Luo and Li, Yongbin},
|
||||
booktitle={KDD},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
In this work, we propose a novel framework, called **Proton**, which first probes the underlying relational schema-linking structures between a NL query and its database schema from a pre-trained language model, and then effectively injects it into the downstream text-to-SQL parsing models.
|
||||
|
||||
![Overview](static/proton.jpg)
|
||||
|
||||
## Codebase
|
||||
|
||||
### Prepare Environment
|
||||
|
||||
The setup of the environment is exactly the same as that of [LGESQL](https://github.com/rhythmcao/text2sql-lgesql):
|
||||
|
||||
The environment-related commands are provided in `setup.sh`.
|
||||
|
||||
```bash
|
||||
sh setup.sh
|
||||
```
|
||||
|
||||
### Download dataset.
|
||||
|
||||
Download, unzip and rename the [spider.zip](https://drive.google.com/uc?export=download&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0) into the directory `data`. Merge the `data/train_spider.json` and `data/train_others.json` into one single dataset `data/train.json`.
|
||||
|
||||
### Preprocess dataset.
|
||||
|
||||
Preprocess the train and dev dataset, including input normalization, schema linking, graph construction and output actions generation.
|
||||
|
||||
```bash
|
||||
./run/run_preprocessing.sh
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
Training Proton with:
|
||||
|
||||
```
|
||||
./run/run_lgesql_plm.sh msde electra-large-discriminator
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
For evaluation, see `run/run_evaluation.sh` and `run/run_submission.sh` (eval from scratch) for reference.
|
||||
|
||||
## Result
|
||||
|
||||
| Model | DK | SYN | Spider |
|
||||
| --------------- | ---- | ---- | ------ |
|
||||
| LGESQL | 48.4 | 64.6 | 75.3 |
|
||||
| LGESQL + Proton | 51.0 | 65.6 | 76.3 |
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This implementation is based on [RAT-SQL: Relation-Aware Schema Encoding and Linking for Text-to-SQL Parsers.](https://github.com/microsoft/rat-sql) and [LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations](https://github.com/rhythmcao/text2sql-lgesql). Thanks to the author for releasing the code.
|
|
@ -0,0 +1,66 @@
|
|||
# coding=utf-8
|
||||
from asdl.hypothesis import Hypothesis
|
||||
from asdl.transition_system import ApplyRuleAction, GenTokenAction
|
||||
from asdl.sql.sql_transition_system import SelectColumnAction, SelectTableAction
|
||||
|
||||
class ActionInfo(object):
|
||||
"""sufficient statistics for making a prediction of an action at a time step"""
|
||||
|
||||
def __init__(self, action=None):
|
||||
self.t = 0
|
||||
self.parent_t = -1
|
||||
self.action = action
|
||||
self.frontier_prod = None
|
||||
self.frontier_field = None
|
||||
|
||||
# for GenToken actions only
|
||||
self.copy_from_src = False
|
||||
self.src_token_position = -1
|
||||
|
||||
def __repr__(self, verbose=False):
|
||||
repr_str = '%s (t=%d, p_t=%d, frontier_field=%s)' % (repr(self.action),
|
||||
self.t,
|
||||
self.parent_t,
|
||||
self.frontier_field.__repr__(True) if self.frontier_field else 'None')
|
||||
|
||||
if verbose:
|
||||
verbose_repr = 'action_prob=%.4f, ' % self.action_prob
|
||||
if isinstance(self.action, GenTokenAction):
|
||||
verbose_repr += 'in_vocab=%s, ' \
|
||||
'gen_copy_switch=%s, ' \
|
||||
'p(gen)=%s, p(copy)=%s, ' \
|
||||
'has_copy=%s, copy_pos=%s' % (self.in_vocab,
|
||||
self.gen_copy_switch,
|
||||
self.gen_token_prob, self.copy_token_prob,
|
||||
self.copy_from_src, self.src_token_position)
|
||||
|
||||
repr_str += '\n' + verbose_repr
|
||||
|
||||
return repr_str
|
||||
|
||||
|
||||
def get_action_infos(src_query: list = None, tgt_actions: list = [], force_copy=False):
|
||||
action_infos = []
|
||||
hyp = Hypothesis()
|
||||
for t, action in enumerate(tgt_actions):
|
||||
action_info = ActionInfo(action)
|
||||
action_info.t = t
|
||||
if hyp.frontier_node:
|
||||
action_info.parent_t = hyp.frontier_node.created_time
|
||||
action_info.frontier_prod = hyp.frontier_node.production
|
||||
action_info.frontier_field = hyp.frontier_field.field
|
||||
|
||||
if isinstance(action, SelectColumnAction) or isinstance(action, SelectTableAction):
|
||||
pass
|
||||
elif isinstance(action, GenTokenAction): # GenToken
|
||||
try:
|
||||
tok_src_idx = src_query.index(str(action.token))
|
||||
action_info.copy_from_src = True
|
||||
action_info.src_token_position = tok_src_idx
|
||||
except ValueError:
|
||||
if force_copy: raise ValueError('cannot copy primitive token %s from source' % action.token)
|
||||
|
||||
hyp.apply_action(action)
|
||||
action_infos.append(action_info)
|
||||
|
||||
return action_infos
|
|
@ -0,0 +1,323 @@
|
|||
# coding=utf-8
|
||||
from collections import OrderedDict, Counter
|
||||
from itertools import chain
|
||||
import re, os
|
||||
|
||||
def remove_comment(text):
|
||||
text = re.sub(re.compile("#.*"), "", text)
|
||||
text = '\n'.join(filter(lambda x: x, text.split('\n')))
|
||||
return text
|
||||
|
||||
class ASDLGrammar(object):
|
||||
"""
|
||||
Collection of types, constructors and productions
|
||||
"""
|
||||
def __init__(self, productions, file_path):
|
||||
# productions are indexed by their head types
|
||||
file_name = os.path.basename(file_path)
|
||||
grammar_name = file_name[:file_name.index('.txt')] if '.txt' in file_name else file_name
|
||||
self._grammar_name = grammar_name
|
||||
self._productions = OrderedDict()
|
||||
self._constructor_production_map = dict()
|
||||
for prod in productions:
|
||||
if prod.type not in self._productions:
|
||||
self._productions[prod.type] = list()
|
||||
self._productions[prod.type].append(prod)
|
||||
self._constructor_production_map[prod.constructor.name] = prod
|
||||
|
||||
self.root_type = productions[0].type
|
||||
# number of constructors
|
||||
self.size = sum(len(head) for head in self._productions.values())
|
||||
|
||||
# get entities to their ids map
|
||||
self.prod2id = {prod: i for i, prod in enumerate(self.productions)}
|
||||
self.type2id = {type: i for i, type in enumerate(self.types)}
|
||||
self.field2id = {field: i for i, field in enumerate(self.fields)}
|
||||
|
||||
self.id2prod = {i: prod for i, prod in enumerate(self.productions)}
|
||||
self.id2type = {i: type for i, type in enumerate(self.types)}
|
||||
self.id2field = {i: field for i, field in enumerate(self.fields)}
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
@property
|
||||
def productions(self):
|
||||
return sorted(chain.from_iterable(self._productions.values()), key=lambda x: repr(x))
|
||||
|
||||
def __getitem__(self, datum):
|
||||
if isinstance(datum, str):
|
||||
return self._productions[ASDLType(datum)]
|
||||
elif isinstance(datum, ASDLType):
|
||||
return self._productions[datum]
|
||||
|
||||
def get_prod_by_ctr_name(self, name):
|
||||
return self._constructor_production_map[name]
|
||||
|
||||
@property
|
||||
def types(self):
|
||||
if not hasattr(self, '_types'):
|
||||
all_types = set()
|
||||
for prod in self.productions:
|
||||
all_types.add(prod.type)
|
||||
all_types.update(map(lambda x: x.type, prod.constructor.fields))
|
||||
|
||||
self._types = sorted(all_types, key=lambda x: x.name)
|
||||
|
||||
return self._types
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
if not hasattr(self, '_fields'):
|
||||
all_fields = set()
|
||||
for prod in self.productions:
|
||||
all_fields.update(prod.constructor.fields)
|
||||
|
||||
self._fields = sorted(all_fields, key=lambda x: (x.name, x.type.name, x.cardinality))
|
||||
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def primitive_types(self):
|
||||
return filter(lambda x: isinstance(x, ASDLPrimitiveType), self.types)
|
||||
|
||||
@property
|
||||
def composite_types(self):
|
||||
return filter(lambda x: isinstance(x, ASDLCompositeType), self.types)
|
||||
|
||||
def is_composite_type(self, asdl_type):
|
||||
return asdl_type in self.composite_types
|
||||
|
||||
def is_primitive_type(self, asdl_type):
|
||||
return asdl_type in self.primitive_types
|
||||
|
||||
@staticmethod
|
||||
def from_filepath(file_path):
|
||||
def _parse_field_from_text(_text):
|
||||
d = _text.strip().split(' ')
|
||||
name = d[1].strip()
|
||||
type_str = d[0].strip()
|
||||
cardinality = 'single'
|
||||
if type_str[-1] == '*':
|
||||
type_str = type_str[:-1]
|
||||
cardinality = 'multiple'
|
||||
elif type_str[-1] == '?':
|
||||
type_str = type_str[:-1]
|
||||
cardinality = 'optional'
|
||||
|
||||
if type_str in primitive_type_names:
|
||||
return Field(name, ASDLPrimitiveType(type_str), cardinality=cardinality)
|
||||
else:
|
||||
return Field(name, ASDLCompositeType(type_str), cardinality=cardinality)
|
||||
|
||||
def _parse_constructor_from_text(_text):
|
||||
_text = _text.strip()
|
||||
fields = None
|
||||
if '(' in _text:
|
||||
name = _text[:_text.find('(')]
|
||||
field_blocks = _text[_text.find('(') + 1:_text.find(')')].split(',')
|
||||
fields = map(_parse_field_from_text, field_blocks)
|
||||
else:
|
||||
name = _text
|
||||
|
||||
if name == '': name = None
|
||||
|
||||
return ASDLConstructor(name, fields)
|
||||
|
||||
with open(file_path, 'r') as inf:
|
||||
text = inf.read()
|
||||
lines = remove_comment(text).split('\n')
|
||||
lines = list(map(lambda l: l.strip(), lines))
|
||||
lines = list(filter(lambda l: l, lines))
|
||||
line_no = 0
|
||||
|
||||
# first line is always the primitive types
|
||||
primitive_type_names = list(map(lambda x: x.strip(), lines[line_no].split(',')))
|
||||
line_no += 1
|
||||
|
||||
all_productions = list()
|
||||
|
||||
while True:
|
||||
type_block = lines[line_no]
|
||||
type_name = type_block[:type_block.find('=')].strip()
|
||||
constructors_blocks = type_block[type_block.find('=') + 1:].split('|')
|
||||
i = line_no + 1
|
||||
while i < len(lines) and lines[i].strip().startswith('|'):
|
||||
t = lines[i].strip()
|
||||
cont_constructors_blocks = t[1:].split('|')
|
||||
constructors_blocks.extend(cont_constructors_blocks)
|
||||
|
||||
i += 1
|
||||
|
||||
constructors_blocks = filter(lambda x: x and x.strip(), constructors_blocks)
|
||||
|
||||
# parse type name
|
||||
new_type = ASDLPrimitiveType(type_name) if type_name in primitive_type_names else ASDLCompositeType(type_name)
|
||||
constructors = map(_parse_constructor_from_text, constructors_blocks)
|
||||
|
||||
productions = list(map(lambda c: ASDLProduction(new_type, c), constructors))
|
||||
all_productions.extend(productions)
|
||||
|
||||
line_no = i
|
||||
if line_no == len(lines):
|
||||
break
|
||||
|
||||
grammar = ASDLGrammar(all_productions, file_path)
|
||||
|
||||
return grammar
|
||||
|
||||
|
||||
class ASDLProduction(object):
|
||||
def __init__(self, type, constructor):
|
||||
self.type = type
|
||||
self.constructor = constructor
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return self.constructor.fields
|
||||
|
||||
def __getitem__(self, field_name):
|
||||
return self.constructor[field_name]
|
||||
|
||||
def __hash__(self):
|
||||
h = hash(self.type) ^ hash(self.constructor)
|
||||
|
||||
return h
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ASDLProduction) and \
|
||||
self.type == other.type and \
|
||||
self.constructor == other.constructor
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s -> %s' % (self.type.__repr__(plain=True), self.constructor.__repr__(plain=True))
|
||||
|
||||
|
||||
class ASDLConstructor(object):
|
||||
def __init__(self, name, fields=None):
|
||||
self.name = name
|
||||
self.fields = []
|
||||
if fields:
|
||||
self.fields = list(fields)
|
||||
|
||||
def __getitem__(self, field_name):
|
||||
for field in self.fields:
|
||||
if field.name == field_name: return field
|
||||
|
||||
raise KeyError
|
||||
|
||||
def __hash__(self):
|
||||
h = hash(self.name)
|
||||
for field in self.fields:
|
||||
h ^= hash(field)
|
||||
|
||||
return h
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ASDLConstructor) and \
|
||||
self.name == other.name and \
|
||||
self.fields == other.fields
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self, plain=False):
|
||||
plain_repr = '%s(%s)' % (self.name,
|
||||
', '.join(f.__repr__(plain=True) for f in self.fields))
|
||||
if plain: return plain_repr
|
||||
else: return 'Constructor(%s)' % plain_repr
|
||||
|
||||
|
||||
class Field(object):
|
||||
def __init__(self, name, type, cardinality):
|
||||
self.name = name
|
||||
self.type = type
|
||||
|
||||
assert cardinality in ['single', 'optional', 'multiple']
|
||||
self.cardinality = cardinality
|
||||
|
||||
def __hash__(self):
|
||||
h = hash(self.name) ^ hash(self.type)
|
||||
h ^= hash(self.cardinality)
|
||||
|
||||
return h
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Field) and \
|
||||
self.name == other.name and \
|
||||
self.type == other.type and \
|
||||
self.cardinality == other.cardinality
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self, plain=False):
|
||||
plain_repr = '%s%s %s' % (self.type.__repr__(plain=True),
|
||||
Field.get_cardinality_repr(self.cardinality),
|
||||
self.name)
|
||||
if plain: return plain_repr
|
||||
else: return 'Field(%s)' % plain_repr
|
||||
|
||||
@staticmethod
|
||||
def get_cardinality_repr(cardinality):
|
||||
return '' if cardinality == 'single' else '?' if cardinality == 'optional' else '*'
|
||||
|
||||
|
||||
class ASDLType(object):
|
||||
def __init__(self, type_name):
|
||||
self.name = type_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ASDLType) and self.name == other.name
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self, plain=False):
|
||||
plain_repr = self.name
|
||||
if plain: return plain_repr
|
||||
else: return '%s(%s)' % (self.__class__.__name__, plain_repr)
|
||||
|
||||
|
||||
class ASDLCompositeType(ASDLType):
|
||||
pass
|
||||
|
||||
|
||||
class ASDLPrimitiveType(ASDLType):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asdl_desc = """
|
||||
var, ent, num, var_type
|
||||
|
||||
expr = Variable(var variable)
|
||||
| Entity(ent entity)
|
||||
| Number(num number)
|
||||
| Apply(pred predicate, expr* arguments)
|
||||
| Argmax(var variable, expr domain, expr body)
|
||||
| Argmin(var variable, expr domain, expr body)
|
||||
| Count(var variable, expr body)
|
||||
| Exists(var variable, expr body)
|
||||
| Lambda(var variable, var_type type, expr body)
|
||||
| Max(var variable, expr body)
|
||||
| Min(var variable, expr body)
|
||||
| Sum(var variable, expr domain, expr body)
|
||||
| The(var variable, expr body)
|
||||
| Not(expr argument)
|
||||
| And(expr* arguments)
|
||||
| Or(expr* arguments)
|
||||
| Compare(cmp_op op, expr left, expr right)
|
||||
|
||||
cmp_op = GreaterThan | Equal | LessThan
|
||||
"""
|
||||
|
||||
grammar = ASDLGrammar.from_text(asdl_desc)
|
||||
print(ASDLCompositeType('1') == ASDLPrimitiveType('1'))
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
# coding=utf-8
|
||||
from io import StringIO
|
||||
from asdl.asdl import *
|
||||
|
||||
class AbstractSyntaxTree(object):
|
||||
def __init__(self, production, realized_fields=None):
|
||||
self.production = production
|
||||
|
||||
# a child is essentially a *realized_field*
|
||||
self.fields = []
|
||||
|
||||
# record its parent field to which it's attached
|
||||
self.parent_field = None
|
||||
|
||||
# used in decoding, record the time step when this node was created
|
||||
self.created_time = 0
|
||||
|
||||
if realized_fields:
|
||||
assert len(realized_fields) == len(self.production.fields)
|
||||
|
||||
for field in realized_fields:
|
||||
self.add_child(field)
|
||||
else:
|
||||
for field in self.production.fields:
|
||||
self.add_child(RealizedField(field))
|
||||
|
||||
def add_child(self, realized_field):
|
||||
# if isinstance(realized_field.value, AbstractSyntaxTree):
|
||||
# realized_field.value.parent = self
|
||||
self.fields.append(realized_field)
|
||||
realized_field.parent_node = self
|
||||
|
||||
def __getitem__(self, field_name):
|
||||
for field in self.fields:
|
||||
if field.name == field_name: return field
|
||||
raise KeyError
|
||||
|
||||
def sanity_check(self):
|
||||
if len(self.production.fields) != len(self.fields):
|
||||
raise ValueError('filed number must match')
|
||||
for field, realized_field in zip(self.production.fields, self.fields):
|
||||
assert field == realized_field.field
|
||||
for child in self.fields:
|
||||
for child_val in child.as_value_list:
|
||||
if isinstance(child_val, AbstractSyntaxTree):
|
||||
child_val.sanity_check()
|
||||
|
||||
def copy(self):
|
||||
new_tree = AbstractSyntaxTree(self.production)
|
||||
new_tree.created_time = self.created_time
|
||||
for i, old_field in enumerate(self.fields):
|
||||
new_field = new_tree.fields[i]
|
||||
new_field._not_single_cardinality_finished = old_field._not_single_cardinality_finished
|
||||
if isinstance(old_field.type, ASDLCompositeType):
|
||||
for value in old_field.as_value_list:
|
||||
new_field.add_value(value.copy())
|
||||
else:
|
||||
for value in old_field.as_value_list:
|
||||
new_field.add_value(value)
|
||||
|
||||
return new_tree
|
||||
|
||||
def to_string(self, sb=None):
|
||||
is_root = False
|
||||
if sb is None:
|
||||
is_root = True
|
||||
sb = StringIO()
|
||||
|
||||
sb.write('(')
|
||||
sb.write(self.production.constructor.name)
|
||||
|
||||
for field in self.fields:
|
||||
sb.write(' ')
|
||||
sb.write('(')
|
||||
sb.write(field.type.name)
|
||||
sb.write(Field.get_cardinality_repr(field.cardinality))
|
||||
sb.write('-')
|
||||
sb.write(field.name)
|
||||
|
||||
if field.value is not None:
|
||||
for val_node in field.as_value_list:
|
||||
sb.write(' ')
|
||||
if isinstance(field.type, ASDLCompositeType):
|
||||
val_node.to_string(sb)
|
||||
else:
|
||||
sb.write(str(val_node).replace(' ', '-SPACE-'))
|
||||
|
||||
sb.write(')') # of field
|
||||
|
||||
sb.write(')') # of node
|
||||
|
||||
if is_root:
|
||||
return sb.getvalue()
|
||||
|
||||
def __hash__(self):
|
||||
code = hash(self.production)
|
||||
for field in self.fields:
|
||||
code = code + 37 * hash(field)
|
||||
|
||||
return code
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
if self.created_time != other.created_time:
|
||||
return False
|
||||
|
||||
if self.production != other.production:
|
||||
return False
|
||||
|
||||
if len(self.fields) != len(other.fields):
|
||||
return False
|
||||
|
||||
for i in range(len(self.fields)):
|
||||
if self.fields[i] != other.fields[i]: return False
|
||||
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.production)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
node_num = 1
|
||||
for field in self.fields:
|
||||
for val in field.as_value_list:
|
||||
if isinstance(val, AbstractSyntaxTree):
|
||||
node_num += val.size
|
||||
else: node_num += 1
|
||||
|
||||
return node_num
|
||||
|
||||
|
||||
class RealizedField(Field):
|
||||
"""wrapper of field realized with values"""
|
||||
def __init__(self, field, value=None, parent=None):
|
||||
super(RealizedField, self).__init__(field.name, field.type, field.cardinality)
|
||||
|
||||
# record its parent AST node
|
||||
self.parent_node = None
|
||||
|
||||
# FIXME: hack, return the field as a property
|
||||
self.field = field
|
||||
|
||||
# initialize value to correct type
|
||||
if self.cardinality == 'multiple':
|
||||
self.value = []
|
||||
if value is not None:
|
||||
for child_node in value:
|
||||
self.add_value(child_node)
|
||||
else:
|
||||
self.value = None
|
||||
# note the value could be 0!
|
||||
if value is not None: self.add_value(value)
|
||||
|
||||
# properties only used in decoding, record if the field is finished generating
|
||||
# when card in [optional, multiple]
|
||||
self._not_single_cardinality_finished = False
|
||||
|
||||
def add_value(self, value):
|
||||
if isinstance(value, AbstractSyntaxTree):
|
||||
value.parent_field = self
|
||||
|
||||
if self.cardinality == 'multiple':
|
||||
self.value.append(value)
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def as_value_list(self):
|
||||
"""get value as an iterable"""
|
||||
if self.cardinality == 'multiple': return self.value
|
||||
elif self.value is not None: return [self.value]
|
||||
else: return []
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
if self.cardinality == 'single':
|
||||
if self.value is None: return False
|
||||
else: return True
|
||||
elif self.cardinality == 'optional' and self.value is not None:
|
||||
return True
|
||||
else:
|
||||
if self._not_single_cardinality_finished: return True
|
||||
else: return False
|
||||
|
||||
def set_finish(self):
|
||||
# assert self.cardinality in ('optional', 'multiple')
|
||||
self._not_single_cardinality_finished = True
|
||||
|
||||
def __eq__(self, other):
|
||||
if super(RealizedField, self).__eq__(other):
|
||||
if type(other) == Field: return True # FIXME: hack, Field and RealizedField can compare!
|
||||
if self.value == other.value: return True
|
||||
else: return False
|
||||
else: return False
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
|
||||
from asdl.asdl import *
|
||||
from asdl.hypothesis import Hypothesis
|
||||
from asdl.transition_system import *
|
||||
|
||||
|
||||
class DecodeHypothesis(Hypothesis):
|
||||
def __init__(self):
|
||||
super(DecodeHypothesis, self).__init__()
|
||||
|
||||
self.action_infos = []
|
||||
self.code = None
|
||||
|
||||
def clone_and_apply_action_info(self, action_info):
|
||||
action = action_info.action
|
||||
|
||||
new_hyp = self.clone_and_apply_action(action)
|
||||
new_hyp.action_infos.append(action_info)
|
||||
|
||||
return new_hyp
|
||||
|
||||
def copy(self):
|
||||
new_hyp = DecodeHypothesis()
|
||||
if self.tree:
|
||||
new_hyp.tree = self.tree.copy()
|
||||
|
||||
new_hyp.actions = list(self.actions)
|
||||
new_hyp.action_infos = list(self.action_infos)
|
||||
new_hyp.score = self.score
|
||||
new_hyp._value_buffer = list(self._value_buffer)
|
||||
new_hyp.t = self.t
|
||||
new_hyp.code = self.code
|
||||
|
||||
new_hyp.update_frontier_info()
|
||||
|
||||
return new_hyp
|
|
@ -0,0 +1,122 @@
|
|||
# coding=utf-8
|
||||
|
||||
from asdl.asdl import *
|
||||
from asdl.asdl_ast import AbstractSyntaxTree
|
||||
from asdl.transition_system import *
|
||||
|
||||
class Hypothesis(object):
|
||||
def __init__(self):
|
||||
self.tree = None
|
||||
self.actions = []
|
||||
self.score = 0.
|
||||
self.frontier_node = None
|
||||
self.frontier_field = None
|
||||
self._value_buffer = []
|
||||
|
||||
# record the current time step
|
||||
self.t = 0
|
||||
|
||||
def apply_action(self, action):
|
||||
if self.tree is None: # the first action
|
||||
assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
|
||||
'at the beginning of decoding'
|
||||
|
||||
self.tree = AbstractSyntaxTree(action.production)
|
||||
self.update_frontier_info()
|
||||
elif self.frontier_node:
|
||||
if isinstance(self.frontier_field.type, ASDLCompositeType):
|
||||
if isinstance(action, ApplyRuleAction):
|
||||
field_value = AbstractSyntaxTree(action.production)
|
||||
field_value.created_time = self.t
|
||||
self.frontier_field.add_value(field_value)
|
||||
self.update_frontier_info()
|
||||
elif isinstance(action, ReduceAction):
|
||||
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
|
||||
'applied on field with multiple ' \
|
||||
'cardinality'
|
||||
self.frontier_field.set_finish()
|
||||
self.update_frontier_info()
|
||||
else:
|
||||
raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field))
|
||||
else: # fill in a primitive field
|
||||
if isinstance(action, GenTokenAction):
|
||||
# only field of type string requires termination signal </primitive>
|
||||
end_primitive = False
|
||||
if self.frontier_field.type.name == 'string':
|
||||
if action.is_stop_signal():
|
||||
self.frontier_field.add_value(' '.join(self._value_buffer))
|
||||
self._value_buffer = []
|
||||
|
||||
end_primitive = True
|
||||
else:
|
||||
self._value_buffer.append(action.token)
|
||||
else:
|
||||
self.frontier_field.add_value(action.token)
|
||||
end_primitive = True
|
||||
|
||||
if end_primitive and self.frontier_field.cardinality in ('single', 'optional'):
|
||||
self.frontier_field.set_finish()
|
||||
self.update_frontier_info()
|
||||
|
||||
elif isinstance(action, ReduceAction):
|
||||
assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
|
||||
'applied on field with multiple ' \
|
||||
'cardinality'
|
||||
self.frontier_field.set_finish()
|
||||
self.update_frontier_info()
|
||||
else:
|
||||
raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields')
|
||||
|
||||
self.t += 1
|
||||
self.actions.append(action)
|
||||
|
||||
def update_frontier_info(self):
|
||||
def _find_frontier_node_and_field(tree_node):
|
||||
# return None if each field of this ast node is realized else unfinished ast node, unrealized field
|
||||
if tree_node:
|
||||
for field in tree_node.fields:
|
||||
# if it's an intermediate node, check its children
|
||||
if isinstance(field.type, ASDLCompositeType) and field.value:
|
||||
if field.cardinality in ('single', 'optional'): iter_values = [field.value]
|
||||
else: iter_values = field.value
|
||||
|
||||
for child_node in iter_values:
|
||||
result = _find_frontier_node_and_field(child_node)
|
||||
if result: return result
|
||||
|
||||
# now all its possible children are checked
|
||||
if not field.finished:
|
||||
return tree_node, field
|
||||
|
||||
return None
|
||||
else: return None
|
||||
|
||||
frontier_info = _find_frontier_node_and_field(self.tree)
|
||||
if frontier_info:
|
||||
self.frontier_node, self.frontier_field = frontier_info
|
||||
else:
|
||||
self.frontier_node, self.frontier_field = None, None
|
||||
|
||||
def clone_and_apply_action(self, action):
|
||||
new_hyp = self.copy()
|
||||
new_hyp.apply_action(action)
|
||||
|
||||
return new_hyp
|
||||
|
||||
def copy(self):
|
||||
new_hyp = Hypothesis()
|
||||
if self.tree:
|
||||
new_hyp.tree = self.tree.copy()
|
||||
|
||||
new_hyp.actions = list(self.actions)
|
||||
new_hyp.score = self.score
|
||||
new_hyp._value_buffer = list(self._value_buffer)
|
||||
new_hyp.t = self.t
|
||||
|
||||
new_hyp.update_frontier_info()
|
||||
|
||||
return new_hyp
|
||||
|
||||
@property
|
||||
def completed(self):
|
||||
return self.tree and self.frontier_field is None
|
|
@ -0,0 +1,103 @@
|
|||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
|
||||
# val: value(float/string)/sql(dict)/col_unit(tuple)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, tab_id/sql)
|
||||
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/integer
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
|
||||
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
# JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
# UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
# TABLE_TYPE = ('sql', 'table_unit')
|
||||
# COND_OPS = ('and', 'or')
|
||||
# SQL_OPS = ('intersect', 'union', 'except')
|
||||
# ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
##########################################################################
|
||||
|
||||
# 1. eliminate ? by enumerating different number of items
|
||||
# 2. from conds, distinct and value generation are also not considered
|
||||
# 3. for select items, we use val_unit instead of (agg, val_unit)
|
||||
# 4. for orderby items, we use col_unit instead of val_unit
|
||||
# 5. for groupby items, we use col_unit instead of col_id
|
||||
# 6. predict from clause first, to obtain an overview of the entire query graph
|
||||
|
||||
col_id, tab_id
|
||||
|
||||
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Union(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Except(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Single(sql_unit sql_unit)
|
||||
|
||||
sql_unit = Complete(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause)
|
||||
| NoWhere(from from_clause, val_unit* select_clause, group_by group_by_clause, order_by order_by_clause)
|
||||
| NoGroupBy(from from_clause, val_unit* select_clause, cond where_clause, order_by order_by_clause)
|
||||
| NoOrderBy(from from_clause, val_unit* select_clause, cond where_clause, group_by group_by_clause)
|
||||
| OnlyWhere(from from_clause, val_unit* select_clause, cond where_clause)
|
||||
| OnlyGroupBy(from from_clause, val_unit* select_clause, group_by group_by_clause)
|
||||
| OnlyOrderBy(from from_clause, val_unit* select_clause, order_by order_by_clause)
|
||||
| Simple(from from_clause, val_unit* select_clause)
|
||||
|
||||
from = FromTable(tab_id* tab_id_list)
|
||||
| FromSQL(sql from_sql)
|
||||
|
||||
group_by = Having(col_unit* col_unit_list, cond having_clause)
|
||||
| NoHaving(col_unit* col_unit_list)
|
||||
|
||||
order_by = Asc(col_unit* col_unit_list)
|
||||
| Desc(col_unit* col_unit_list)
|
||||
| AscLimit(col_unit* col_unit_list)
|
||||
| DescLimit(col_unit* col_unit_list)
|
||||
|
||||
cond = And(cond left, cond right)
|
||||
| Or(cond left, cond right)
|
||||
| Between(val_unit val_unit)
|
||||
| Eq(val_unit val_unit)
|
||||
| Gt(val_unit val_unit)
|
||||
| Lt(val_unit val_unit)
|
||||
| Ge(val_unit val_unit)
|
||||
| Le(val_unit val_unit)
|
||||
| Neq(val_unit val_unit)
|
||||
| Like(val_unit val_unit)
|
||||
| NotLike(val_unit val_unit)
|
||||
| BetweenSQL(val_unit val_unit, sql cond_sql)
|
||||
| EqSQL(val_unit val_unit, sql cond_sql)
|
||||
| GtSQL(val_unit val_unit, sql cond_sql)
|
||||
| LtSQL(val_unit val_unit, sql cond_sql)
|
||||
| GeSQL(val_unit val_unit, sql cond_sql)
|
||||
| LeSQL(val_unit val_unit, sql cond_sql)
|
||||
| NeqSQL(val_unit val_unit, sql cond_sql)
|
||||
| InSQL(val_unit val_unit, sql cond_sql)
|
||||
| NotInSQL(val_unit val_unit, sql cond_sql)
|
||||
|
||||
val_unit = Unary(col_unit col_unit)
|
||||
| Minus(col_unit col_unit, col_unit col_unit)
|
||||
| Plus(col_unit col_unit, col_unit col_unit)
|
||||
| Times(col_unit col_unit, col_unit col_unit)
|
||||
| Divide(col_unit col_unit, col_unit col_unit)
|
||||
|
||||
col_unit = None(col_id col_id)
|
||||
| Max(col_id col_id)
|
||||
| Min(col_id col_id)
|
||||
| Count(col_id col_id)
|
||||
| Sum(col_id col_id)
|
||||
| Avg(col_id col_id)
|
|
@ -0,0 +1,111 @@
|
|||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
|
||||
# val: value(float/string)/sql(dict)/col_unit(tuple)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, tab_id/sql)
|
||||
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/integer
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
|
||||
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
# JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
# UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
# TABLE_TYPE = ('sql', 'table_unit')
|
||||
# COND_OPS = ('and', 'or')
|
||||
# SQL_OPS = ('intersect', 'union', 'except')
|
||||
# ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
##########################################################################
|
||||
|
||||
# 1. eliminate * by enumerating different number of items
|
||||
# 2. from conds, distinct and value generation are also not considered
|
||||
# 3. for select items, we use val_unit instead of (agg, val_unit)
|
||||
# 4. for orderby items, we use col_unit instead of val_unit
|
||||
# 5. for groupby items, we use col_unit instead of col_id
|
||||
# 6. predict from clause first, to obtain an overview of the entire query graph
|
||||
|
||||
col_id, tab_id
|
||||
|
||||
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Union(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Except(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Single(sql_unit sql_unit)
|
||||
|
||||
sql_unit = SQL(from from_clause, select select_clause, cond? where_clause, group_by? group_by_clause, order_by? order_by_clause)
|
||||
|
||||
select = SelectOne(val_unit val_unit)
|
||||
| SelectTwo(val_unit val_unit, val_unit val_unit)
|
||||
| SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
| SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
| SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
|
||||
from = FromOneTable(tab_id tab_id)
|
||||
| FromTwoTable(tab_id tab_id, tab_id tab_id)
|
||||
| FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromSQL(sql from_sql)
|
||||
|
||||
group_by = GroupByOne(col_unit col_unit, cond? having_clause)
|
||||
| GroupByTwo(col_unit col_unit, col_unit col_unit, cond? having_clause)
|
||||
|
||||
order_by = OneAsc(col_unit col_unit)
|
||||
| OneDesc(col_unit col_unit)
|
||||
| OneAscLimit(col_unit col_unit)
|
||||
| OneDescLimit(col_unit col_unit)
|
||||
| TwoAsc(col_unit col_unit, col_unit col_unit)
|
||||
| TwoDesc(col_unit col_unit, col_unit col_unit)
|
||||
| TwoAscLimit(col_unit col_unit, col_unit col_unit)
|
||||
| TwoDescLimit(col_unit col_unit, col_unit col_unit)
|
||||
|
||||
cond = And(cond left, cond right)
|
||||
| Or(cond left, cond right)
|
||||
| Between(val_unit val_unit)
|
||||
| Eq(val_unit val_unit)
|
||||
| Gt(val_unit val_unit)
|
||||
| Lt(val_unit val_unit)
|
||||
| Ge(val_unit val_unit)
|
||||
| Le(val_unit val_unit)
|
||||
| Neq(val_unit val_unit)
|
||||
| Like(val_unit val_unit)
|
||||
| NotLike(val_unit val_unit)
|
||||
| BetweenSQL(val_unit val_unit, sql cond_sql)
|
||||
| EqSQL(val_unit val_unit, sql cond_sql)
|
||||
| GtSQL(val_unit val_unit, sql cond_sql)
|
||||
| LtSQL(val_unit val_unit, sql cond_sql)
|
||||
| GeSQL(val_unit val_unit, sql cond_sql)
|
||||
| LeSQL(val_unit val_unit, sql cond_sql)
|
||||
| NeqSQL(val_unit val_unit, sql cond_sql)
|
||||
| InSQL(val_unit val_unit, sql cond_sql)
|
||||
| NotInSQL(val_unit val_unit, sql cond_sql)
|
||||
|
||||
val_unit = Unary(col_unit col_unit)
|
||||
| Minus(col_unit col_unit, col_unit col_unit)
|
||||
| Plus(col_unit col_unit, col_unit col_unit)
|
||||
| Times(col_unit col_unit, col_unit col_unit)
|
||||
| Divide(col_unit col_unit, col_unit col_unit)
|
||||
|
||||
col_unit = None(col_id col_id)
|
||||
| Max(col_id col_id)
|
||||
| Min(col_id col_id)
|
||||
| Count(col_id col_id)
|
||||
| Sum(col_id col_id)
|
||||
| Avg(col_id col_id)
|
|
@ -0,0 +1,120 @@
|
|||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
|
||||
# val: value(float/string)/sql(dict)/col_unit(tuple)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, tab_id/sql)
|
||||
# cond_unit: (not_op(bool), cmp_op, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/integer
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
|
||||
# CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
# JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
# CMP_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
# UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
# AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
# TABLE_TYPE = ('sql', 'table_unit')
|
||||
# COND_OPS = ('and', 'or')
|
||||
# SQL_OPS = ('intersect', 'union', 'except')
|
||||
# ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
##########################################################################
|
||||
|
||||
# 1. eliminate both ? and * cardinality by enumerating different number of items
|
||||
# 2. from conds, distinct and value generation are also not considered
|
||||
# 3. for select items, we use val_unit instead of (agg, val_unit)
|
||||
# 4. for orderby items, we use col_unit instead of val_unit
|
||||
# 5. for groupby items, we use col_unit instead of col_id
|
||||
# 6. predict from clause first, to obtain an overview of the entire query graph
|
||||
|
||||
col_id, tab_id
|
||||
|
||||
sql = Intersect(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Union(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Except(sql_unit sql_unit, sql_unit sql_unit)
|
||||
| Single(sql_unit sql_unit)
|
||||
|
||||
sql_unit = Complete(from from_clause, select select_clause, cond where_clause, group_by group_by_clause, order_by order_by_clause)
|
||||
| NoWhere(from from_clause, select select_clause, group_by group_by_clause, order_by order_by_clause)
|
||||
| NoGroupBy(from from_clause, select select_clause, cond where_clause, order_by order_by_clause)
|
||||
| NoOrderBy(from from_clause, select select_clause, cond where_clause, group_by group_by_clause)
|
||||
| OnlyWhere(from from_clause, select select_clause, cond where_clause)
|
||||
| OnlyGroupBy(from from_clause, select select_clause, group_by group_by_clause)
|
||||
| OnlyOrderBy(from from_clause, select select_clause, order_by order_by_clause)
|
||||
| Simple(from from_clause, select select_clause)
|
||||
|
||||
select = SelectOne(val_unit val_unit)
|
||||
| SelectTwo(val_unit val_unit, val_unit val_unit)
|
||||
| SelectThree(val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
| SelectFour(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
| SelectFive(val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit, val_unit val_unit)
|
||||
|
||||
from = FromOneTable(tab_id tab_id)
|
||||
| FromTwoTable(tab_id tab_id, tab_id tab_id)
|
||||
| FromThreeTable(tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromFourTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromFiveTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromSixTable(tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id, tab_id tab_id)
|
||||
| FromSQL(sql from_sql)
|
||||
|
||||
group_by = OneNoHaving(col_unit col_unit)
|
||||
| TwoNoHaving(col_unit col_unit, col_unit col_unit)
|
||||
| OneHaving(col_unit col_unit, cond having_clause)
|
||||
| TwoHaving(col_unit col_unit, col_unit col_unit, cond having_clause)
|
||||
|
||||
order_by = OneAsc(col_unit col_unit)
|
||||
| OneDesc(col_unit col_unit)
|
||||
| OneAscLimit(col_unit col_unit)
|
||||
| OneDescLimit(col_unit col_unit)
|
||||
| TwoAsc(col_unit col_unit, col_unit col_unit)
|
||||
| TwoDesc(col_unit col_unit, col_unit col_unit)
|
||||
| TwoAscLimit(col_unit col_unit, col_unit col_unit)
|
||||
| TwoDescLimit(col_unit col_unit, col_unit col_unit)
|
||||
|
||||
cond = And(cond left, cond right)
|
||||
| Or(cond left, cond right)
|
||||
| Between(val_unit val_unit)
|
||||
| Eq(val_unit val_unit)
|
||||
| Gt(val_unit val_unit)
|
||||
| Lt(val_unit val_unit)
|
||||
| Ge(val_unit val_unit)
|
||||
| Le(val_unit val_unit)
|
||||
| Neq(val_unit val_unit)
|
||||
| Like(val_unit val_unit)
|
||||
| NotLike(val_unit val_unit)
|
||||
| BetweenSQL(val_unit val_unit, sql cond_sql)
|
||||
| EqSQL(val_unit val_unit, sql cond_sql)
|
||||
| GtSQL(val_unit val_unit, sql cond_sql)
|
||||
| LtSQL(val_unit val_unit, sql cond_sql)
|
||||
| GeSQL(val_unit val_unit, sql cond_sql)
|
||||
| LeSQL(val_unit val_unit, sql cond_sql)
|
||||
| NeqSQL(val_unit val_unit, sql cond_sql)
|
||||
| InSQL(val_unit val_unit, sql cond_sql)
|
||||
| NotInSQL(val_unit val_unit, sql cond_sql)
|
||||
|
||||
val_unit = Unary(col_unit col_unit)
|
||||
| Minus(col_unit col_unit, col_unit col_unit)
|
||||
| Plus(col_unit col_unit, col_unit col_unit)
|
||||
| Times(col_unit col_unit, col_unit col_unit)
|
||||
| Divide(col_unit col_unit, col_unit col_unit)
|
||||
|
||||
col_unit = None(col_id col_id)
|
||||
| Max(col_id col_id)
|
||||
| Min(col_id col_id)
|
||||
| Count(col_id col_id)
|
||||
| Sum(col_id col_id)
|
||||
| Avg(col_id col_id)
|
|
@ -0,0 +1,175 @@
|
|||
#coding=utf8
|
||||
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class Parser():
|
||||
""" Parse a sql dict into AbstractSyntaxTree object according to specified grammar rules
|
||||
Some common methods are implemented in this parent class.
|
||||
"""
|
||||
def __init__(self, grammar: ASDLGrammar):
|
||||
super(Parser, self).__init__()
|
||||
self.grammar = grammar
|
||||
|
||||
@classmethod
|
||||
def from_grammar(cls, grammar: ASDLGrammar):
|
||||
grammar_name = grammar._grammar_name
|
||||
if 'v0' in grammar_name:
|
||||
from asdl.sql.parser.parser_v0 import ParserV0
|
||||
return ParserV0(grammar)
|
||||
elif 'v1' in grammar_name:
|
||||
from asdl.sql.parser.parser_v1 import ParserV1
|
||||
return ParserV1(grammar)
|
||||
elif 'v2' in grammar_name:
|
||||
from asdl.sql.parser.parser_v2 import ParserV2
|
||||
return ParserV2(grammar)
|
||||
else:
|
||||
raise ValueError('Not recognized grammar name %s' % (grammar_name))
|
||||
|
||||
def parse(self, sql_json: dict):
|
||||
""" sql_json is exactly the 'sql' field of each data sample
|
||||
return AbstractSyntaxTree of sql
|
||||
"""
|
||||
try:
|
||||
ast_node = self.parse_sql(sql_json)
|
||||
return ast_node
|
||||
except Exception as e:
|
||||
print('Something Error happened while parsing:', e)
|
||||
# if fail to parse, just return select * from table(id=0)
|
||||
error_sql = {
|
||||
"select": [False, [(0, [0, [0, 0, False], None])]],
|
||||
"from": {'table_units': [('table_unit', 0)], 'conds': []},
|
||||
"where": [], "groupBy": [], "orderBy": [], "having": [], "limit": None,
|
||||
"intersect": [], "union": [], "except": []
|
||||
}
|
||||
ast_node = self.parse_sql(error_sql)
|
||||
return ast_node
|
||||
|
||||
def parse_sql(self, sql: dict):
|
||||
""" Determine whether sql has intersect/union/except,
|
||||
at most one in the current dict
|
||||
"""
|
||||
for choice in ['intersect', 'union', 'except']:
|
||||
if sql[choice]:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(choice.title()))
|
||||
nested_sql = sql[choice]
|
||||
sql_field1, sql_field2 = ast_node.fields
|
||||
sql_field1.add_value(self.parse_sql_unit(sql))
|
||||
sql_field2.add_value(self.parse_sql_unit(nested_sql))
|
||||
return ast_node
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Single'))
|
||||
ast_node.fields[0].add_value(self.parse_sql_unit(sql))
|
||||
return ast_node
|
||||
|
||||
def parse_sql_unit(self, sql: dict):
|
||||
""" Parse a single sql unit, determine the existence of different clauses
|
||||
"""
|
||||
sql_ctr = ['Complete', 'NoWhere', 'NoGroupBy', 'NoOrderBy', 'OnlyWhere', 'OnlyGroupBy', 'OnlyOrderBy', 'Simple']
|
||||
where_field, groupby_field, orderby_field = [None] * 3
|
||||
if sql['where'] and sql['groupBy'] and sql['orderBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[0]))
|
||||
from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields
|
||||
elif sql['groupBy'] and sql['orderBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[1]))
|
||||
from_field, select_field, groupby_field, orderby_field = ast_node.fields
|
||||
elif sql['where'] and sql['orderBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[2]))
|
||||
from_field, select_field, where_field, orderby_field = ast_node.fields
|
||||
elif sql['where'] and sql['groupBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[3]))
|
||||
from_field, select_field, where_field, groupby_field = ast_node.fields
|
||||
elif sql['where']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[4]))
|
||||
from_field, select_field, where_field = ast_node.fields
|
||||
elif sql['groupBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[5]))
|
||||
from_field, select_field, groupby_field = ast_node.fields
|
||||
elif sql['orderBy']:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[6]))
|
||||
from_field, select_field, orderby_field = ast_node.fields
|
||||
else:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(sql_ctr[7]))
|
||||
from_field, select_field = ast_node.fields
|
||||
self.parse_from(sql['from'], from_field)
|
||||
self.parse_select(sql['select'], select_field)
|
||||
if sql['where']:
|
||||
self.parse_where(sql['where'], where_field)
|
||||
if sql['groupBy']: # if having clause is not empty, groupBy must exist
|
||||
self.parse_groupby(sql['groupBy'], sql['having'], groupby_field)
|
||||
if sql['orderBy']: # if limit is not None, orderBY is not empty
|
||||
self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field)
|
||||
return ast_node
|
||||
|
||||
def parse_select(self, select_clause: list, select_field: RealizedField):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_from(self, from_clause: dict, from_field: RealizedField):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_where(self, where_clause: list, where_field: RealizedField):
|
||||
where_field.add_value(self.parse_conds(where_clause))
|
||||
|
||||
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_conds(self, conds: list):
|
||||
assert len(conds) > 0
|
||||
and_or = (len(conds) - 1) // 2
|
||||
root_node, left_field, right_field = [None] * 3
|
||||
for i in reversed(range(and_or)):
|
||||
and_or_idx = 2 * i + 1
|
||||
conj = conds[and_or_idx]
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(conj.title()))
|
||||
if root_node is None:
|
||||
root_node = ast_node
|
||||
if left_field is not None:
|
||||
left_field.add_value(ast_node)
|
||||
left_field, right_field = ast_node.fields
|
||||
right_field.add_value(self.parse_cond(conds[2 * (i + 1)]))
|
||||
if left_field is None:
|
||||
root_node = self.parse_cond(conds[0])
|
||||
else:
|
||||
left_field.add_value(self.parse_cond(conds[0]))
|
||||
return root_node
|
||||
|
||||
def parse_cond(self, cond: list):
|
||||
not_op, cmp_op, val_unit, val1, val2 = cond
|
||||
not_op = '^' if not_op else ''
|
||||
sql_val = 'sql' if type(val1) == dict else ''
|
||||
op_list = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
cmp_op = not_op + op_list[cmp_op] + sql_val
|
||||
op_dict = {
|
||||
'between': 'Between', '=': 'Eq', '>': 'Gt', '<': 'Lt', '>=': 'Ge', '<=': 'Le', '!=': 'Neq',
|
||||
'insql': 'InSQL', 'like': 'Like', '^insql': 'NotInSQL', '^like': 'NotLike', 'betweensql': 'BetweenSQL', '=sql': 'EqSQL',
|
||||
'>sql': 'GtSQL', '<sql': 'LtSQL', '>=sql': 'GeSQL', '<=sql': 'LeSQL', '!=sql': 'NeqSQL'
|
||||
}
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(op_dict[cmp_op]))
|
||||
val_unit_field = ast_node.fields[0]
|
||||
val_unit_field.add_value(self.parse_val_unit(val_unit))
|
||||
if len(ast_node.fields) == 2:
|
||||
val_field = ast_node.fields[1]
|
||||
val_field.add_value(self.parse_sql(val1))
|
||||
return ast_node
|
||||
|
||||
def parse_val_unit(self, val_unit: list):
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(unit_op_list[unit_op]))
|
||||
if unit_op == 0:
|
||||
ast_node.fields[0].add_value(self.parse_col_unit(col_unit1))
|
||||
else:
|
||||
# ast_node.fields[0].add_value(int(col_unit1[1]))
|
||||
# ast_node.fields[1].add_value(int(col_unit2[1]))
|
||||
ast_node.fields[0].add_value(self.parse_col_unit(col_unit1))
|
||||
ast_node.fields[1].add_value(self.parse_col_unit(col_unit2))
|
||||
return ast_node
|
||||
|
||||
def parse_col_unit(self, col_unit: list):
|
||||
agg_op, col_id, distinct_flag = col_unit
|
||||
agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg_op]))
|
||||
ast_node.fields[0].add_value(int(col_id))
|
||||
return ast_node
|
|
@ -0,0 +1,69 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.parser.parser_base import Parser
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class ParserV0(Parser):
|
||||
""" In this version, we eliminate all cardinality ? and restrict that * must have at least one item
|
||||
"""
|
||||
def parse_select(self, select_clause: list, select_field: RealizedField):
|
||||
"""
|
||||
ignore cases agg(col_id1 op col_id2) and agg(col_id1) op agg(col_id2)
|
||||
"""
|
||||
select_clause = select_clause[1] # list of (agg, val_unit)
|
||||
unit_op_list = ['Unary', 'Minus', 'Plus', 'Times', 'Divide']
|
||||
agg_op_list = ['None', 'Max', 'Min', 'Count', 'Sum', 'Avg']
|
||||
for agg, val_unit in select_clause:
|
||||
if agg != 0: # agg col_id
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
|
||||
col_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(agg_op_list[agg]))
|
||||
col_node.fields[0].add_value(int(val_unit[1][1]))
|
||||
ast_node.fields[0].add_value(col_node)
|
||||
else: # binary_op col_id1 col_id2
|
||||
ast_node = self.parse_val_unit(val_unit)
|
||||
select_field.add_value(ast_node)
|
||||
|
||||
def parse_from(self, from_clause: dict, from_field: RealizedField):
|
||||
"""
|
||||
Ignore from conditions, since it is not evaluated in evaluation script
|
||||
"""
|
||||
table_units = from_clause['table_units']
|
||||
t = table_units[0][0]
|
||||
if t == 'table_unit':
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromTable'))
|
||||
tables_field = ast_node.fields[0]
|
||||
for _, v in table_units:
|
||||
tables_field.add_value(int(v))
|
||||
else:
|
||||
assert t == 'sql'
|
||||
v = table_units[0][1]
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
|
||||
ast_node.fields[0].add_value(self.parse_sql(v))
|
||||
from_field.add_value(ast_node)
|
||||
|
||||
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
|
||||
col_ids = []
|
||||
for col_unit in groupby_clause:
|
||||
col_ids.append(col_unit[1]) # agg is None and isDistinct False
|
||||
if having_clause:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Having'))
|
||||
col_units_field, having_fields = ast_node.fields
|
||||
having_fields.add_value(self.parse_conds(having_clause))
|
||||
else:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('NoHaving'))
|
||||
col_units_field = ast_node.fields[0]
|
||||
for col_unit in groupby_clause:
|
||||
col_units_field.add_value(self.parse_col_unit(col_unit))
|
||||
groupby_field.add_value(ast_node)
|
||||
|
||||
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
|
||||
if limit is None:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Asc')) if orderby_clause[0] == 'asc' \
|
||||
else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Desc'))
|
||||
else:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('AscLimit')) if orderby_clause[0] == 'asc' \
|
||||
else AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('DescLimit'))
|
||||
col_units_field = ast_node.fields[0]
|
||||
for val_unit in orderby_clause[1]:
|
||||
col_units_field.add_value(self.parse_col_unit(val_unit[1]))
|
||||
orderby_field.add_value(ast_node)
|
|
@ -0,0 +1,83 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.parser.parser_base import Parser
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class ParserV1(Parser):
|
||||
""" In this version, we eliminate all cardinality * and use ?
|
||||
"""
|
||||
def parse_sql_unit(self, sql: dict):
|
||||
""" Parse a single sql unit, determine the existence of different clauses
|
||||
"""
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('SQL'))
|
||||
from_field, select_field, where_field, groupby_field, orderby_field = ast_node.fields
|
||||
self.parse_from(sql['from'], from_field)
|
||||
self.parse_select(sql['select'], select_field)
|
||||
if sql['where']:
|
||||
self.parse_where(sql['where'], where_field)
|
||||
if sql['groupBy']: # if having clause is not empty, groupBy must exist
|
||||
self.parse_groupby(sql['groupBy'], sql['having'], groupby_field)
|
||||
if sql['orderBy']: # if limit is not None, orderBY is not empty
|
||||
self.parse_orderby(sql['orderBy'], sql['limit'], orderby_field)
|
||||
return ast_node
|
||||
|
||||
def parse_select(self, select_clause: list, select_field: RealizedField):
|
||||
select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag
|
||||
select_num = min(5, len(select_clause))
|
||||
select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1]))
|
||||
for i, (agg, val_unit) in enumerate(select_clause):
|
||||
if i >= 5: break
|
||||
if agg != 0: # MAX/MIN/COUNT/SUM/AVG
|
||||
val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
|
||||
col_unit = [agg] + val_unit[1][1:]
|
||||
val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit))
|
||||
else:
|
||||
val_unit_ast = self.parse_val_unit(val_unit)
|
||||
ast_node.fields[i].add_value(val_unit_ast)
|
||||
select_field.add_value(ast_node)
|
||||
|
||||
def parse_from(self, from_clause: dict, from_field: RealizedField):
|
||||
""" Ignore from conditions, since it is not evaluated in evaluation script
|
||||
"""
|
||||
table_units = from_clause['table_units']
|
||||
t = table_units[0][0]
|
||||
if t == 'table_unit':
|
||||
table_num = min(6, len(table_units))
|
||||
table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1]))
|
||||
for i, (_, tab_id) in enumerate(table_units):
|
||||
if i >= 6: break
|
||||
ast_node.fields[i].add_value(int(tab_id))
|
||||
else:
|
||||
assert t == 'sql'
|
||||
v = table_units[0][1]
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
|
||||
ast_node.fields[0].add_value(self.parse_sql(v))
|
||||
from_field.add_value(ast_node)
|
||||
|
||||
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
|
||||
groupby_ctr = ['GroupByOne', 'GroupByTwo']
|
||||
groupby_num = min(2, len(groupby_clause)) - 1
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num]))
|
||||
if having_clause:
|
||||
having_field = ast_node.fields[-1]
|
||||
having_field.add_value(self.parse_conds(having_clause))
|
||||
for i, col_unit in enumerate(groupby_clause):
|
||||
if i >= 2: break
|
||||
# ast_node.fields[i].add_value(int(col_unit[1]))
|
||||
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
|
||||
groupby_field.add_value(ast_node)
|
||||
|
||||
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
|
||||
orderby_num = min(2, len(orderby_clause[1]))
|
||||
num_str = 'One' if orderby_num == 1 else 'Two'
|
||||
order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc'
|
||||
limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str))
|
||||
for i, val_unit in enumerate(orderby_clause[1]):
|
||||
if i >= 2: break
|
||||
col_unit = val_unit[1]
|
||||
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
|
||||
# ast_node.fields[i].add_value(self.parse_val_unit(val_unit))
|
||||
orderby_field.add_value(ast_node)
|
|
@ -0,0 +1,71 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.parser.parser_base import Parser
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class ParserV2(Parser):
|
||||
""" In this version, we remove all cardinality ? or *
|
||||
by enumerating all different lengths of item list, such as SelectOne, SelectTwo
|
||||
"""
|
||||
def parse_select(self, select_clause: list, select_field: RealizedField):
|
||||
select_clause = select_clause[1] # list of (agg, val_unit), ignore distinct flag
|
||||
select_num = min(5, len(select_clause))
|
||||
select_ctr = ['SelectOne', 'SelectTwo', 'SelectThree', 'SelectFour', 'SelectFive']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(select_ctr[select_num - 1]))
|
||||
for i, (agg, val_unit) in enumerate(select_clause):
|
||||
if i >= 5: break
|
||||
if agg != 0: # MAX/MIN/COUNT/SUM/AVG
|
||||
val_unit_ast = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('Unary'))
|
||||
col_unit = [agg] + val_unit[1][1:]
|
||||
val_unit_ast.fields[0].add_value(self.parse_col_unit(col_unit))
|
||||
else:
|
||||
val_unit_ast = self.parse_val_unit(val_unit)
|
||||
ast_node.fields[i].add_value(val_unit_ast)
|
||||
select_field.add_value(ast_node)
|
||||
|
||||
def parse_from(self, from_clause: dict, from_field: RealizedField):
|
||||
""" Ignore from conditions, since it is not evaluated in evaluation script
|
||||
"""
|
||||
table_units = from_clause['table_units']
|
||||
t = table_units[0][0]
|
||||
if t == 'table_unit':
|
||||
table_num = min(6, len(table_units))
|
||||
table_ctr = ['FromOneTable', 'FromTwoTable', 'FromThreeTable', 'FromFourTable', 'FromFiveTable', 'FromSixTable']
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(table_ctr[table_num - 1]))
|
||||
for i, (_, tab_id) in enumerate(table_units):
|
||||
if i >= 6: break
|
||||
ast_node.fields[i].add_value(int(tab_id))
|
||||
else:
|
||||
assert t == 'sql'
|
||||
v = table_units[0][1]
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name('FromSQL'))
|
||||
ast_node.fields[0].add_value(self.parse_sql(v))
|
||||
from_field.add_value(ast_node)
|
||||
|
||||
def parse_groupby(self, groupby_clause: list, having_clause: list, groupby_field: RealizedField):
|
||||
groupby_ctr = ['OneNoHaving', 'TwoNoHaving', 'OneHaving', 'TwoHaving']
|
||||
groupby_num = min(2, len(groupby_clause))
|
||||
if having_clause:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num + 1]))
|
||||
having_field = ast_node.fields[-1]
|
||||
having_field.add_value(self.parse_conds(having_clause))
|
||||
else:
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(groupby_ctr[groupby_num - 1]))
|
||||
for i, col_unit in enumerate(groupby_clause):
|
||||
if i >= 2: break
|
||||
# ast_node.fields[i].add_value(int(col_unit[1]))
|
||||
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
|
||||
groupby_field.add_value(ast_node)
|
||||
|
||||
def parse_orderby(self, orderby_clause: list, limit: int, orderby_field: RealizedField):
|
||||
orderby_num = min(2, len(orderby_clause[1]))
|
||||
num_str = 'One' if orderby_num == 1 else 'Two'
|
||||
order_str = 'Asc' if orderby_clause[0] == 'asc' else 'Desc'
|
||||
limit_str = 'Limit' if limit else '' # e.g. OneAsc, TwoDescLimit
|
||||
ast_node = AbstractSyntaxTree(self.grammar.get_prod_by_ctr_name(num_str + order_str + limit_str))
|
||||
for i, val_unit in enumerate(orderby_clause[1]):
|
||||
if i >= 2: break
|
||||
col_unit = val_unit[1]
|
||||
ast_node.fields[i].add_value(self.parse_col_unit(col_unit))
|
||||
# ast_node.fields[i].add_value(self.parse_val_unit(val_unit))
|
||||
orderby_field.add_value(ast_node)
|
|
@ -0,0 +1,174 @@
|
|||
# coding=utf-8
|
||||
import json, os, sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
from asdl.sql.parser.parser_base import Parser
|
||||
from asdl.sql.unparser.unparser_base import UnParser
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
from asdl.transition_system import GenTokenAction, TransitionSystem, ApplyRuleAction, ReduceAction
|
||||
|
||||
class SelectColumnAction(GenTokenAction):
|
||||
def __init__(self, column_id):
|
||||
super(SelectColumnAction, self).__init__(column_id)
|
||||
|
||||
@property
|
||||
def column_id(self):
|
||||
return self.token
|
||||
|
||||
def __repr__(self):
|
||||
return 'SelectColumnAction[id=%s]' % self.column_id
|
||||
|
||||
class SelectTableAction(GenTokenAction):
|
||||
def __init__(self, table_id):
|
||||
super(SelectTableAction, self).__init__(table_id)
|
||||
|
||||
@property
|
||||
def table_id(self):
|
||||
return self.token
|
||||
|
||||
def __repr__(self):
|
||||
return 'SelectTableAction[id=%s]' % self.table_id
|
||||
|
||||
class SQLTransitionSystem(TransitionSystem):
|
||||
|
||||
def __init__(self, grammar):
|
||||
self.grammar = grammar
|
||||
self.parser = Parser.from_grammar(self.grammar)
|
||||
self.unparser = UnParser.from_grammar(self.grammar)
|
||||
|
||||
def ast_to_surface_code(self, asdl_ast, table, *args, **kargs):
|
||||
return self.unparser.unparse(asdl_ast, table, *args, **kargs)
|
||||
|
||||
def compare_ast(self, hyp_ast, ref_ast):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_code(self, code, mode):
|
||||
raise NotImplementedError
|
||||
|
||||
def surface_code_to_ast(self, code):
|
||||
return self.parser.parse(code)
|
||||
|
||||
def get_valid_continuation_types(self, hyp):
|
||||
if hyp.tree:
|
||||
if self.grammar.is_composite_type(hyp.frontier_field.type):
|
||||
if hyp.frontier_field.cardinality == 'single':
|
||||
return ApplyRuleAction,
|
||||
elif hyp.frontier_field.cardinality == 'multiple':
|
||||
if len(hyp.frontier_field.value) == 0:
|
||||
return ApplyRuleAction,
|
||||
else:
|
||||
return ApplyRuleAction, ReduceAction
|
||||
else:
|
||||
return ApplyRuleAction, ReduceAction
|
||||
elif hyp.frontier_field.type.name == 'col_id':
|
||||
if hyp.frontier_field.cardinality == 'single':
|
||||
return SelectColumnAction,
|
||||
elif hyp.frontier_field.cardinality == 'multiple':
|
||||
if len(hyp.frontier_field.value) == 0:
|
||||
return SelectColumnAction,
|
||||
else:
|
||||
return SelectColumnAction, ReduceAction
|
||||
else: # optional, not used
|
||||
return SelectColumnAction, ReduceAction
|
||||
elif hyp.frontier_field.type.name == 'tab_id':
|
||||
if hyp.frontier_field.cardinality == 'single':
|
||||
return SelectTableAction,
|
||||
elif hyp.frontier_field.cardinality == 'multiple':
|
||||
if len(hyp.frontier_field.value) == 0:
|
||||
return SelectTableAction,
|
||||
else:
|
||||
return SelectTableAction, ReduceAction
|
||||
else: # optional, not used
|
||||
return SelectTableAction, ReduceAction
|
||||
else: # not used now
|
||||
return GenTokenAction,
|
||||
else:
|
||||
return ApplyRuleAction,
|
||||
|
||||
def get_primitive_field_actions(self, realized_field):
|
||||
if realized_field.type.name == 'col_id':
|
||||
if realized_field.cardinality == 'multiple':
|
||||
action_list = []
|
||||
for idx in realized_field.value:
|
||||
action_list.append(SelectColumnAction(int(idx)))
|
||||
return action_list
|
||||
elif realized_field.value is not None:
|
||||
return [SelectColumnAction(int(realized_field.value))]
|
||||
else:
|
||||
return []
|
||||
elif realized_field.type.name == 'tab_id':
|
||||
if realized_field.cardinality == 'multiple':
|
||||
action_list = []
|
||||
for idx in realized_field.value:
|
||||
action_list.append(SelectTableAction(int(idx)))
|
||||
return action_list
|
||||
elif realized_field.value is not None:
|
||||
return [SelectTableAction(int(realized_field.value))]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
raise ValueError('unknown primitive field type')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
try:
|
||||
from evaluation import evaluate, build_foreign_key_map_from_json
|
||||
except Exception:
|
||||
print('Cannot find evaluator ...')
|
||||
grammar = ASDLGrammar.from_filepath('asdl/sql/grammar/sql_asdl_v2.txt')
|
||||
print('Total number of productions:', len(grammar))
|
||||
for each in grammar.productions:
|
||||
print(each)
|
||||
print('Total number of types:', len(grammar.types))
|
||||
for each in grammar.types:
|
||||
print(each)
|
||||
print('Total number of fields:', len(grammar.fields))
|
||||
for each in grammar.fields:
|
||||
print(each)
|
||||
|
||||
spider_trans = SQLTransitionSystem(grammar)
|
||||
kmaps = build_foreign_key_map_from_json('data/tables.json')
|
||||
dbs_list = json.load(open('data/tables.json', 'r'))
|
||||
dbs = {}
|
||||
for each in dbs_list:
|
||||
dbs[each['db_id']] = each
|
||||
|
||||
train = json.load(open('data/train.json', 'r'))
|
||||
train_db = [ex['db_id'] for ex in train]
|
||||
train = [ex['sql'] for ex in train]
|
||||
dev = json.load(open('data/dev.json', 'r'))
|
||||
dev_db = [ex['db_id'] for ex in dev]
|
||||
dev = [ex['sql'] for ex in dev]
|
||||
|
||||
recovered_sqls = []
|
||||
for idx in range(len(train)):
|
||||
sql_ast = spider_trans.surface_code_to_ast(train[idx])
|
||||
sql_ast.sanity_check()
|
||||
# print(spider_trans.get_actions(sql_ast))
|
||||
recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[train_db[idx]])
|
||||
# print(recovered_sql)
|
||||
recovered_sqls.append(recovered_sql)
|
||||
|
||||
with open('data/train_pred.sql', 'w') as of:
|
||||
for each in recovered_sqls:
|
||||
of.write(each + '\n')
|
||||
with open('data/eval_train.log', 'w') as of:
|
||||
old_print = sys.stdout
|
||||
sys.stdout = of
|
||||
evaluate('data/train_gold.sql', 'data/train_pred.sql', 'data/database', 'match', kmaps)
|
||||
sys.stdout = old_print
|
||||
|
||||
recovered_sqls = []
|
||||
for idx in range(len(dev)):
|
||||
sql_ast = spider_trans.surface_code_to_ast(dev[idx])
|
||||
sql_ast.sanity_check()
|
||||
recovered_sql = spider_trans.ast_to_surface_code(sql_ast, dbs[dev_db[idx]])
|
||||
recovered_sqls.append(recovered_sql)
|
||||
with open('data/dev_pred.sql', 'w') as of:
|
||||
for each in recovered_sqls:
|
||||
of.write(each + '\n')
|
||||
with open('data/eval_dev.log', 'w') as of:
|
||||
old_print = sys.stdout
|
||||
sys.stdout = of
|
||||
evaluate('data/dev_gold.sql', 'data/dev_pred.sql', 'data/database', 'match', kmaps)
|
||||
sys.stdout = old_print
|
|
@ -0,0 +1,165 @@
|
|||
#coding=utf8
|
||||
|
||||
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class UnParser():
|
||||
|
||||
def __init__(self, grammar: ASDLGrammar):
|
||||
""" ASDLGrammar """
|
||||
super(UnParser, self).__init__()
|
||||
self.grammar = grammar
|
||||
|
||||
@classmethod
|
||||
def from_grammar(cls, grammar: ASDLGrammar):
|
||||
grammar_name = grammar._grammar_name
|
||||
if 'v0' in grammar_name:
|
||||
from asdl.sql.unparser.unparser_v0 import UnParserV0
|
||||
return UnParserV0(grammar)
|
||||
elif 'v1' in grammar_name:
|
||||
from asdl.sql.unparser.unparser_v1 import UnParserV1
|
||||
return UnParserV1(grammar)
|
||||
elif 'v2' in grammar_name:
|
||||
from asdl.sql.unparser.unparser_v2 import UnParserV2
|
||||
return UnParserV2(grammar)
|
||||
else:
|
||||
raise ValueError('Not recognized grammar name %s' % (grammar_name))
|
||||
|
||||
def unparse(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
try:
|
||||
sql = self.unparse_sql(sql_ast, db, *args, **kargs)
|
||||
sql = ' '.join([i for i in sql.split(' ') if i != ''])
|
||||
return sql
|
||||
except Exception as e:
|
||||
print('Something Error happened while unparsing:', e)
|
||||
return 'SELECT * FROM %s' % (db['table_names_original'][0])
|
||||
|
||||
def unparse_sql(self, sql_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
prod_name = sql_ast.production.constructor.name
|
||||
if prod_name == 'Intersect':
|
||||
return '%s INTERSECT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
|
||||
elif prod_name == 'Union':
|
||||
return '%s UNION %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
|
||||
elif prod_name == 'Except':
|
||||
return '%s EXCEPT %s' % (self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs), self.unparse_sql_unit(sql_ast.fields[1], db, *args, **kargs))
|
||||
else:
|
||||
return self.unparse_sql_unit(sql_ast.fields[0], db, *args, **kargs)
|
||||
|
||||
def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs):
|
||||
sql_ast = sql_field.value
|
||||
prod_name = sql_ast.production.constructor.name
|
||||
from_str = self.unparse_from(sql_ast.fields[0], db, *args, **kargs)
|
||||
select_str = self.unparse_select(sql_ast.fields[1], db, *args, **kargs)
|
||||
if prod_name == 'Complete':
|
||||
return 'SELECT %s FROM %s WHERE %s GROUP BY %s ORDER BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
|
||||
self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs),
|
||||
self.unparse_orderby(sql_ast.fields[4], db, *args, **kargs)
|
||||
)
|
||||
elif prod_name == 'NoWhere':
|
||||
return 'SELECT %s FROM %s GROUP BY %s ORDER BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs),
|
||||
self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs),
|
||||
)
|
||||
elif prod_name == 'NoGroupBy':
|
||||
return 'SELECT %s FROM %s WHERE %s ORDER BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
|
||||
self.unparse_orderby(sql_ast.fields[3], db, *args, **kargs),
|
||||
)
|
||||
elif prod_name == 'NoOrderBy':
|
||||
return 'SELECT %s FROM %s WHERE %s GROUP BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_where(sql_ast.fields[2], db, *args, **kargs),
|
||||
self.unparse_groupby(sql_ast.fields[3], db, *args, **kargs),
|
||||
)
|
||||
elif prod_name == 'OnlyWhere':
|
||||
return 'SELECT %s FROM %s WHERE %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_where(sql_ast.fields[2], db, *args, **kargs)
|
||||
)
|
||||
elif prod_name == 'OnlyGroupBy':
|
||||
return 'SELECT %s FROM %s GROUP BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_groupby(sql_ast.fields[2], db, *args, **kargs)
|
||||
)
|
||||
elif prod_name == 'OnlyOrderBy':
|
||||
return 'SELECT %s FROM %s ORDER BY %s' % (
|
||||
select_str, from_str,
|
||||
self.unparse_orderby(sql_ast.fields[2], db, *args, **kargs)
|
||||
)
|
||||
else:
|
||||
return 'SELECT %s FROM %s' % (select_str, from_str)
|
||||
|
||||
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def unparse_where(self, where_field: RealizedField, db: dict, *args, **kargs):
|
||||
return self.unparse_conds(where_field.value, db, *args, **kargs)
|
||||
|
||||
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def unparse_conds(self, conds_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
ctr_name = conds_ast.production.constructor.name
|
||||
if ctr_name in ['And', 'Or']:
|
||||
left_cond, right_cond = conds_ast.fields
|
||||
return self.unparse_conds(left_cond.value, db, *args, **kargs) + ' ' + ctr_name.upper() + ' ' + \
|
||||
self.unparse_conds(right_cond.value, db, *args, **kargs)
|
||||
else:
|
||||
return self.unparse_cond(conds_ast, db, *args, **kargs)
|
||||
|
||||
def unparse_cond(self, cond_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
ctr_name = cond_ast.production.constructor.name
|
||||
val_unit_str = self.unparse_val_unit(cond_ast.fields[0].value, db, *args, **kargs)
|
||||
val_str = '( ' + self.unparse_sql(cond_ast.fields[1].value, db, *args, **kargs) + ' )' if len(cond_ast.fields) == 2 else '"value"'
|
||||
if ctr_name.startswith('Between'):
|
||||
return val_unit_str + ' BETWEEN ' + val_str + ' AND "value"'
|
||||
else:
|
||||
op_dict = {
|
||||
'Between': ' BETWEEN ', 'Eq': ' = ', 'Gt': ' > ', 'Lt': ' < ', 'Ge': ' >= ', 'Le': ' <= ', 'Neq': ' != ',
|
||||
'In': ' IN ', 'Like': ' LIKE ', 'NotIn': ' NOT IN ', 'NotLike': ' NOT LIKE '
|
||||
}
|
||||
ctr_name = ctr_name if 'SQL' not in ctr_name else ctr_name[:ctr_name.index('SQL')]
|
||||
op = op_dict[ctr_name]
|
||||
return op.join([val_unit_str, val_str])
|
||||
|
||||
def unparse_val_unit(self, val_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
unit_op = val_unit_ast.production.constructor.name
|
||||
if unit_op == 'Unary':
|
||||
return self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs)
|
||||
else:
|
||||
binary = {'Minus': ' - ', 'Plus': ' + ', 'Times': ' * ', 'Divide': ' / '}
|
||||
op = binary[unit_op]
|
||||
return op.join([self.unparse_col_unit(val_unit_ast.fields[0].value, db, *args, **kargs),
|
||||
self.unparse_col_unit(val_unit_ast.fields[1].value, db, *args, **kargs)])
|
||||
# col_id1, col_id2 = int(val_unit_ast.fields[0].value), int(val_unit_ast.fields[1].value)
|
||||
# tab_id1, col_name1 = db['column_names_original'][col_id1]
|
||||
# if col_id1 != 0:
|
||||
# tab_name1 = db['table_names_original'][tab_id1]
|
||||
# col_name1 = tab_name1 + '.' + col_name1
|
||||
# tab_id2, col_name2 = db['column_names_original'][col_id2]
|
||||
# if col_id2 != 0:
|
||||
# tab_name2 = db['table_names_original'][tab_id2]
|
||||
# col_name2 = tab_name2 + '.' + col_name2
|
||||
# return op.join([col_name1, col_name2])
|
||||
|
||||
def unparse_col_unit(self, col_unit_ast: AbstractSyntaxTree, db: dict, *args, **kargs):
|
||||
agg = col_unit_ast.production.constructor.name
|
||||
col_id = int(col_unit_ast.fields[0].value)
|
||||
tab_id, col_name = db['column_names_original'][col_id]
|
||||
if col_id != 0:
|
||||
tab_name = db['table_names_original'][tab_id]
|
||||
col_name = tab_name + '.' + col_name
|
||||
if agg == 'None':
|
||||
return col_name
|
||||
else: # Max/Min/Count/Sum/Avg
|
||||
return agg.upper() + '(' + col_name + ')'
|
|
@ -0,0 +1,59 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.unparser.unparser_base import UnParser
|
||||
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class UnParserV0(UnParser):
|
||||
|
||||
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
|
||||
select_list = select_field.value
|
||||
select_items = []
|
||||
for val_unit_ast in select_list:
|
||||
val_unit_str = self.unparse_val_unit(val_unit_ast, db, *args, **kargs)
|
||||
select_items.append(val_unit_str)
|
||||
return ' , '.join(select_items)
|
||||
|
||||
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
|
||||
from_ast = from_field.value
|
||||
ctr_name = from_ast.production.constructor.name
|
||||
if ctr_name == 'FromTable':
|
||||
tab_ids = from_ast.fields[0].value
|
||||
if len(tab_ids) == 1:
|
||||
return db['table_names_original'][tab_ids[0]]
|
||||
else:
|
||||
tab_names = [db['table_names_original'][i] for i in tab_ids]
|
||||
return ' JOIN '.join(tab_names)
|
||||
else:
|
||||
sql_ast = from_ast.fields[0].value
|
||||
return '( ' + self.unparse_sql(sql_ast, db, *args, **kargs) + ' )'
|
||||
|
||||
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
|
||||
groupby_ast = groupby_field.value
|
||||
ctr_name = groupby_ast.production.constructor.name
|
||||
groupby_str = []
|
||||
for col_unit_ast in groupby_ast.fields[0].value:
|
||||
groupby_str.append(self.unparse_col_unit(col_unit_ast, db, *args, **kargs))
|
||||
groupby_str = ' , '.join(groupby_str)
|
||||
if ctr_name == 'Having':
|
||||
having = groupby_ast.fields[1].value
|
||||
having_str = self.unparse_conds(having, db, *args, **kargs)
|
||||
return groupby_str + ' HAVING ' + having_str
|
||||
else:
|
||||
return groupby_str
|
||||
|
||||
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
|
||||
orderby_ast = orderby_field.value
|
||||
ctr_name = orderby_ast.production.constructor.name.lower()
|
||||
val_unit_str = []
|
||||
for val_unit_ast in orderby_ast.fields[0].value:
|
||||
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
|
||||
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
|
||||
val_unit_str = ' , '.join(val_unit_str)
|
||||
if 'asc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s ASC LIMIT 1' % (val_unit_str)
|
||||
elif 'asc' in ctr_name:
|
||||
return '%s ASC' % (val_unit_str)
|
||||
elif 'desc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s DESC LIMIT 1' % (val_unit_str)
|
||||
else:
|
||||
return '%s DESC' % (val_unit_str)
|
|
@ -0,0 +1,80 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.unparser.unparser_base import UnParser
|
||||
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class UnParserV1(UnParser):
|
||||
|
||||
def unparse_sql_unit(self, sql_field: RealizedField, db: dict, *args, **kargs):
|
||||
sql_ast = sql_field.value
|
||||
from_field, select_field, where_field, groupby_field, orderby_field = sql_ast.fields
|
||||
from_str = 'FROM ' + self.unparse_from(from_field, db, *args, **kargs)
|
||||
select_str = 'SELECT ' + self.unparse_select(select_field, db, *args, **kargs)
|
||||
where_str, groupby_str, orderby_str = '', '', ''
|
||||
if where_field.value is not None:
|
||||
where_str = 'WHERE ' + self.unparse_where(where_field, db, *args, **kargs)
|
||||
if groupby_field.value is not None:
|
||||
groupby_str = 'GROUP BY ' + self.unparse_groupby(groupby_field, db, *args, **kargs)
|
||||
if orderby_field.value is not None:
|
||||
orderby_str = 'ORDER BY ' + self.unparse_orderby(orderby_field, db, *args, **kargs)
|
||||
return ' '.join([select_str, from_str, where_str, groupby_str, orderby_str])
|
||||
|
||||
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
|
||||
select_ast = select_field.value
|
||||
select_list = select_ast.fields
|
||||
select_items = []
|
||||
for val_unit_field in select_list:
|
||||
val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs)
|
||||
select_items.append(val_unit_str)
|
||||
return ' , '.join(select_items)
|
||||
|
||||
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
|
||||
from_ast = from_field.value
|
||||
ctr_name = from_ast.production.constructor.name
|
||||
if 'Table' in ctr_name:
|
||||
tab_names = []
|
||||
for tab_field in from_ast.fields:
|
||||
tab_name = db['table_names_original'][int(tab_field.value)]
|
||||
tab_names.append(tab_name)
|
||||
return ' JOIN '.join(tab_names)
|
||||
else:
|
||||
return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )'
|
||||
|
||||
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
|
||||
groupby_ast = groupby_field.value
|
||||
ctr_name = groupby_ast.production.constructor.name
|
||||
groupby_str = []
|
||||
num = len(groupby_ast.fields) - 1
|
||||
for col_id_field in groupby_ast.fields[:num]:
|
||||
# col_id = int(col_id_field.value)
|
||||
# tab_id, col_name = db['column_names_original'][col_id]
|
||||
# if col_id != 0:
|
||||
# tab_name = db['table_names_original'][tab_id]
|
||||
# col_name = tab_name + '.' + col_name
|
||||
col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs)
|
||||
groupby_str.append(col_name)
|
||||
groupby_str = ' , '.join(groupby_str)
|
||||
if groupby_ast.fields[-1].value is not None:
|
||||
having = groupby_ast.fields[-1].value
|
||||
having_str = self.unparse_conds(having, db, *args, **kargs)
|
||||
return groupby_str + ' HAVING ' + having_str
|
||||
else:
|
||||
return groupby_str
|
||||
|
||||
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
|
||||
orderby_ast = orderby_field.value
|
||||
ctr_name = orderby_ast.production.constructor.name.lower()
|
||||
val_unit_str = []
|
||||
for val_unit_field in orderby_ast.fields:
|
||||
val_unit_ast = val_unit_field.value
|
||||
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
|
||||
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
|
||||
val_unit_str = ' , '.join(val_unit_str)
|
||||
if 'asc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s ASC LIMIT 1' % (val_unit_str)
|
||||
elif 'asc' in ctr_name:
|
||||
return '%s ASC' % (val_unit_str)
|
||||
elif 'desc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s DESC LIMIT 1' % (val_unit_str)
|
||||
else:
|
||||
return '%s DESC' % (val_unit_str)
|
|
@ -0,0 +1,66 @@
|
|||
#coding=utf8
|
||||
from asdl.sql.unparser.unparser_base import UnParser
|
||||
from asdl.asdl import ASDLGrammar, ASDLConstructor, ASDLProduction
|
||||
from asdl.asdl_ast import RealizedField, AbstractSyntaxTree
|
||||
|
||||
class UnParserV2(UnParser):
|
||||
|
||||
def unparse_select(self, select_field: RealizedField, db: dict, *args, **kargs):
|
||||
select_ast = select_field.value
|
||||
select_list = select_ast.fields
|
||||
select_items = []
|
||||
for val_unit_field in select_list:
|
||||
val_unit_str = self.unparse_val_unit(val_unit_field.value, db, *args, **kargs)
|
||||
select_items.append(val_unit_str)
|
||||
return ' , '.join(select_items)
|
||||
|
||||
def unparse_from(self, from_field: RealizedField, db: dict, *args, **kargs):
|
||||
from_ast = from_field.value
|
||||
ctr_name = from_ast.production.constructor.name
|
||||
if 'Table' in ctr_name:
|
||||
tab_names = []
|
||||
for tab_field in from_ast.fields:
|
||||
tab_name = db['table_names_original'][int(tab_field.value)]
|
||||
tab_names.append(tab_name)
|
||||
return ' JOIN '.join(tab_names)
|
||||
else:
|
||||
return '( ' + self.unparse_sql(from_ast.fields[0].value, db, *args, **kargs) + ' )'
|
||||
|
||||
def unparse_groupby(self, groupby_field: RealizedField, db: dict, *args, **kargs):
|
||||
groupby_ast = groupby_field.value
|
||||
ctr_name = groupby_ast.production.constructor.name
|
||||
groupby_str = []
|
||||
num = len(groupby_ast.fields) if 'NoHaving' in ctr_name else len(groupby_ast.fields) - 1
|
||||
for col_id_field in groupby_ast.fields[:num]:
|
||||
# col_id = int(col_id_field.value)
|
||||
# tab_id, col_name = db['column_names_original'][col_id]
|
||||
# if col_id != 0:
|
||||
# tab_name = db['table_names_original'][tab_id]
|
||||
# col_name = tab_name + '.' + col_name
|
||||
col_name = self.unparse_col_unit(col_id_field.value, db, *args, **kargs)
|
||||
groupby_str.append(col_name)
|
||||
groupby_str = ' , '.join(groupby_str)
|
||||
if 'NoHaving' in ctr_name:
|
||||
return groupby_str
|
||||
else:
|
||||
having = groupby_ast.fields[-1].value
|
||||
having_str = self.unparse_conds(having, db, *args, **kargs)
|
||||
return groupby_str + ' HAVING ' + having_str
|
||||
|
||||
def unparse_orderby(self, orderby_field: RealizedField, db: dict, *args, **kargs):
|
||||
orderby_ast = orderby_field.value
|
||||
ctr_name = orderby_ast.production.constructor.name.lower()
|
||||
val_unit_str = []
|
||||
for val_unit_field in orderby_ast.fields:
|
||||
val_unit_ast = val_unit_field.value
|
||||
val_unit_str.append(self.unparse_col_unit(val_unit_ast, db, *args, **kargs))
|
||||
# val_unit_str.append(self.unparse_val_unit(val_unit_ast, db, *args, **kargs))
|
||||
val_unit_str = ' , '.join(val_unit_str)
|
||||
if 'asc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s ASC LIMIT 1' % (val_unit_str)
|
||||
elif 'asc' in ctr_name:
|
||||
return '%s ASC' % (val_unit_str)
|
||||
elif 'desc' in ctr_name and 'limit' in ctr_name:
|
||||
return '%s DESC LIMIT 1' % (val_unit_str)
|
||||
else:
|
||||
return '%s DESC' % (val_unit_str)
|
|
@ -0,0 +1,136 @@
|
|||
# coding=utf-8
|
||||
|
||||
|
||||
class Action(object):
|
||||
pass
|
||||
|
||||
|
||||
class ApplyRuleAction(Action):
|
||||
def __init__(self, production):
|
||||
self.production = production
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.production)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ApplyRuleAction) and self.production == other.production
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
return 'ApplyRule[%s]' % self.production.__repr__()
|
||||
|
||||
|
||||
class GenTokenAction(Action):
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
|
||||
def is_stop_signal(self):
|
||||
return self.token == '</primitive>'
|
||||
|
||||
def __repr__(self):
|
||||
return 'GenToken[%s]' % self.token
|
||||
|
||||
|
||||
class ReduceAction(Action):
|
||||
def __repr__(self):
|
||||
return 'Reduce'
|
||||
|
||||
|
||||
class TransitionSystem(object):
|
||||
def __init__(self, grammar):
|
||||
self.grammar = grammar
|
||||
|
||||
def get_actions(self, asdl_ast):
|
||||
"""
|
||||
generate action sequence given the ASDL Syntax Tree
|
||||
"""
|
||||
|
||||
actions = []
|
||||
|
||||
parent_action = ApplyRuleAction(asdl_ast.production)
|
||||
actions.append(parent_action)
|
||||
|
||||
for field in asdl_ast.fields:
|
||||
# is a composite field
|
||||
if self.grammar.is_composite_type(field.type):
|
||||
if field.cardinality == 'single':
|
||||
field_actions = self.get_actions(field.value)
|
||||
else:
|
||||
field_actions = []
|
||||
|
||||
if field.value is not None:
|
||||
if field.cardinality == 'multiple':
|
||||
for val in field.value:
|
||||
cur_child_actions = self.get_actions(val)
|
||||
field_actions.extend(cur_child_actions)
|
||||
elif field.cardinality == 'optional':
|
||||
field_actions = self.get_actions(field.value)
|
||||
|
||||
# if an optional field is filled, then do not need Reduce action
|
||||
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
|
||||
field_actions.append(ReduceAction())
|
||||
else: # is a primitive field
|
||||
field_actions = self.get_primitive_field_actions(field)
|
||||
|
||||
# if an optional field is filled, then do not need Reduce action
|
||||
if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions:
|
||||
# reduce action
|
||||
field_actions.append(ReduceAction())
|
||||
|
||||
actions.extend(field_actions)
|
||||
|
||||
return actions
|
||||
|
||||
def tokenize_code(self, code, mode):
|
||||
raise NotImplementedError
|
||||
|
||||
def compare_ast(self, hyp_ast, ref_ast):
|
||||
raise NotImplementedError
|
||||
|
||||
def ast_to_surface_code(self, asdl_ast):
|
||||
raise NotImplementedError
|
||||
|
||||
def surface_code_to_ast(self, code):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_primitive_field_actions(self, realized_field):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_valid_continuation_types(self, hyp):
|
||||
if hyp.tree:
|
||||
if self.grammar.is_composite_type(hyp.frontier_field.type):
|
||||
if hyp.frontier_field.cardinality == 'single':
|
||||
return ApplyRuleAction,
|
||||
else: # optional, multiple
|
||||
return ApplyRuleAction, ReduceAction
|
||||
else:
|
||||
if hyp.frontier_field.cardinality == 'single':
|
||||
return GenTokenAction,
|
||||
elif hyp.frontier_field.cardinality == 'optional':
|
||||
if hyp._value_buffer:
|
||||
return GenTokenAction,
|
||||
else:
|
||||
return GenTokenAction, ReduceAction
|
||||
else:
|
||||
return GenTokenAction, ReduceAction
|
||||
else:
|
||||
return ApplyRuleAction,
|
||||
|
||||
def get_valid_continuating_productions(self, hyp):
|
||||
if hyp.tree:
|
||||
if self.grammar.is_composite_type(hyp.frontier_field.type):
|
||||
return self.grammar[hyp.frontier_field.type]
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
return self.grammar[self.grammar.root_type]
|
||||
|
||||
@staticmethod
|
||||
def get_class_by_lang(lang):
|
||||
if lang == 'sql':
|
||||
from asdl.sql.sql_transition_system import SQLTransitionSystem
|
||||
else:
|
||||
raise ValueError('unknown language %s' % lang)
|
||||
return SQLTransitionSystem
|
|
@ -0,0 +1,71 @@
|
|||
#coding=utf8
|
||||
import sys, os, json, pickle, argparse, time, torch
|
||||
from argparse import Namespace
|
||||
# These three lines are used in the docker environment rhythmcao/text2sql:v2.0
|
||||
# os.environ['NLTK_DATA'] = os.path.join(os.path.sep, 'root', 'nltk_data')
|
||||
# os.environ["STANZA_RESOURCES_DIR"] = os.path.join(os.path.sep, 'root', 'stanza_resources')
|
||||
# os.environ['EMBEDDINGS_ROOT'] = os.path.join(os.path.sep, 'root', '.embeddings')
|
||||
from preprocess.process_dataset import process_tables, process_dataset
|
||||
from preprocess.process_graphs import process_dataset_graph
|
||||
from preprocess.common_utils import Preprocessor
|
||||
from preprocess.graph_utils import GraphProcessor
|
||||
from utils.example import Example
|
||||
from utils.batch import Batch
|
||||
from model.model_utils import Registrable
|
||||
from model.model_constructor import *
|
||||
|
||||
def preprocess_database_and_dataset(db_dir='database/', table_path='data/tables.json', dataset_path='data/dev.json', method='lgesql'):
|
||||
tables = json.load(open(table_path, 'r'))
|
||||
dataset = json.load(open(dataset_path, 'r'))
|
||||
processor = Preprocessor(db_dir=db_dir, db_content=True)
|
||||
output_tables = process_tables(processor, tables)
|
||||
output_dataset = process_dataset(processor, dataset, output_tables)
|
||||
graph_processor = GraphProcessor()
|
||||
output_dataset = process_dataset_graph(graph_processor, output_dataset, output_tables, method=method)
|
||||
return output_dataset, output_tables
|
||||
|
||||
def load_examples(dataset, tables):
|
||||
ex_list = []
|
||||
for ex in dataset:
|
||||
ex_list.append(Example(ex, tables[ex['db_id']]))
|
||||
return ex_list
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--db_dir', default='database', help='path to db dir')
|
||||
parser.add_argument('--table_path', default='data/tables.json', help='path to tables json file')
|
||||
parser.add_argument('--dataset_path', default='data/dev.json', help='path to raw dataset json file')
|
||||
parser.add_argument('--saved_model', default='saved_models/glove42B', help='path to saved model path, at least contain param.json and model.bin')
|
||||
parser.add_argument('--output_path', default='predicted_sql.txt', help='output predicted sql file')
|
||||
parser.add_argument('--batch_size', default=20, type=int, help='batch size for evaluation')
|
||||
parser.add_argument('--beam_size', default=5, type=int, help='beam search size')
|
||||
parser.add_argument('--use_gpu', action='store_true', help='whether use gpu')
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
params = json.load(open(os.path.join(args.saved_model, 'params.json'), 'r'), object_hook=lambda d: Namespace(**d))
|
||||
params.lazy_load = True # load PLM from AutoConfig instead of AutoModel.from_pretrained(...)
|
||||
dataset, tables = preprocess_database_and_dataset(db_dir=args.db_dir, table_path=args.table_path, dataset_path=args.dataset_path, method=params.model)
|
||||
Example.configuration(plm=params.plm, method=params.model, tables=tables, table_path=args.table_path, db_dir=args.db_dir)
|
||||
dataset = load_examples(dataset, tables)
|
||||
|
||||
device = torch.device("cuda:0") if torch.cuda.is_available() and args.use_gpu else torch.device("cpu")
|
||||
model = Registrable.by_name('text2sql')(params, Example.trans).to(device)
|
||||
check_point = torch.load(open(os.path.join(args.saved_model, 'model.bin'), 'rb'), map_location=device)
|
||||
model.load_state_dict(check_point['model'])
|
||||
|
||||
start_time = time.time()
|
||||
print('Start evaluating ...')
|
||||
model.eval()
|
||||
all_hyps = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(dataset), args.batch_size):
|
||||
current_batch = Batch.from_example_list(dataset[i: i + args.batch_size], device, train=False)
|
||||
hyps = model.parse(current_batch, args.beam_size)
|
||||
all_hyps.extend(hyps)
|
||||
with open(args.output_path, 'w', encoding='utf8') as of:
|
||||
evaluator = Example.evaluator
|
||||
for idx, hyp in enumerate(all_hyps):
|
||||
pred_sql = evaluator.obtain_sql(hyp, dataset[idx].db)
|
||||
# best_ast = hyp[0].tree # by default, the top beam prediction
|
||||
# pred_sql = Example.trans.ast_to_surface_code(best_ast, dataset[idx].db)
|
||||
of.write(pred_sql + '\n')
|
||||
print('Evaluation costs %.4fs .' % (time.time() - start_time))
|
|
@ -0,0 +1,877 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]] # val1 is also considered
|
||||
label_wo_agg = [unit[2] for unit in label_conds] # val_unit
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
for _, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
if len(label['from']['table_units']) > 0:
|
||||
if label['from']['table_units'][0][0] == 'sql' and pred['from']['table_units'][0][0] == 'sql':
|
||||
return self.eval_exact_match(pred['from']['table_units'][0][1], label['from']['table_units'][0][1]) # still wrong
|
||||
else:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
return label_tables == pred_tables
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold) as f:
|
||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict) as f:
|
||||
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
|
||||
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
|
||||
evaluator = Evaluator()
|
||||
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
p_str = p[0]
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
# .schema: map lowercased raw tab name to lowercased raw col name list
|
||||
# .idMap: map tab name to __tab__, tab.col to __tab.col__, * to __all__, all lowercased
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
# extract all __tab.col__ that has tab in from clause, not include __all__
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) # kmap: map __tab.col__ to pivot __tab.col__
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
return scores
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
if len(sql['from']['table_units']) > 0 and sql['from']['table_units'][0][0] == 'sql':
|
||||
sql['from']['table_units'][0] = ('sql', rebuild_sql_val(sql['from']['table_units'][0][1]))
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table) as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
"""
|
||||
foreign key map(dict): str -> str
|
||||
map each lowercased column to a pivot column in its cluster set
|
||||
normalized column: __all__, __t.lower() + . + c.lower()__
|
||||
"""
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
|
@ -0,0 +1,215 @@
|
|||
#coding=utf8
|
||||
""" ONLSTM and traditional LSTM with locked dropout """
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def cumsoftmax(x, dim=-1):
|
||||
return torch.cumsum(F.softmax(x, dim=dim), dim=dim)
|
||||
|
||||
class LinearDropConnect(nn.Linear):
|
||||
""" Used in recurrent connection dropout """
|
||||
def __init__(self, in_features, out_features, bias=True, dropconnect=0.):
|
||||
super(LinearDropConnect, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
|
||||
self.dropconnect = dropconnect
|
||||
|
||||
def sample_mask(self):
|
||||
if self.dropconnect == 0.:
|
||||
self._weight = self.weight.clone()
|
||||
else:
|
||||
mask = self.weight.new_zeros(self.weight.size(), dtype=torch.bool)
|
||||
mask.bernoulli_(self.dropconnect)
|
||||
self._weight = self.weight.masked_fill(mask, 0.)
|
||||
|
||||
def forward(self, inputs, sample_mask=False):
|
||||
if self.training:
|
||||
if sample_mask:
|
||||
self.sample_mask()
|
||||
return F.linear(inputs, self._weight, self.bias) # apply the same mask to weight matrix in linear module
|
||||
else:
|
||||
return F.linear(inputs, self.weight * (1 - self.dropconnect), self.bias)
|
||||
|
||||
class LockedDropout(nn.Module):
|
||||
""" Used in dropout between layers """
|
||||
def __init__(self, hidden_size, num_layers=1, dropout=0.2):
|
||||
super(LockedDropout, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
|
||||
def sample_masks(self, x):
|
||||
self.masks = []
|
||||
for _ in range(self.num_layers - 1):
|
||||
mask = x.new_zeros(x.size(0), 1, self.hidden_size).bernoulli_(1 - self.dropout)
|
||||
mask = mask.div_(1 - self.dropout)
|
||||
mask.requires_grad = False
|
||||
self.masks.append(mask)
|
||||
|
||||
def forward(self, x, layer=0):
|
||||
""" x: bsize x seqlen x hidden_size """
|
||||
if (not self.training) or self.dropout == 0. or layer == self.num_layers - 1: # output hidden states, no dropout
|
||||
return x
|
||||
mask = self.masks[layer]
|
||||
mask = mask.expand_as(x)
|
||||
return mask * x
|
||||
|
||||
class RecurrentNeuralNetwork(nn.Module):
|
||||
def init_hiddens(self, x):
|
||||
return x.new_zeros(self.num_layers, x.size(0), self.hidden_size), \
|
||||
x.new_zeros(self.num_layers, x.size(0), self.hidden_size)
|
||||
|
||||
def forward(self, inputs, hiddens=None, start=False, layerwise=False):
|
||||
"""
|
||||
@args:
|
||||
start: whether sampling locked masks for recurrent connections and between layers
|
||||
layerwise: whether return a list, results of intermediate layer outputs
|
||||
@return:
|
||||
outputs: bsize x seqlen x hidden_size
|
||||
final_hiddens: hT and cT, each of size: num_layers x bsize x hidden_size
|
||||
"""
|
||||
assert inputs.dim() == 3
|
||||
if hiddens is None:
|
||||
hiddens = self.init_hiddens(inputs)
|
||||
bsize, seqlen, _ = list(inputs.size())
|
||||
prev_state = list(hiddens) # each of size: num_layers, bsize, hidden_size
|
||||
prev_layer = inputs # size: bsize, seqlen, input_size
|
||||
each_layer_outputs, final_h, final_c = [], [], []
|
||||
|
||||
if self.training and start:
|
||||
for c in self.cells:
|
||||
c.sample_masks()
|
||||
self.locked_dropout.sample_masks(inputs)
|
||||
|
||||
for l in range(len(self.cells)):
|
||||
curr_layer = [None] * seqlen
|
||||
curr_inputs = self.cells[l].ih(prev_layer)
|
||||
next_h, next_c = prev_state[0][l], prev_state[1][l]
|
||||
for t in range(seqlen):
|
||||
hidden, cell = self.cells[l](None, (next_h, next_c), transformed_inputs=curr_inputs[:, t])
|
||||
next_h, next_c = hidden, cell # overwritten every timestep
|
||||
curr_layer[t] = hidden
|
||||
|
||||
prev_layer = torch.stack(curr_layer, dim=1) # bsize x seqlen x hidden_size
|
||||
each_layer_outputs.append(prev_layer)
|
||||
final_h.append(next_h)
|
||||
final_c.append(next_c)
|
||||
prev_layer = self.locked_dropout(prev_layer, layer=l)
|
||||
|
||||
outputs, final_hiddens = prev_layer, (torch.stack(final_h, dim=0), torch.stack(final_c, dim=0))
|
||||
if layerwise:
|
||||
return outputs, final_hiddens, each_layer_outputs
|
||||
else:
|
||||
return outputs, final_hiddens
|
||||
|
||||
class LSTMCell(nn.Module):
|
||||
|
||||
def __init__(self, input_size, hidden_size, bias=True, dropconnect=0.):
|
||||
super(LSTMCell, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ih = nn.Linear(input_size, hidden_size * 4, bias=bias)
|
||||
self.hh = LinearDropConnect(hidden_size, hidden_size * 4, bias=bias, dropconnect=dropconnect)
|
||||
self.drop_weight_modules = [self.hh]
|
||||
|
||||
def sample_masks(self):
|
||||
for m in self.drop_weight_modules:
|
||||
m.sample_mask()
|
||||
|
||||
def forward(self, inputs, hiddens, transformed_inputs=None):
|
||||
"""
|
||||
inputs: bsize x input_size
|
||||
hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size)
|
||||
transformed_inputs: short cut for inputs, save time if seq len is already provied in training
|
||||
return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size)
|
||||
"""
|
||||
if transformed_inputs is None:
|
||||
transformed_inputs = self.ih(inputs)
|
||||
h0, c0 = hiddens
|
||||
gates = transformed_inputs + self.hh(h0)
|
||||
ingate, forgetgate, outgate, cell = gates.contiguous().\
|
||||
view(-1, 4, self.hidden_size).chunk(4, 1)
|
||||
forgetgate = torch.sigmoid(forgetgate.squeeze(1))
|
||||
ingate = torch.sigmoid(ingate.squeeze(1))
|
||||
cell = torch.tanh(cell.squeeze(1))
|
||||
outgate = torch.sigmoid(outgate.squeeze(1))
|
||||
c1 = forgetgate * c0 + ingate * cell
|
||||
h1 = outgate * torch.tanh(c1)
|
||||
return h1, c1
|
||||
|
||||
class LSTM(RecurrentNeuralNetwork):
|
||||
|
||||
def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=1, bias=True, dropout=0., dropconnect=0.):
|
||||
super(LSTM, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.cells = nn.ModuleList(
|
||||
[LSTMCell(input_size, hidden_size, bias, dropconnect)] +
|
||||
[LSTMCell(hidden_size, hidden_size, bias, dropconnect) for i in range(num_layers - 1)]
|
||||
)
|
||||
self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers
|
||||
|
||||
class ONLSTMCell(nn.Module):
|
||||
|
||||
def __init__(self, input_size, hidden_size, chunk_num=8, bias=True, dropconnect=0.2):
|
||||
super(ONLSTMCell, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.chunk_num = chunk_num # chunk_num should be divided by hidden_size
|
||||
if self.hidden_size % self.chunk_num != 0:
|
||||
raise ValueError('[Error]: chunk number must be divided by hidden size in ONLSTM Cell')
|
||||
self.chunk_size = int(hidden_size / chunk_num)
|
||||
|
||||
self.ih = nn.Linear(input_size, self.chunk_size * 2 + hidden_size * 4, bias=bias)
|
||||
self.hh = LinearDropConnect(hidden_size, self.chunk_size * 2 + hidden_size * 4, bias=bias, dropconnect=dropconnect)
|
||||
self.drop_weight_modules = [self.hh]
|
||||
|
||||
def sample_masks(self):
|
||||
for m in self.drop_weight_modules:
|
||||
m.sample_mask()
|
||||
|
||||
def forward(self, inputs, hiddens, transformed_inputs=None):
|
||||
"""
|
||||
inputs: bsize x input_size
|
||||
hiddens: tuple of h0 (bsize x hidden_size) and c0 (bsize x hidden_size)
|
||||
transformed_inputs: short cut for inputs, save time if seq len is already provied in training
|
||||
return tuple of h1 (bsize x hidden_size) and c1 (bsize x hidden_size)
|
||||
"""
|
||||
if transformed_inputs is None:
|
||||
transformed_inputs = self.ih(inputs)
|
||||
h0, c0 = hiddens
|
||||
gates = transformed_inputs + self.hh(h0)
|
||||
cingate, cforgetgate = gates[:, :self.chunk_size * 2].chunk(2, 1)
|
||||
ingate, forgetgate, outgate, cell = gates[:, self.chunk_size * 2:].contiguous().\
|
||||
view(-1, self.chunk_size * 4, self.chunk_num).chunk(4, 1)
|
||||
|
||||
cingate = 1. - cumsoftmax(cingate)
|
||||
cforgetgate = cumsoftmax(cforgetgate)
|
||||
cingate = cingate[:, :, None]
|
||||
cforgetgate = cforgetgate[:, :, None]
|
||||
|
||||
forgetgate = torch.sigmoid(forgetgate)
|
||||
ingate = torch.sigmoid(ingate)
|
||||
cell = torch.tanh(cell)
|
||||
outgate = torch.sigmoid(outgate)
|
||||
|
||||
overlap = cforgetgate * cingate
|
||||
forgetgate = forgetgate * overlap + (cforgetgate - overlap)
|
||||
ingate = ingate * overlap + (cingate - overlap)
|
||||
c0 = c0.contiguous().view(-1, self.chunk_size, self.chunk_num)
|
||||
c1 = forgetgate * c0 + ingate * cell
|
||||
h1 = outgate * torch.tanh(c1)
|
||||
return h1.contiguous().view(-1, self.hidden_size), c1.contiguous().view(-1, self.hidden_size)
|
||||
|
||||
class ONLSTM(RecurrentNeuralNetwork):
|
||||
|
||||
def __init__(self, input_size, hidden_size, num_layers=1, chunk_num=8, bias=True, dropout=0., dropconnect=0.):
|
||||
super(ONLSTM, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.cells = nn.ModuleList(
|
||||
[ONLSTMCell(input_size, hidden_size, chunk_num, bias, dropconnect)] +
|
||||
[ONLSTMCell(hidden_size, hidden_size, chunk_num, bias, dropconnect) for i in range(num_layers - 1)]
|
||||
)
|
||||
self.locked_dropout = LockedDropout(hidden_size, num_layers, dropout) # dropout rate between layers
|
|
@ -0,0 +1,397 @@
|
|||
# coding=utf-8
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from asdl.sql.sql_transition_system import SelectColumnAction, SelectTableAction
|
||||
from asdl.transition_system import ApplyRuleAction, ReduceAction
|
||||
from asdl.action_info import ActionInfo
|
||||
from asdl.hypothesis import Hypothesis
|
||||
from asdl.decode_hypothesis import DecodeHypothesis
|
||||
from utils.batch import Batch
|
||||
from model.decoder.onlstm import LSTM, ONLSTM
|
||||
from model.model_utils import Registrable, MultiHeadAttention
|
||||
|
||||
@Registrable.register('decoder_tranx')
|
||||
class SqlParser(nn.Module):
|
||||
def __init__(self, args, transition_system):
|
||||
super(SqlParser, self).__init__()
|
||||
self.args = args
|
||||
self.transition_system = transition_system
|
||||
self.grammar = self.transition_system.grammar
|
||||
|
||||
# the last entry is the embedding for Reduce action
|
||||
self.production_embed = nn.Embedding(len(transition_system.grammar) + 1, args.action_embed_size)
|
||||
# embedding table for ASDL fields in constructors
|
||||
self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size)
|
||||
# embedding table for ASDL types
|
||||
self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size)
|
||||
|
||||
# input of decoder lstm
|
||||
input_dim = args.action_embed_size # previous action
|
||||
# parent node
|
||||
cxt_num = 2 if args.sep_cxt else 1
|
||||
input_dim += args.gnn_hidden_size * cxt_num * int(not args.no_context_feeding)
|
||||
input_dim += args.action_embed_size * int(not args.no_parent_production_embed)
|
||||
input_dim += args.field_embed_size * int(not args.no_parent_field_embed)
|
||||
input_dim += args.type_embed_size * int(not args.no_parent_field_type_embed)
|
||||
input_dim += args.lstm_hidden_size * int(not args.no_parent_state)
|
||||
cell_constructor = ONLSTM if args.lstm == 'onlstm' else LSTM
|
||||
self.decoder_lstm = cell_constructor(input_dim, args.lstm_hidden_size, num_layers=args.lstm_num_layers,
|
||||
chunk_num=args.chunk_size, dropout=args.dropout, dropconnect=args.drop_connect)
|
||||
|
||||
# transform column embedding to production embedding space
|
||||
self.column_lstm_input = nn.Linear(args.gnn_hidden_size, args.action_embed_size)
|
||||
# transform table embedding to production embedding space
|
||||
self.table_lstm_input = self.column_lstm_input # nn.Linear(args.gnn_hidden_size, args.action_embed_size)
|
||||
|
||||
self.context_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
|
||||
num_heads=args.num_heads, feat_drop=args.dropout)
|
||||
if args.sep_cxt: # calculate seperate context vector for question and database schema
|
||||
self.schema_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
|
||||
num_heads=args.num_heads, feat_drop=args.dropout)
|
||||
|
||||
# feature vector before ApplyRule or SelectColumn/Table
|
||||
self.att_vec_linear = nn.Sequential(nn.Linear(args.lstm_hidden_size + cxt_num * args.gnn_hidden_size, args.att_vec_size), nn.Tanh())
|
||||
|
||||
self.apply_rule_affine = nn.Linear(args.att_vec_size, args.action_embed_size, bias=False)
|
||||
self.apply_rule = lambda x: F.linear(x, self.production_embed.weight) # re-use the action embedding matrix
|
||||
self.select_column = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
|
||||
num_heads=args.num_heads, feat_drop=args.dropout)
|
||||
self.select_table = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
|
||||
num_heads=args.num_heads, feat_drop=args.dropout)
|
||||
|
||||
def score(self, encodings, mask, h0, batch):
|
||||
""" Training function
|
||||
@input:
|
||||
encodings: encoded representations and mask matrix from encoder
|
||||
bsize x seqlen x gnn_hidden_size
|
||||
batch: see utils.batch, we use fields
|
||||
batch.examples, batch.get_frontier_prod_idx(t),
|
||||
batch.get_frontier_field_idx(t), batch.get_frontier_field_type_idx(t),
|
||||
batch.max_action_num, example.tgt_action (ActionInfo)
|
||||
output:
|
||||
loss: sum of loss for each training batch
|
||||
"""
|
||||
args = self.args
|
||||
zero_action_embed = encodings.new_zeros(args.action_embed_size)
|
||||
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
|
||||
q, tab, col = encodings.split(split_len, dim=1)
|
||||
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
|
||||
if args.sep_cxt:
|
||||
encodings, mask = q, q_mask
|
||||
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
|
||||
schema_context, _ = self.schema_attn(schema, h0, schema_mask)
|
||||
context, _ = self.context_attn(encodings, h0, mask)
|
||||
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
|
||||
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
|
||||
h_c = (h0, h0.new_zeros(h0.size()))
|
||||
action_probs = [[] for _ in range(encodings.size(0))]
|
||||
history_states = []
|
||||
for t in range(batch.max_action_num):
|
||||
# x: [prev_action_embed, [prev_context, parent_production, parent_field, parent_type, parent_state]]
|
||||
if t == 0:
|
||||
x = encodings.new_zeros(encodings.size(0), self.decoder_lstm.input_size)
|
||||
offset = args.action_embed_size
|
||||
if not args.no_context_feeding:
|
||||
x[:, offset: offset + args.gnn_hidden_size] = context
|
||||
offset += args.gnn_hidden_size
|
||||
if args.sep_cxt:
|
||||
x[:, offset: offset + args.gnn_hidden_size] = schema_context
|
||||
offset += args.gnn_hidden_size
|
||||
offset += args.action_embed_size * int(not args.no_parent_production_embed)
|
||||
offset += args.field_embed_size * int(not args.no_parent_field_embed)
|
||||
if not args.no_parent_field_type_embed:
|
||||
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * x.size(0), dtype=torch.long, device=x.device)
|
||||
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
|
||||
else:
|
||||
prev_action_embed = []
|
||||
for e_id, example in enumerate(batch.examples):
|
||||
if t < len(example.tgt_action):
|
||||
prev_action = example.tgt_action[t - 1].action
|
||||
if isinstance(prev_action, ApplyRuleAction):
|
||||
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
|
||||
elif isinstance(prev_action, ReduceAction):
|
||||
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
|
||||
elif isinstance(prev_action, SelectColumnAction):
|
||||
# map schema item into prod embed space
|
||||
col_embed = self.column_lstm_input(col[e_id, prev_action.column_id])
|
||||
prev_action_embed.append(col_embed)
|
||||
elif isinstance(prev_action, SelectTableAction):
|
||||
tab_embed = self.table_lstm_input(tab[e_id, prev_action.table_id])
|
||||
prev_action_embed.append(tab_embed)
|
||||
else:
|
||||
raise ValueError('Unrecognized previous action object!')
|
||||
else:
|
||||
prev_action_embed.append(zero_action_embed)
|
||||
inputs = [torch.stack(prev_action_embed)]
|
||||
if not args.no_context_feeding:
|
||||
inputs.append(context)
|
||||
if args.sep_cxt:
|
||||
inputs.append(schema_context)
|
||||
if not args.no_parent_production_embed:
|
||||
inputs.append(self.production_embed(batch.get_frontier_prod_idx(t)))
|
||||
if not args.no_parent_field_embed:
|
||||
inputs.append(self.field_embed(batch.get_frontier_field_idx(t)))
|
||||
if not args.no_parent_field_type_embed:
|
||||
inputs.append(self.type_embed(batch.get_frontier_field_type_idx(t)))
|
||||
if not args.no_parent_state:
|
||||
actions_t = [e.tgt_action[t] if t < len(e.tgt_action) else None for e in batch.examples]
|
||||
parent_state = torch.stack([history_states[p_t][e_id]
|
||||
for e_id, p_t in
|
||||
enumerate([a_t.parent_t if a_t else 0 for a_t in actions_t])])
|
||||
inputs.append(parent_state)
|
||||
x = torch.cat(inputs, dim=-1)
|
||||
|
||||
# advance decoder lstm and attention calculation
|
||||
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c, start=(t==0))
|
||||
out = out.squeeze(1)
|
||||
context, _ = self.context_attn(encodings, out, mask)
|
||||
if args.sep_cxt:
|
||||
schema_context, _ = self.schema_attn(schema, out, schema_mask)
|
||||
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # bsize x args.att_vec_size
|
||||
else:
|
||||
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
|
||||
|
||||
# action logprobs
|
||||
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # bsize x prod_num
|
||||
_, select_tab_prob = self.select_table(tab, att_vec, tab_mask)
|
||||
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
|
||||
_, select_col_prob = self.select_column(col, att_vec, col_mask)
|
||||
select_col_logprob = torch.log(select_col_prob + 1e-32)
|
||||
|
||||
for e_id, example in enumerate(batch.examples):
|
||||
if t < len(example.tgt_action):
|
||||
action_t = example.tgt_action[t].action
|
||||
if isinstance(action_t, ApplyRuleAction):
|
||||
logprob_t = apply_rule_logprob[e_id, self.grammar.prod2id[action_t.production]]
|
||||
# print('Rule %s with prob %s' %(action_t.production, logprob_t.item()))
|
||||
elif isinstance(action_t, ReduceAction):
|
||||
logprob_t = apply_rule_logprob[e_id, len(self.grammar)]
|
||||
# print('Rule %s with prob %s' % ('Reduce', logprob_t.item()))
|
||||
elif isinstance(action_t, SelectColumnAction):
|
||||
logprob_t = select_col_logprob[e_id, action_t.column_id]
|
||||
# print('SelectColumn %s with prob %s' % (action_t.column_id, logprob_t.item()))
|
||||
elif isinstance(action_t, SelectTableAction):
|
||||
logprob_t = select_tab_logprob[e_id, action_t.table_id]
|
||||
# print('SelectTable %s with prob %s' % (action_t.table_id, logprob_t.item()))
|
||||
else:
|
||||
raise ValueError('Unrecognized action object!')
|
||||
action_probs[e_id].append(logprob_t)
|
||||
|
||||
h_c = (h_t, c_t)
|
||||
history_states.append(h_t[-1])
|
||||
|
||||
# loss is negative sum of all the action probabilities
|
||||
loss = - torch.stack([torch.stack(logprob_i).sum() for logprob_i in action_probs]).sum()
|
||||
return loss
|
||||
|
||||
def parse(self, encodings, mask, h0, batch, beam_size=5):
|
||||
""" Parse one by one, batch size for each args in encodings is 1
|
||||
"""
|
||||
args = self.args
|
||||
zero_action_embed = encodings.new_zeros(args.action_embed_size)
|
||||
assert encodings.size(0) == 1 and mask.size(0) == 1
|
||||
encodings, mask = encodings.repeat(beam_size, 1, 1), mask.repeat(beam_size, 1)
|
||||
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
|
||||
q, tab, col = encodings.split(split_len, dim=1)
|
||||
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
|
||||
if args.sep_cxt:
|
||||
encodings, mask = q, q_mask
|
||||
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
|
||||
schema_context, _ = self.schema_attn(schema[:1], h0, schema_mask[:1])
|
||||
context, _ = self.context_attn(encodings[:1], h0, mask[:1])
|
||||
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
|
||||
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
|
||||
h_c = (h0, h0.new_zeros(h0.size()))
|
||||
hypotheses = [DecodeHypothesis()]
|
||||
hyp_states, completed_hypotheses, t = [[]], [], 0
|
||||
while t < args.decode_max_step:
|
||||
hyp_num = len(hypotheses)
|
||||
cur_encodings, cur_mask = encodings[:hyp_num], mask[:hyp_num]
|
||||
if args.sep_cxt:
|
||||
cur_schema, cur_schema_mask = schema[:hyp_num], schema_mask[:hyp_num]
|
||||
cur_tab, cur_tab_mask, cur_col, cur_col_mask = tab[:hyp_num], tab_mask[:hyp_num], col[:hyp_num], col_mask[:hyp_num]
|
||||
# x: [prev_action_embed, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_state]
|
||||
if t == 0:
|
||||
x = encodings.new_zeros(hyp_num, self.decoder_lstm.input_size)
|
||||
offset = args.action_embed_size
|
||||
if not args.no_context_feeding:
|
||||
x[:, offset: offset + args.gnn_hidden_size] = context
|
||||
offset += args.gnn_hidden_size
|
||||
if args.sep_cxt:
|
||||
x[:, offset: offset + args.gnn_hidden_size] = schema_context
|
||||
offset += args.gnn_hidden_size
|
||||
offset += args.action_embed_size * int(not args.no_parent_production_embed)
|
||||
offset += args.field_embed_size * int(not args.no_parent_field_embed)
|
||||
if not args.no_parent_field_type_embed:
|
||||
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * hyp_num, dtype=torch.long, device=x.device)
|
||||
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
|
||||
else:
|
||||
prev_action_embed = []
|
||||
for e_id, hyp in enumerate(hypotheses):
|
||||
prev_action = hyp.actions[-1]
|
||||
if isinstance(prev_action, ApplyRuleAction):
|
||||
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
|
||||
elif isinstance(prev_action, ReduceAction):
|
||||
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
|
||||
elif isinstance(prev_action, SelectColumnAction): # need to first map schema item into prod embed space
|
||||
col_embed = self.column_lstm_input(cur_col[e_id, prev_action.column_id])
|
||||
prev_action_embed.append(col_embed)
|
||||
elif isinstance(prev_action, SelectTableAction):
|
||||
tab_embed = self.table_lstm_input(cur_tab[e_id, prev_action.table_id])
|
||||
prev_action_embed.append(tab_embed)
|
||||
else:
|
||||
raise ValueError('Unrecognized previous action object!')
|
||||
inputs = [torch.stack(prev_action_embed)]
|
||||
if not args.no_context_feeding:
|
||||
inputs.append(context)
|
||||
if args.sep_cxt:
|
||||
inputs.append(schema_context)
|
||||
if not args.no_parent_production_embed:
|
||||
frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
|
||||
frontier_prod_embeds = self.production_embed(
|
||||
torch.tensor([self.grammar.prod2id[prod] for prod in frontier_prods], dtype=torch.long, device=encodings.device))
|
||||
inputs.append(frontier_prod_embeds)
|
||||
if not args.no_parent_field_embed:
|
||||
frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
|
||||
frontier_field_embeds = self.field_embed(
|
||||
torch.tensor([self.grammar.field2id[field] for field in frontier_fields], dtype=torch.long, device=encodings.device))
|
||||
inputs.append(frontier_field_embeds)
|
||||
if not args.no_parent_field_type_embed:
|
||||
frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
|
||||
frontier_field_type_embeds = self.type_embed(
|
||||
torch.tensor([self.grammar.type2id[tp] for tp in frontier_field_types], dtype=torch.long, device=encodings.device))
|
||||
inputs.append(frontier_field_type_embeds)
|
||||
if not args.no_parent_state:
|
||||
p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
|
||||
parent_states = torch.stack([hyp_states[hyp_id][p_t] for hyp_id, p_t in enumerate(p_ts)])
|
||||
inputs.append(parent_states)
|
||||
x = torch.cat(inputs, dim=-1) # hyp_num x input_size
|
||||
|
||||
# advance decoder lstm and attention calculation
|
||||
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c)
|
||||
out = out.squeeze(1)
|
||||
context, _ = self.context_attn(cur_encodings, out, cur_mask)
|
||||
if args.sep_cxt:
|
||||
schema_context, _ = self.schema_attn(cur_schema, out, cur_schema_mask)
|
||||
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # hyp_num x args.att_vec_size
|
||||
else:
|
||||
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
|
||||
|
||||
# action logprobs
|
||||
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # remaining_sents*beam x prod_num
|
||||
_, select_tab_prob = self.select_table(cur_tab, att_vec, cur_tab_mask)
|
||||
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
|
||||
_, select_col_prob = self.select_column(cur_col, att_vec, cur_col_mask)
|
||||
select_col_logprob = torch.log(select_col_prob + 1e-32)
|
||||
|
||||
new_hyp_meta = []
|
||||
for hyp_id, hyp in enumerate(hypotheses):
|
||||
action_types = self.transition_system.get_valid_continuation_types(hyp)
|
||||
for action_type in action_types:
|
||||
if action_type == ApplyRuleAction:
|
||||
productions = self.transition_system.get_valid_continuating_productions(hyp)
|
||||
for production in productions:
|
||||
prod_id = self.grammar.prod2id[production]
|
||||
prod_score = apply_rule_logprob[hyp_id, prod_id]
|
||||
new_hyp_score = hyp.score + prod_score
|
||||
meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
|
||||
'score': prod_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
elif action_type == ReduceAction:
|
||||
action_score = apply_rule_logprob[hyp_id, len(self.grammar)]
|
||||
new_hyp_score = hyp.score + action_score
|
||||
meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
|
||||
'score': action_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
elif action_type == SelectColumnAction:
|
||||
col_num = cur_col_mask[hyp_id].int().sum().item()
|
||||
for col_id in range(col_num):
|
||||
col_sel_score = select_col_logprob[hyp_id, col_id]
|
||||
new_hyp_score = hyp.score + col_sel_score
|
||||
meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
|
||||
'score': col_sel_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
elif action_type == SelectTableAction:
|
||||
tab_num = cur_tab_mask[hyp_id].int().sum().item()
|
||||
for tab_id in range(tab_num):
|
||||
tab_sel_score = select_tab_logprob[hyp_id, tab_id]
|
||||
new_hyp_score = hyp.score + tab_sel_score
|
||||
meta_entry = {'action_type': 'sel_tab', 'tab_id': tab_id,
|
||||
'score': tab_sel_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
else:
|
||||
raise ValueError('Unrecognized action type while decoding!')
|
||||
if not new_hyp_meta: break
|
||||
|
||||
# rank and pick
|
||||
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta])
|
||||
top_new_hyp_scores, meta_ids = new_hyp_scores.topk(min(beam_size, new_hyp_scores.size(0)))
|
||||
|
||||
live_hyp_ids, new_hypotheses = [], []
|
||||
for new_hyp_score, meta_id in zip(top_new_hyp_scores.tolist(), meta_ids.tolist()):
|
||||
action_info = ActionInfo()
|
||||
hyp_meta_entry = new_hyp_meta[meta_id]
|
||||
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
|
||||
prev_hyp = hypotheses[prev_hyp_id]
|
||||
action_type_str = hyp_meta_entry['action_type']
|
||||
if action_type_str == 'apply_rule':
|
||||
prod_id = hyp_meta_entry['prod_id']
|
||||
if prod_id < len(self.grammar):
|
||||
production = self.grammar.id2prod[prod_id]
|
||||
action = ApplyRuleAction(production)
|
||||
else:
|
||||
action = ReduceAction()
|
||||
elif action_type_str == 'sel_col':
|
||||
action = SelectColumnAction(hyp_meta_entry['col_id'])
|
||||
elif action_type_str == 'sel_tab':
|
||||
action = SelectTableAction(hyp_meta_entry['tab_id'])
|
||||
else:
|
||||
raise ValueError('Unrecognized action type str!')
|
||||
|
||||
action_info.action = action
|
||||
action_info.t = t
|
||||
if t > 0:
|
||||
action_info.parent_t = prev_hyp.frontier_node.created_time
|
||||
action_info.frontier_prod = prev_hyp.frontier_node.production
|
||||
action_info.frontier_field = prev_hyp.frontier_field.field
|
||||
|
||||
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
|
||||
new_hyp.score = new_hyp_score
|
||||
|
||||
if new_hyp.completed:
|
||||
completed_hypotheses.append(new_hyp)
|
||||
else:
|
||||
new_hypotheses.append(new_hyp)
|
||||
live_hyp_ids.append(prev_hyp_id)
|
||||
|
||||
if live_hyp_ids:
|
||||
if not args.no_parent_state:
|
||||
hyp_states = [hyp_states[i] + [h_t[-1, i]] for i in live_hyp_ids]
|
||||
h_c = (h_t[:, live_hyp_ids], c_t[:, live_hyp_ids])
|
||||
if not args.no_context_feeding:
|
||||
context = context[live_hyp_ids]
|
||||
if args.sep_cxt:
|
||||
schema_context = schema_context[live_hyp_ids]
|
||||
hypotheses = new_hypotheses
|
||||
t += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# for idx, hyp in enumerate(new_hypotheses):
|
||||
# print('Idx %d' % (idx))
|
||||
# print('Tree:', hyp.tree.to_string())
|
||||
# print('Action:', hyp.action_infos[-1])
|
||||
# print('============================================')
|
||||
# input('********************Next parse step ....************************')
|
||||
|
||||
if len(completed_hypotheses) == 0: # no completed sql
|
||||
completed_hypotheses.append(DecodeHypothesis())
|
||||
else:
|
||||
completed_hypotheses.sort(key=lambda hyp: - hyp.score) # / hyp.tree.size
|
||||
return completed_hypotheses
|
|
@ -0,0 +1,33 @@
|
|||
#coding=utf8
|
||||
import dgl, math, torch
|
||||
|
||||
def src_dot_dst(src_field, dst_field, out_field):
|
||||
def func(edges):
|
||||
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
|
||||
|
||||
return func
|
||||
|
||||
def src_sum_edge_mul_dst(src_field, dst_field, e_field, out_field):
|
||||
def func(edges):
|
||||
return {out_field: ((edges.src[src_field] + edges.data[e_field]) * edges.dst[dst_field]).sum(-1, keepdim=True)}
|
||||
|
||||
return func
|
||||
|
||||
def scaled_exp(field, scale_constant):
|
||||
def func(edges):
|
||||
# clamp for softmax numerical stability
|
||||
return {field: torch.exp((edges.data[field] / scale_constant).clamp(-10, 10))}
|
||||
|
||||
return func
|
||||
|
||||
def src_sum_edge_mul_edge(src_field, e_field1, e_field2, out_field):
|
||||
def func(edges):
|
||||
return {out_field: (edges.src[src_field] + edges.data[e_field1]) * edges.data[e_field2]}
|
||||
|
||||
return func
|
||||
|
||||
def div_by_z(in_field, norm_field, out_field):
|
||||
def func(nodes):
|
||||
return {out_field: nodes.data[in_field] / nodes.data[norm_field]}
|
||||
|
||||
return func
|
|
@ -0,0 +1,25 @@
|
|||
#coding=utf8
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.encoder.graph_input import *
|
||||
from model.encoder.rgatsql import RGATSQL
|
||||
from model.encoder.lgesql import LGESQL
|
||||
from model.encoder.graph_output import *
|
||||
from model.model_utils import Registrable
|
||||
|
||||
@Registrable.register('encoder_text2sql')
|
||||
class Text2SQLEncoder(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(Text2SQLEncoder, self).__init__()
|
||||
lazy_load = args.lazy_load if hasattr(args, 'lazy_load') else False
|
||||
self.input_layer = GraphInputLayer(args.embed_size, args.gnn_hidden_size, args.word_vocab, dropout=args.dropout, schema_aggregation=args.schema_aggregation) \
|
||||
if args.plm is None else GraphInputLayerPLM(args.plm, args.gnn_hidden_size, dropout=args.dropout,
|
||||
subword_aggregation=args.subword_aggregation, schema_aggregation=args.schema_aggregation, lazy_load=lazy_load)
|
||||
self.hidden_layer = Registrable.by_name(args.model)(args)
|
||||
self.output_layer = Registrable.by_name(args.output_model)(args)
|
||||
|
||||
def forward(self, batch):
|
||||
outputs = self.input_layer(batch)
|
||||
outputs = self.hidden_layer(outputs, batch)
|
||||
return self.output_layer(outputs, batch)
|
|
@ -0,0 +1,153 @@
|
|||
#coding=utf8
|
||||
import os, math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.model_utils import rnn_wrapper, lens2mask, PoolingFunction
|
||||
from transformers import AutoModel, AutoConfig
|
||||
|
||||
class GraphInputLayer(nn.Module):
|
||||
|
||||
def __init__(self, embed_size, hidden_size, word_vocab, dropout=0.2, fix_grad_idx=60, schema_aggregation='head+tail'):
|
||||
super(GraphInputLayer, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.hidden_size = hidden_size
|
||||
self.word_vocab = word_vocab
|
||||
self.fix_grad_idx = fix_grad_idx
|
||||
self.word_embed = nn.Embedding(self.word_vocab, self.embed_size, padding_idx=0)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
self.rnn_layer = InputRNNLayer(self.embed_size, self.hidden_size, cell='lstm', schema_aggregation=schema_aggregation)
|
||||
|
||||
def pad_embedding_grad_zero(self, index=None):
|
||||
self.word_embed.weight.grad[0].zero_() # padding symbol is always 0
|
||||
if index is not None:
|
||||
if not torch.is_tensor(index):
|
||||
index = torch.tensor(index, dtype=torch.long, device=self.word_embed.weight.grad.device)
|
||||
self.word_embed.weight.grad.index_fill_(0, index, 0.)
|
||||
else:
|
||||
self.word_embed.weight.grad[self.fix_grad_idx:].zero_()
|
||||
|
||||
def forward(self, batch):
|
||||
question, table, column = self.word_embed(batch.questions), self.word_embed(batch.tables), self.word_embed(batch.columns)
|
||||
if batch.question_unk_mask is not None:
|
||||
question = question.masked_scatter_(batch.question_unk_mask.unsqueeze(-1), batch.question_unk_embeddings[:, :self.embed_size])
|
||||
if batch.table_unk_mask is not None:
|
||||
table = table.masked_scatter_(batch.table_unk_mask.unsqueeze(-1), batch.table_unk_embeddings[:, :self.embed_size])
|
||||
if batch.column_unk_mask is not None:
|
||||
column = column.masked_scatter_(batch.column_unk_mask.unsqueeze(-1), batch.column_unk_embeddings[:, :self.embed_size])
|
||||
input_dict = {
|
||||
"question": self.dropout_layer(question),
|
||||
"table": self.dropout_layer(table),
|
||||
"column": self.dropout_layer(column)
|
||||
}
|
||||
inputs = self.rnn_layer(input_dict, batch)
|
||||
return inputs
|
||||
|
||||
class GraphInputLayerPLM(nn.Module):
|
||||
|
||||
def __init__(self, plm='bert-base-uncased', hidden_size=256, dropout=0., subword_aggregation='mean',
|
||||
schema_aggregation='head+tail', lazy_load=False):
|
||||
super(GraphInputLayerPLM, self).__init__()
|
||||
self.plm_model = AutoModel.from_config(AutoConfig.from_pretrained(os.path.join('./pretrained_models', plm))) \
|
||||
if lazy_load else AutoModel.from_pretrained(os.path.join('./pretrained_models', plm))
|
||||
self.config = self.plm_model.config
|
||||
self.subword_aggregation = SubwordAggregation(self.config.hidden_size, subword_aggregation=subword_aggregation)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
self.rnn_layer = InputRNNLayer(self.config.hidden_size, hidden_size, cell='lstm', schema_aggregation=schema_aggregation)
|
||||
|
||||
def pad_embedding_grad_zero(self, index=None):
|
||||
pass
|
||||
|
||||
def forward(self, batch):
|
||||
outputs = self.plm_model(**batch.inputs)[0] # final layer hidden states
|
||||
question, table, column = self.subword_aggregation(outputs, batch)
|
||||
input_dict = {
|
||||
"question": self.dropout_layer(question),
|
||||
"table": self.dropout_layer(table),
|
||||
"column": self.dropout_layer(column)
|
||||
}
|
||||
inputs = self.rnn_layer(input_dict, batch)
|
||||
return inputs
|
||||
|
||||
class SubwordAggregation(nn.Module):
|
||||
""" Map subword or wordpieces into one fixed size vector based on aggregation method
|
||||
"""
|
||||
def __init__(self, hidden_size, subword_aggregation='mean-pooling'):
|
||||
super(SubwordAggregation, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.aggregation = PoolingFunction(self.hidden_size, self.hidden_size, method=subword_aggregation)
|
||||
|
||||
def forward(self, inputs, batch):
|
||||
""" Transform pretrained model outputs into our desired format
|
||||
questions: bsize x max_question_len x hidden_size
|
||||
tables: bsize x max_table_word_len x hidden_size
|
||||
columns: bsize x max_column_word_len x hidden_size
|
||||
"""
|
||||
old_questions, old_tables, old_columns = inputs.masked_select(batch.question_mask_plm.unsqueeze(-1)), \
|
||||
inputs.masked_select(batch.table_mask_plm.unsqueeze(-1)), inputs.masked_select(batch.column_mask_plm.unsqueeze(-1))
|
||||
questions = old_questions.new_zeros(batch.question_subword_lens.size(0), batch.max_question_subword_len, self.hidden_size)
|
||||
questions = questions.masked_scatter_(batch.question_subword_mask.unsqueeze(-1), old_questions)
|
||||
tables = old_tables.new_zeros(batch.table_subword_lens.size(0), batch.max_table_subword_len, self.hidden_size)
|
||||
tables = tables.masked_scatter_(batch.table_subword_mask.unsqueeze(-1), old_tables)
|
||||
columns = old_columns.new_zeros(batch.column_subword_lens.size(0), batch.max_column_subword_len, self.hidden_size)
|
||||
columns = columns.masked_scatter_(batch.column_subword_mask.unsqueeze(-1), old_columns)
|
||||
|
||||
questions = self.aggregation(questions, mask=batch.question_subword_mask)
|
||||
tables = self.aggregation(tables, mask=batch.table_subword_mask)
|
||||
columns = self.aggregation(columns, mask=batch.column_subword_mask)
|
||||
|
||||
new_questions, new_tables, new_columns = questions.new_zeros(len(batch), batch.max_question_len, self.hidden_size),\
|
||||
tables.new_zeros(batch.table_word_mask.size(0), batch.max_table_word_len, self.hidden_size), \
|
||||
columns.new_zeros(batch.column_word_mask.size(0), batch.max_column_word_len, self.hidden_size)
|
||||
new_questions = new_questions.masked_scatter_(batch.question_mask.unsqueeze(-1), questions)
|
||||
new_tables = new_tables.masked_scatter_(batch.table_word_mask.unsqueeze(-1), tables)
|
||||
new_columns = new_columns.masked_scatter_(batch.column_word_mask.unsqueeze(-1), columns)
|
||||
return new_questions, new_tables, new_columns
|
||||
|
||||
class InputRNNLayer(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, cell='lstm', schema_aggregation='head+tail', share_lstm=False):
|
||||
super(InputRNNLayer, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.cell = cell.upper()
|
||||
self.question_lstm = getattr(nn, self.cell)(self.input_size, self.hidden_size // 2, num_layers=1, bidirectional=True, batch_first=True)
|
||||
self.schema_lstm = self.question_lstm if share_lstm else \
|
||||
getattr(nn, self.cell)(self.input_size, self.hidden_size // 2, num_layers=1, bidirectional=True, batch_first=True)
|
||||
self.schema_aggregation = schema_aggregation
|
||||
if self.schema_aggregation != 'head+tail':
|
||||
self.aggregation = PoolingFunction(self.hidden_size, self.hidden_size, method=schema_aggregation)
|
||||
|
||||
def forward(self, input_dict, batch):
|
||||
"""
|
||||
for question sentence, forward into a bidirectional LSTM to get contextual info and sequential dependence
|
||||
for schema phrase, extract representation for each phrase by concatenating head+tail vectors,
|
||||
batch.question_lens, batch.table_word_lens, batch.column_word_lens are used
|
||||
"""
|
||||
questions, _ = rnn_wrapper(self.question_lstm, input_dict['question'], batch.question_lens, cell=self.cell)
|
||||
questions = questions.contiguous().view(-1, self.hidden_size)[lens2mask(batch.question_lens).contiguous().view(-1)]
|
||||
table_outputs, table_hiddens = rnn_wrapper(self.schema_lstm, input_dict['table'], batch.table_word_lens, cell=self.cell)
|
||||
if self.schema_aggregation != 'head+tail':
|
||||
tables = self.aggregation(table_outputs, mask=batch.table_word_mask)
|
||||
else:
|
||||
table_hiddens = table_hiddens[0].transpose(0, 1) if self.cell == 'LSTM' else table_hiddens.transpose(0, 1)
|
||||
tables = table_hiddens.contiguous().view(-1, self.hidden_size)
|
||||
column_outputs, column_hiddens = rnn_wrapper(self.schema_lstm, input_dict['column'], batch.column_word_lens, cell=self.cell)
|
||||
if self.schema_aggregation != 'head+tail':
|
||||
columns = self.aggregation(column_outputs, mask=batch.column_word_mask)
|
||||
else:
|
||||
column_hiddens = column_hiddens[0].transpose(0, 1) if self.cell == 'LSTM' else column_hiddens.transpose(0, 1)
|
||||
columns = column_hiddens.contiguous().view(-1, self.hidden_size)
|
||||
|
||||
questions = questions.split(batch.question_lens.tolist(), dim=0)
|
||||
tables = tables.split(batch.table_lens.tolist(), dim=0)
|
||||
columns = columns.split(batch.column_lens.tolist(), dim=0)
|
||||
# dgl graph node feats format: q11 q12 ... t11 t12 ... c11 c12 ... q21 q22 ...
|
||||
outputs = [th for q_t_c in zip(questions, tables, columns) for th in q_t_c]
|
||||
outputs = torch.cat(outputs, dim=0)
|
||||
# transformer input format: bsize x max([q1 q2 ... t1 t2 ... c1 c2 ...]) x hidden_size
|
||||
# outputs = []
|
||||
# for q, t, c in zip(questions, tables, columns):
|
||||
# zero_paddings = q.new_zeros((batch.max_len - q.size(0) - t.size(0) - c.size(0), q.size(1)))
|
||||
# cur_outputs = torch.cat([q, t, c, zero_paddings], dim=0)
|
||||
# outputs.append(cur_outputs)
|
||||
# outputs = torch.stack(outputs, dim=0) # bsize x max_len x hidden_size
|
||||
return outputs
|
|
@ -0,0 +1,138 @@
|
|||
#coding=utf8
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import dgl.function as fn
|
||||
from model.model_utils import Registrable
|
||||
from model.encoder.functions import scaled_exp, div_by_z, src_dot_dst
|
||||
|
||||
class ScoreFunction(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, mlp=1, method='biaffine'):
|
||||
super(ScoreFunction, self).__init__()
|
||||
assert method in ['dot', 'bilinear', 'affine', 'biaffine']
|
||||
self.mlp = int(mlp)
|
||||
self.hidden_size = hidden_size // self.mlp
|
||||
if self.mlp > 1: # use mlp to perform dim reduction
|
||||
self.mlp_q = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh())
|
||||
self.mlp_s = nn.Sequential(nn.Linear(hidden_size, self.hidden_size), nn.Tanh())
|
||||
self.method = method
|
||||
if self.method == 'bilinear':
|
||||
self.W = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
elif self.method == 'affine':
|
||||
self.affine = nn.Linear(self.hidden_size * 2, 1)
|
||||
elif self.method == 'biaffine':
|
||||
self.W = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.affine = nn.Linear(self.hidden_size * 2, 1)
|
||||
|
||||
def forward(self, context, node):
|
||||
"""
|
||||
@args:
|
||||
context(torch.FloatTensor): num_nodes x hidden_size
|
||||
node(torch.FloatTensor): num_nodes x hidden_size
|
||||
@return:
|
||||
scores(torch.FloatTensor): num_nodes
|
||||
"""
|
||||
if self.mlp > 1:
|
||||
context, node = self.mlp_q(context), self.mlp_s(node)
|
||||
if self.method == 'dot':
|
||||
scores = (context * node).sum(dim=-1)
|
||||
elif self.method == 'bilinear':
|
||||
scores = (context * self.W(node)).sum(dim=-1)
|
||||
elif self.method == 'affine':
|
||||
scores = self.affine(torch.cat([context, node], dim=-1)).squeeze(-1)
|
||||
elif self.method == 'biaffine':
|
||||
scores = (context * self.W(node)).sum(dim=-1)
|
||||
scores += self.affine(torch.cat([context, node], dim=-1)).squeeze(-1)
|
||||
else:
|
||||
raise ValueError('[Error]: Unrecognized score function method %s!' % (self.method))
|
||||
return scores
|
||||
|
||||
@Registrable.register('without_pruning')
|
||||
class GraphOutputLayer(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(GraphOutputLayer, self).__init__()
|
||||
self.hidden_size = args.gnn_hidden_size
|
||||
|
||||
def forward(self, inputs, batch):
|
||||
""" Re-scatter data format:
|
||||
inputs: sum(q_len + t_len + c_len) x hidden_size
|
||||
outputs: bsize x (max_q_len + max_t_len + max_c_len) x hidden_size
|
||||
"""
|
||||
outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size)
|
||||
outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs)
|
||||
if self.training:
|
||||
return outputs, batch.mask, torch.tensor(0., dtype=torch.float).to(outputs.device)
|
||||
else:
|
||||
return outputs, batch.mask
|
||||
|
||||
@Registrable.register('with_pruning')
|
||||
class GraphOutputLayerWithPruning(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(GraphOutputLayerWithPruning, self).__init__()
|
||||
self.hidden_size = args.gnn_hidden_size
|
||||
self.graph_pruning = GraphPruning(self.hidden_size, args.num_heads, args.dropout, args.score_function)
|
||||
|
||||
def forward(self, inputs, batch):
|
||||
outputs = inputs.new_zeros(len(batch), batch.mask.size(1), self.hidden_size)
|
||||
outputs = outputs.masked_scatter_(batch.mask.unsqueeze(-1), inputs)
|
||||
|
||||
if self.training:
|
||||
g = batch.graph
|
||||
question = inputs.masked_select(g.question_mask.unsqueeze(-1)).view(-1, self.hidden_size)
|
||||
schema = inputs.masked_select(g.schema_mask.unsqueeze(-1)).view(-1, self.hidden_size)
|
||||
loss = self.graph_pruning(question, schema, g.gp, g.node_label)
|
||||
return outputs, batch.mask, loss
|
||||
else:
|
||||
return outputs, batch.mask
|
||||
|
||||
class GraphPruning(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, num_heads=8, feat_drop=0.2, score_function='affine'):
|
||||
super(GraphPruning, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.node_mha = DGLMHA(hidden_size, hidden_size, num_heads, feat_drop)
|
||||
self.node_score_function = ScoreFunction(self.hidden_size, mlp=2, method=score_function)
|
||||
self.loss_function = nn.BCEWithLogitsLoss(reduction='sum')
|
||||
|
||||
def forward(self, question, schema, graph, node_label):
|
||||
node_context = self.node_mha(question, schema, graph)
|
||||
node_score = self.node_score_function(node_context, schema)
|
||||
loss = self.loss_function(node_score, node_label)
|
||||
return loss
|
||||
|
||||
class DGLMHA(nn.Module):
|
||||
""" Multi-head attention implemented with DGL lib
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, num_heads=8, feat_drop=0.2):
|
||||
super(DGLMHA, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.num_heads = num_heads
|
||||
self.d_k = self.hidden_size // self.num_heads
|
||||
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.output_size, self.hidden_size),\
|
||||
nn.Linear(self.hidden_size, self.hidden_size, bias=False), nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.affine_o = nn.Linear(self.hidden_size, self.output_size)
|
||||
self.feat_dropout = nn.Dropout(p=feat_drop)
|
||||
|
||||
def forward(self, context, node, g):
|
||||
q, k, v = self.affine_q(self.feat_dropout(node)), self.affine_k(self.feat_dropout(context)), self.affine_v(self.feat_dropout(context))
|
||||
with g.local_scope():
|
||||
g.nodes['schema'].data['q'] = q.view(-1, self.num_heads, self.d_k)
|
||||
g.nodes['question'].data['k'] = k.view(-1, self.num_heads, self.d_k)
|
||||
g.nodes['question'].data['v'] = v.view(-1, self.num_heads, self.d_k)
|
||||
out_x = self.propagate_attention(g)
|
||||
return self.affine_o(out_x.view(-1, self.hidden_size))
|
||||
|
||||
def propagate_attention(self, g):
|
||||
# Compute attention score
|
||||
g.apply_edges(src_dot_dst('k', 'q', 'score'))
|
||||
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
|
||||
# Update node state
|
||||
g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
|
||||
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
|
||||
out_x = g.nodes['schema'].data['o']
|
||||
return out_x
|
|
@ -0,0 +1,98 @@
|
|||
#coding=utf8
|
||||
import copy, math
|
||||
import torch, dgl
|
||||
import torch.nn as nn
|
||||
import dgl.function as fn
|
||||
from model.model_utils import Registrable, FFN
|
||||
from model.encoder.rgatsql import RGATLayer, MultiViewRGATLayer
|
||||
from model.encoder.functions import *
|
||||
|
||||
@Registrable.register('lgesql')
|
||||
class LGESQL(nn.Module):
|
||||
""" Compared with RGAT, we utilize a line graph to explicitly model the propagation among edges:
|
||||
1. aggregate info from in-edges
|
||||
2. aggregate info from src nodes
|
||||
"""
|
||||
def __init__(self, args):
|
||||
super(LGESQL, self).__init__()
|
||||
self.num_layers, self.num_heads = args.gnn_num_layers, args.num_heads
|
||||
self.relation_share_heads = args.relation_share_heads
|
||||
self.graph_view = args.local_and_nonlocal
|
||||
self.ndim = args.gnn_hidden_size # node feature dim
|
||||
self.edim = self.ndim // self.num_heads if self.relation_share_heads else \
|
||||
self.ndim // 2 if self.graph_view == 'mmc' else self.ndim
|
||||
self.relation_num = args.relation_num
|
||||
self.relation_embed = nn.Embedding(self.relation_num, self.edim)
|
||||
self.gnn_layers = nn.ModuleList([
|
||||
DualRGATLayer(self.ndim, self.edim, num_heads=args.num_heads, feat_drop=args.dropout, graph_view=self.graph_view)
|
||||
for _ in range(self.num_layers)])
|
||||
|
||||
def forward(self, x, batch):
|
||||
global_lgx = self.relation_embed(batch.graph.global_edges)
|
||||
mask = batch.graph.local_mask
|
||||
local_lgx = global_lgx[mask]
|
||||
local_g, global_g, lg = batch.graph.local_g, batch.graph.global_g, batch.graph.lg
|
||||
src_ids, dst_ids = batch.graph.src_ids, batch.graph.dst_ids
|
||||
for i in range(self.num_layers):
|
||||
x, local_lgx = self.gnn_layers[i](x, local_lgx, global_lgx, local_g, global_g, lg, src_ids, dst_ids)
|
||||
if self.graph_view == 'msde':
|
||||
# update local edge embeddings in the global edge embeddings matrix
|
||||
global_lgx = global_lgx.masked_scatter_(mask.unsqueeze(-1), local_lgx)
|
||||
return x
|
||||
|
||||
class DualRGATLayer(nn.Module):
|
||||
|
||||
def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2, graph_view='mmc'):
|
||||
super(DualRGATLayer, self).__init__()
|
||||
self.ndim, self.edim = ndim, edim
|
||||
self.num_heads, self.graph_view = num_heads, graph_view
|
||||
NodeRGATLayer = MultiViewRGATLayer if self.graph_view == 'mmc' else RGATLayer
|
||||
self.node_update = NodeRGATLayer(self.ndim, self.edim, self.num_heads, feat_drop=feat_drop)
|
||||
self.edge_update = EdgeRGATLayer(self.edim, self.ndim, self.num_heads, feat_drop=feat_drop)
|
||||
|
||||
def forward(self, x, local_lgx, global_lgx, local_g, global_g, lg, src_ids, dst_ids):
|
||||
if self.graph_view == 'mmc':
|
||||
out_x, _ = self.node_update(x, local_lgx, global_lgx, local_g, global_g)
|
||||
elif self.graph_view == 'msde':
|
||||
out_x, _ = self.node_update(x, global_lgx, global_g)
|
||||
else:
|
||||
out_x, _ = self.node_update(x, local_lgx, local_g)
|
||||
src_x = torch.index_select(x, dim=0, index=src_ids)
|
||||
dst_x = torch.index_select(x, dim=0, index=dst_ids)
|
||||
out_local_lgx, _ = self.edge_update(local_lgx, src_x, dst_x, lg)
|
||||
return out_x, out_local_lgx
|
||||
|
||||
class EdgeRGATLayer(nn.Module):
|
||||
|
||||
def __init__(self, edim, ndim, num_heads=8, feat_drop=0.2):
|
||||
super(EdgeRGATLayer, self).__init__()
|
||||
self.edim, self.ndim = edim, ndim
|
||||
self.num_heads = num_heads
|
||||
self.d_k = self.ndim // self.num_heads
|
||||
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.edim, self.ndim), \
|
||||
nn.Linear(self.edim, self.ndim, bias=False), nn.Linear(self.edim, self.ndim, bias=False)
|
||||
self.affine_o = nn.Linear(self.ndim, self.edim)
|
||||
self.layernorm = nn.LayerNorm(self.edim)
|
||||
self.feat_dropout = nn.Dropout(p=feat_drop)
|
||||
self.ffn = FFN(self.edim)
|
||||
|
||||
def forward(self, x, src_x, dst_x, g):
|
||||
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
|
||||
with g.local_scope():
|
||||
g.ndata['q'] = (q + src_x).view(-1, self.num_heads, self.d_k)
|
||||
g.ndata['k'] = k.view(-1, self.num_heads, self.d_k)
|
||||
g.ndata['v'] = (v + dst_x).view(-1, self.num_heads, self.d_k)
|
||||
out_x = self.propagate_attention(g)
|
||||
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
|
||||
out_x = self.ffn(out_x)
|
||||
return out_x, (src_x, dst_x)
|
||||
|
||||
def propagate_attention(self, g):
|
||||
# Compute attention score
|
||||
g.apply_edges(src_dot_dst('k', 'q', 'score'))
|
||||
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
|
||||
# Update node state
|
||||
g.update_all(fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
|
||||
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
|
||||
out_x = g.ndata['o']
|
||||
return out_x
|
|
@ -0,0 +1,125 @@
|
|||
#coding=utf8
|
||||
import copy, math
|
||||
import torch, dgl
|
||||
import dgl.function as fn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from model.model_utils import Registrable, FFN
|
||||
from model.encoder.functions import *
|
||||
|
||||
@Registrable.register('rgatsql')
|
||||
class RGATSQL(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(RGATSQL, self).__init__()
|
||||
self.num_layers = args.gnn_num_layers
|
||||
self.relation_num = args.relation_num
|
||||
self.relation_share_layers, self.relation_share_heads = args.relation_share_layers, args.relation_share_heads
|
||||
self.graph_view = 'multiview' if args.local_and_nonlocal in ['mmc', 'msde'] else args.local_and_nonlocal
|
||||
edim = args.gnn_hidden_size // args.num_heads if self.relation_share_heads else \
|
||||
args.gnn_hidden_size // 2 if self.graph_view == 'multiview' else args.gnn_hidden_size
|
||||
if self.relation_share_layers:
|
||||
self.relation_embed = nn.Embedding(args.relation_num, edim)
|
||||
else:
|
||||
self.relation_embed = nn.ModuleList([nn.Embedding(args.relation_num, edim) for _ in range(self.num_layers)])
|
||||
gnn_layer = MultiViewRGATLayer if self.graph_view == 'multiview' else RGATLayer
|
||||
self.gnn_layers = nn.ModuleList([gnn_layer(args.gnn_hidden_size, edim, num_heads=args.num_heads, feat_drop=args.dropout)
|
||||
for _ in range(self.num_layers)])
|
||||
|
||||
def forward(self, x, batch):
|
||||
if self.graph_view == 'multiview':
|
||||
# multi-view multi-head concatenation
|
||||
global_edges, mask = batch.graph.global_edges, batch.graph.local_mask
|
||||
local_g, global_g = batch.graph.local_g, batch.graph.global_g
|
||||
if self.relation_share_layers:
|
||||
global_lgx = self.relation_embed(global_edges)
|
||||
local_lgx = global_lgx[mask]
|
||||
for i in range(self.num_layers):
|
||||
global_lgx = self.relation_embed[i](global_edges) if not self.relation_share_layers else global_lgx
|
||||
local_lgx = global_lgx[mask] if not self.relation_share_layers else local_lgx
|
||||
x, (local_lgx, global_lgx) = self.gnn_layers[i](x, local_lgx, global_lgx, local_g, global_g)
|
||||
else:
|
||||
graph = batch.graph.local_g if self.graph_view == 'local' else batch.graph.global_g
|
||||
edges, mask = batch.graph.global_edges, batch.graph.local_mask
|
||||
if self.relation_share_layers:
|
||||
lgx = self.relation_embed(edges)
|
||||
lgx = lgx if self.graph_view == 'global' else lgx[mask]
|
||||
for i in range(self.num_layers):
|
||||
lgx = lgx if self.relation_share_layers else self.relation_embed[i](edges) if self.graph_view == 'global' else self.relation_embed[i](edges)[mask]
|
||||
x, lgx = self.gnn_layers[i](x, lgx, graph)
|
||||
return x
|
||||
|
||||
class RGATLayer(nn.Module):
|
||||
|
||||
def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2):
|
||||
super(RGATLayer, self).__init__()
|
||||
self.ndim, self.edim = ndim, edim
|
||||
self.num_heads = num_heads
|
||||
dim = max([ndim, edim])
|
||||
self.d_k = dim // self.num_heads
|
||||
self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.ndim, dim),\
|
||||
nn.Linear(self.ndim, dim, bias=False), nn.Linear(self.ndim, dim, bias=False)
|
||||
self.affine_o = nn.Linear(dim, self.ndim)
|
||||
self.layernorm = nn.LayerNorm(self.ndim)
|
||||
self.feat_dropout = nn.Dropout(p=feat_drop)
|
||||
self.ffn = FFN(self.ndim)
|
||||
|
||||
def forward(self, x, lgx, g):
|
||||
""" @Params:
|
||||
x: node feats, num_nodes x ndim
|
||||
lgx: edge feats, num_edges x edim
|
||||
g: dgl.graph
|
||||
"""
|
||||
# pre-mapping q/k/v affine
|
||||
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
|
||||
e = lgx.view(-1, self.num_heads, self.d_k) if lgx.size(-1) == q.size(-1) else \
|
||||
lgx.unsqueeze(1).expand(-1, self.num_heads, -1)
|
||||
with g.local_scope():
|
||||
g.ndata['q'], g.ndata['k'] = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k)
|
||||
g.ndata['v'] = v.view(-1, self.num_heads, self.d_k)
|
||||
g.edata['e'] = e
|
||||
out_x = self.propagate_attention(g)
|
||||
|
||||
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
|
||||
out_x = self.ffn(out_x)
|
||||
return out_x, lgx
|
||||
|
||||
def propagate_attention(self, g):
|
||||
# Compute attention score
|
||||
g.apply_edges(src_sum_edge_mul_dst('k', 'q', 'e', 'score'))
|
||||
g.apply_edges(scaled_exp('score', math.sqrt(self.d_k)))
|
||||
# Update node state
|
||||
g.update_all(src_sum_edge_mul_edge('v', 'e', 'score', 'v'), fn.sum('v', 'wv'))
|
||||
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o'))
|
||||
out_x = g.ndata['o']
|
||||
return out_x
|
||||
|
||||
class MultiViewRGATLayer(RGATLayer):
|
||||
|
||||
def forward(self, x, local_lgx, global_lgx, local_g, global_g):
|
||||
""" @Params:
|
||||
x: node feats, num_nodes x ndim
|
||||
local_lgx: local edge feats, local_edge_num x edim
|
||||
global_lgx: all edge feats, global_edge_num x edim
|
||||
local_g: dgl.graph, a local graph for node update
|
||||
global_g: dgl.graph, a complete graph for node update
|
||||
"""
|
||||
# pre-mapping q/k/v affine
|
||||
q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x))
|
||||
q, k, v = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k), v.view(-1, self.num_heads, self.d_k)
|
||||
with local_g.local_scope():
|
||||
local_g.ndata['q'], local_g.ndata['k'] = q[:, :self.num_heads // 2], k[:, :self.num_heads // 2]
|
||||
local_g.ndata['v'] = v[:, :self.num_heads // 2]
|
||||
local_g.edata['e'] = local_lgx.view(-1, self.num_heads // 2, self.d_k) if local_lgx.size(-1) == self.d_k * self.num_heads // 2 else \
|
||||
local_lgx.unsqueeze(1).expand(-1, self.num_heads // 2, -1)
|
||||
out_x1 = self.propagate_attention(local_g)
|
||||
with global_g.local_scope():
|
||||
global_g.ndata['q'], global_g.ndata['k'] = q[:, self.num_heads // 2:], k[:, self.num_heads // 2:]
|
||||
global_g.ndata['v'] = v[:, self.num_heads // 2:]
|
||||
global_g.edata['e'] = global_lgx.view(-1, self.num_heads // 2, self.d_k) if global_lgx.size(-1) == self.d_k * self.num_heads // 2 else \
|
||||
global_lgx.unsqueeze(1).expand(-1, self.num_heads // 2, -1)
|
||||
out_x2 = self.propagate_attention(global_g)
|
||||
out_x = torch.cat([out_x1, out_x2], dim=1)
|
||||
out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k)))
|
||||
out_x = self.ffn(out_x)
|
||||
return out_x, (local_lgx, global_lgx)
|
|
@ -0,0 +1,41 @@
|
|||
#coding=utf8
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.model_utils import Registrable, PoolingFunction
|
||||
from model.encoder.graph_encoder import *
|
||||
from model.decoder.sql_parser import *
|
||||
|
||||
@Registrable.register('text2sql')
|
||||
class Text2SQL(nn.Module):
|
||||
def __init__(self, args, transition_system):
|
||||
super(Text2SQL, self).__init__()
|
||||
self.encoder = Registrable.by_name('encoder_text2sql')(args)
|
||||
self.encoder2decoder = PoolingFunction(args.gnn_hidden_size, args.lstm_hidden_size, method='attentive-pooling')
|
||||
self.decoder = Registrable.by_name('decoder_tranx')(args, transition_system)
|
||||
|
||||
def forward(self, batch):
|
||||
""" This function is used during training, which returns the entire training loss
|
||||
"""
|
||||
encodings, mask, gp_loss = self.encoder(batch)
|
||||
h0 = self.encoder2decoder(encodings, mask=mask)
|
||||
loss = self.decoder.score(encodings, mask, h0, batch)
|
||||
return loss, gp_loss
|
||||
|
||||
def parse(self, batch, beam_size):
|
||||
""" This function is used for decoding, which returns a batch of [DecodeHypothesis()] * beam_size
|
||||
"""
|
||||
encodings, mask = self.encoder(batch)
|
||||
h0 = self.encoder2decoder(encodings, mask=mask)
|
||||
hyps = []
|
||||
for i in range(len(batch)):
|
||||
""" table_mappings and column_mappings are used to map original database ids to local ids,
|
||||
while reverse_mappings perform the opposite function, mapping local ids to database ids
|
||||
"""
|
||||
hyps.append(self.decoder.parse(encodings[i:i+1], mask[i:i+1], h0[i:i+1], batch, beam_size))
|
||||
return hyps
|
||||
|
||||
def pad_embedding_grad_zero(self, index=None):
|
||||
""" For glove.42B.300d word vectors, gradients for <pad> symbol is always 0;
|
||||
Most words (starting from index) in the word vocab are also fixed except most frequent words
|
||||
"""
|
||||
self.encoder.input_layer.pad_embedding_grad_zero(index)
|
|
@ -0,0 +1,214 @@
|
|||
#coding=utf8
|
||||
import copy, math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
|
||||
def clones(module, N):
|
||||
"Produce N identical layers."
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
def lens2mask(lens):
|
||||
bsize = lens.numel()
|
||||
max_len = lens.max()
|
||||
masks = torch.arange(0, max_len).type_as(lens).to(lens.device).repeat(bsize, 1).lt(lens.unsqueeze(1))
|
||||
masks.requires_grad = False
|
||||
return masks
|
||||
|
||||
def mask2matrix(mask):
|
||||
col_mask, row_mask = mask.unsqueeze(-1), mask.unsqueeze(-2)
|
||||
return col_mask & row_mask
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
E.g. [1, 2, 3], count=2 ==> [1, 1, 2, 2, 3, 3]
|
||||
[[1, 2], [3, 4]], count=3, dim=1 ==> [[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]
|
||||
Different from torch.repeat
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
elif type(x) in [list, tuple]:
|
||||
return type(x)([tile(each, count, dim) for each in x])
|
||||
else:
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.contiguous().view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
def rnn_wrapper(encoder, inputs, lens, cell='lstm'):
|
||||
"""
|
||||
@args:
|
||||
encoder(nn.Module): rnn series bidirectional encoder, batch_first=True
|
||||
inputs(torch.FloatTensor): rnn inputs, [bsize x max_seq_len x in_dim]
|
||||
lens(torch.LongTensor): seq len for each sample, allow length=0, padding with 0-vector, [bsize]
|
||||
@return:
|
||||
out(torch.FloatTensor): output of encoder, bsize x max_seq_len x hidden_dim*2
|
||||
hidden_states([tuple of ]torch.FloatTensor): final hidden states, num_layers*2 x bsize x hidden_dim
|
||||
"""
|
||||
# rerank according to lens and remove empty inputs
|
||||
sorted_lens, sort_key = torch.sort(lens, descending=True)
|
||||
nonzero_num, total_num = torch.sum(sorted_lens > 0).item(), sorted_lens.size(0)
|
||||
sort_key = sort_key[:nonzero_num]
|
||||
sorted_inputs = torch.index_select(inputs, dim=0, index=sort_key)
|
||||
# forward non empty inputs
|
||||
packed_inputs = rnn_utils.pack_padded_sequence(sorted_inputs, sorted_lens[:nonzero_num].tolist(), batch_first=True)
|
||||
packed_out, sorted_h = encoder(packed_inputs) # bsize x srclen x dim
|
||||
sorted_out, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)
|
||||
if cell.upper() == 'LSTM':
|
||||
sorted_h, sorted_c = sorted_h
|
||||
# rerank according to sort_key
|
||||
out_shape = list(sorted_out.size())
|
||||
out_shape[0] = total_num
|
||||
out = sorted_out.new_zeros(*out_shape).scatter_(0, sort_key.unsqueeze(-1).unsqueeze(-1).repeat(1, *out_shape[1:]), sorted_out)
|
||||
h_shape = list(sorted_h.size())
|
||||
h_shape[1] = total_num
|
||||
h = sorted_h.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_h)
|
||||
if cell.upper() == 'LSTM':
|
||||
c = sorted_c.new_zeros(*h_shape).scatter_(1, sort_key.unsqueeze(0).unsqueeze(-1).repeat(h_shape[0], 1, h_shape[-1]), sorted_c)
|
||||
return out, (h.contiguous(), c.contiguous())
|
||||
return out, h.contiguous()
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, q_size, kv_size, output_size, num_heads=8, bias=True, feat_drop=0.2, attn_drop=0.0):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.num_heads = int(num_heads)
|
||||
self.hidden_size = hidden_size
|
||||
assert self.hidden_size % self.num_heads == 0, 'Head num %d must be divided by hidden size %d' % (num_heads, hidden_size)
|
||||
self.d_k = self.hidden_size // self.num_heads
|
||||
self.feat_drop = nn.Dropout(p=feat_drop)
|
||||
self.attn_drop = nn.Dropout(p=attn_drop)
|
||||
self.W_q = nn.Linear(q_size, self.hidden_size, bias=bias)
|
||||
self.W_k = nn.Linear(kv_size, self.hidden_size, bias=False)
|
||||
self.W_v = nn.Linear(kv_size, self.hidden_size, bias=False)
|
||||
self.W_o = nn.Linear(self.hidden_size, output_size, bias=bias)
|
||||
|
||||
def forward(self, hiddens, query_hiddens, mask=None):
|
||||
''' @params:
|
||||
hiddens : encoded sequence representations, bsize x seqlen x hidden_size
|
||||
query_hiddens : bsize [x tgtlen ]x hidden_size
|
||||
mask : length mask for hiddens, ByteTensor, bsize x seqlen
|
||||
@return:
|
||||
context : bsize x[ tgtlen x] hidden_size
|
||||
'''
|
||||
remove_flag = False
|
||||
if query_hiddens.dim() == 2:
|
||||
query_hiddens, remove_flag = query_hiddens.unsqueeze(1), True
|
||||
Q, K, V = self.W_q(self.feat_drop(query_hiddens)), self.W_k(self.feat_drop(hiddens)), self.W_v(self.feat_drop(hiddens))
|
||||
Q, K, V = Q.reshape(-1, Q.size(1), 1, self.num_heads, self.d_k), K.reshape(-1, 1, K.size(1), self.num_heads, self.d_k), V.reshape(-1, 1, V.size(1), self.num_heads, self.d_k)
|
||||
e = (Q * K).sum(-1) / math.sqrt(self.d_k) # bsize x tgtlen x seqlen x num_heads
|
||||
if mask is not None:
|
||||
e = e + ((1 - mask.float()) * (-1e20)).unsqueeze(1).unsqueeze(-1)
|
||||
a = torch.softmax(e, dim=2)
|
||||
concat = (a.unsqueeze(-1) * V).sum(dim=2).reshape(-1, query_hiddens.size(1), self.hidden_size)
|
||||
context = self.W_o(concat)
|
||||
if remove_flag:
|
||||
return context.squeeze(dim=1), a.mean(dim=-1).squeeze(dim=1)
|
||||
else:
|
||||
return context, a.mean(dim=-1)
|
||||
|
||||
class PoolingFunction(nn.Module):
|
||||
""" Map a sequence of hidden_size dim vectors into one fixed size vector with dimension output_size """
|
||||
def __init__(self, hidden_size=256, output_size=256, bias=True, method='attentive-pooling'):
|
||||
super(PoolingFunction, self).__init__()
|
||||
assert method in ['mean-pooling', 'max-pooling', 'attentive-pooling']
|
||||
self.method = method
|
||||
if self.method == 'attentive-pooling':
|
||||
self.attn = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=bias),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size, 1, bias=bias)
|
||||
)
|
||||
self.mapping_function = nn.Sequential(nn.Linear(hidden_size, output_size, bias=bias), nn.Tanh()) \
|
||||
if hidden_size != output_size else lambda x: x
|
||||
|
||||
def forward(self, inputs, mask=None):
|
||||
""" @args:
|
||||
inputs(torch.FloatTensor): features, batch_size x seq_len x hidden_size
|
||||
mask(torch.BoolTensor): mask for inputs, batch_size x seq_len
|
||||
@return:
|
||||
outputs(torch.FloatTensor): aggregate seq_len dim for inputs, batch_size x output_size
|
||||
"""
|
||||
if self.method == 'max-pooling':
|
||||
outputs = inputs.masked_fill(~ mask.unsqueeze(-1), -1e8)
|
||||
outputs = outputs.max(dim=1)[0]
|
||||
elif self.method == 'mean-pooling':
|
||||
mask_float = mask.float().unsqueeze(-1)
|
||||
outputs = (inputs * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
elif self.method == 'attentive-pooling':
|
||||
e = self.attn(inputs).squeeze(-1)
|
||||
e = e + (1 - mask.float()) * (-1e20)
|
||||
a = torch.softmax(e, dim=1).unsqueeze(1)
|
||||
outputs = torch.bmm(a, inputs).squeeze(1)
|
||||
else:
|
||||
raise ValueError('[Error]: Unrecognized pooling method %s !' % (self.method))
|
||||
outputs = self.mapping_function(outputs)
|
||||
return outputs
|
||||
|
||||
class FFN(nn.Module):
|
||||
|
||||
def __init__(self, input_size):
|
||||
super(FFN, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.feedforward = nn.Sequential(
|
||||
nn.Linear(self.input_size, self.input_size * 4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(self.input_size * 4, self.input_size)
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(self.input_size)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.layernorm(inputs + self.feedforward(inputs))
|
||||
|
||||
class Registrable(object):
|
||||
"""
|
||||
A class that collects all registered components,
|
||||
adapted from `common.registrable.Registrable` from AllenNLP
|
||||
"""
|
||||
registered_components = dict()
|
||||
|
||||
@staticmethod
|
||||
def register(name):
|
||||
def register_class(cls):
|
||||
if name in Registrable.registered_components:
|
||||
raise RuntimeError('class %s already registered' % name)
|
||||
|
||||
Registrable.registered_components[name] = cls
|
||||
return cls
|
||||
|
||||
return register_class
|
||||
|
||||
@staticmethod
|
||||
def by_name(name):
|
||||
return Registrable.registered_components[name]
|
||||
|
||||
class cached_property(object):
|
||||
""" A property that is only computed once per instance and then replaces
|
||||
itself with an ordinary attribute. Deleting the attribute resets the
|
||||
property.
|
||||
|
||||
Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.__doc__ = getattr(func, '__doc__')
|
||||
self.func = func
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self
|
||||
value = obj.__dict__[self.func.__name__] = self.func(obj)
|
||||
return value
|
|
@ -0,0 +1,15 @@
|
|||
## Data Preprocess
|
||||
|
||||
#### Get Parsed SQL Output
|
||||
|
||||
The SQL parsing script is `process_sql.py` in the main directory. Please refer to `parsed_sql_examples.sql` for the explanation of some parsed SQL output examples.
|
||||
|
||||
If you would like to use `process_sql.py` to parse SQL queries by yourself, `parse_sql_one.py` provides an example of how the script is called. Or you can use `parse_raw_json.py` to update all parsed SQL results (value for `sql`) in `train.json` and `dev.json`.
|
||||
|
||||
#### Get Table Info from Database
|
||||
|
||||
To generate the final `tables.json` file. It reads sqlite files from `database/` dir and previous `tables.json` with hand-corrected names:
|
||||
|
||||
```
|
||||
python process/get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]
|
||||
```
|
|
@ -0,0 +1,57 @@
|
|||
#coding=utf8
|
||||
import argparse, os, sys, pickle, json
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from collections import Counter
|
||||
|
||||
def construct_vocab_from_dataset(*data_paths, table_path='data/tables.bin', mwf=4, reference_file=None, output_path=None, sep='\t'):
|
||||
words = []
|
||||
tables = pickle.load(open(table_path, 'rb'))
|
||||
for fp in data_paths:
|
||||
dataset = pickle.load(open(fp, 'rb'))
|
||||
for ex in dataset:
|
||||
words.extend(ex['processed_question_toks'])
|
||||
db = tables[ex['db_id']]
|
||||
words.extend(['table'] * len(db['table_names']))
|
||||
words.extend(db['column_types'])
|
||||
for c in db['processed_column_toks']:
|
||||
words.extend(c)
|
||||
for t in db['processed_table_toks']:
|
||||
words.extend(t)
|
||||
cnt = Counter(words)
|
||||
vocab = sorted(list(cnt.items()), key=lambda x: - x[1])
|
||||
glove_vocab = set()
|
||||
with open(reference_file, 'r', encoding='utf-8') as inf:
|
||||
for line in inf:
|
||||
line = line.strip()
|
||||
if line == '': continue
|
||||
glove_vocab.add(line)
|
||||
oov_words, oov_but_freq_words = set(), []
|
||||
for w, c in vocab:
|
||||
if w not in glove_vocab:
|
||||
oov_words.add(w)
|
||||
if c >= mwf:
|
||||
oov_but_freq_words.append((w, c))
|
||||
print('Out of glove vocabulary size: %d\nAmong them, %d words occur equal or more than %d times in training dataset.' % (len(oov_words), len(oov_but_freq_words), mwf))
|
||||
with open(output_path, 'w') as of:
|
||||
# first serialize oov but frequent words, allowing fine-tune them during training
|
||||
for w, c in oov_but_freq_words:
|
||||
of.write(w + sep + str(c) + '\n')
|
||||
# next serialize words in both train vocab and glove vocab according to decreasing frequency
|
||||
for w, c in vocab:
|
||||
if w not in oov_words:
|
||||
of.write(w + sep + str(c) + '\n')
|
||||
return len(vocab)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_paths', nargs='+', type=str, help='input preprocessed dataset file')
|
||||
parser.add_argument('--table_path', type=str, default='data/tables.bin', help='preprocessed table file')
|
||||
parser.add_argument('--output_path', type=str, required=True, help='output word vocabulary path')
|
||||
parser.add_argument('--reference_file', type=str, default='pretrained_models/glove-42b-300d/vocab_glove.txt',
|
||||
help='eliminate word not in glove vocabulary, unless it occurs frequently >= mwf')
|
||||
parser.add_argument('--mwf', default=4, type=int,
|
||||
help='minimum word frequency used to pick up frequent words in training dataset, but not in glove vocabulary')
|
||||
args = parser.parse_args()
|
||||
|
||||
construct_vocab_from_dataset(*args.data_paths, table_path=args.table_path, mwf=args.mwf, reference_file=args.reference_file, output_path=args.output_path)
|
|
@ -0,0 +1,527 @@
|
|||
#coding=utf8
|
||||
import os, sqlite3
|
||||
import numpy as np
|
||||
import stanza, torch
|
||||
from nltk.corpus import stopwords
|
||||
from itertools import product, combinations
|
||||
import torch.nn.functional as F
|
||||
from utils.constants import MAX_RELATIVE_DIST
|
||||
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
||||
import geoopt as gt
|
||||
|
||||
def is_number(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def agg(input):
|
||||
# if input.size(0)==1:
|
||||
# return input.squeeze()
|
||||
# else :
|
||||
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
|
||||
def quote_normalization(question):
|
||||
""" Normalize all usage of quotation marks into a separate \" """
|
||||
new_question, quotation_marks = [], ["'", '"', '`', '‘', '’', '“', '”', '``', "''", "‘‘", "’’"]
|
||||
for idx, tok in enumerate(question):
|
||||
if len(tok) > 2 and tok[0] in quotation_marks and tok[-1] in quotation_marks:
|
||||
new_question += ["\"", tok[1:-1], "\""]
|
||||
elif len(tok) > 2 and tok[0] in quotation_marks:
|
||||
new_question += ["\"", tok[1:]]
|
||||
elif len(tok) > 2 and tok[-1] in quotation_marks:
|
||||
new_question += [tok[:-1], "\"" ]
|
||||
elif tok in quotation_marks:
|
||||
new_question.append("\"")
|
||||
elif len(tok) == 2 and tok[0] in quotation_marks:
|
||||
# special case: the length of entity value is 1
|
||||
if idx + 1 < len(question) and question[idx + 1] in quotation_marks:
|
||||
new_question += ["\"", tok[1]]
|
||||
else:
|
||||
new_question.append(tok)
|
||||
else:
|
||||
new_question.append(tok)
|
||||
return new_question
|
||||
|
||||
class Preprocessor():
|
||||
|
||||
def __init__(self, db_dir='data/database', db_content=True):
|
||||
super(Preprocessor, self).__init__()
|
||||
self.db_dir = db_dir
|
||||
self.db_content = db_content
|
||||
self.nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
|
||||
self.stopwords = stopwords.words("english")
|
||||
self.device = torch.device("cuda:0")
|
||||
self.hidden_size=1024
|
||||
self.max_batch_size = 8
|
||||
self.plm_model =AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(self.device)
|
||||
self.plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
|
||||
self.config = self.plm_model.config
|
||||
self.ball = gt.Stereographic(-1)
|
||||
self.threshold = 0.4
|
||||
|
||||
|
||||
|
||||
def pipeline(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" db should be preprocessed """
|
||||
entry = self.preprocess_question(entry, db, verbose=verbose)
|
||||
entry = self.schema_linking(entry, db, verbose=verbose)
|
||||
entry = self.extract_subgraph(entry, db, verbose=verbose)
|
||||
return entry
|
||||
|
||||
def preprocess_database(self, db: dict, verbose: bool = False):
|
||||
""" Tokenize, lemmatize, lowercase table and column names for each database """
|
||||
table_toks, table_names = [], []
|
||||
for tab in db['table_names']:
|
||||
doc = self.nlp(tab)
|
||||
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
table_toks.append(tab)
|
||||
table_names.append(" ".join(tab))
|
||||
db['processed_table_toks'], db['processed_table_names'] = table_toks, table_names
|
||||
column_toks, column_names = [], []
|
||||
for _, c in db['column_names']:
|
||||
doc = self.nlp(c)
|
||||
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
column_toks.append(c)
|
||||
column_names.append(" ".join(c))
|
||||
db['processed_column_toks'], db['processed_column_names'] = column_toks, column_names
|
||||
column2table = list(map(lambda x: x[0], db['column_names'])) # from column id to table id
|
||||
table2columns = [[] for _ in range(len(table_names))] # from table id to column ids list
|
||||
for col_id, col in enumerate(db['column_names']):
|
||||
if col_id == 0: continue
|
||||
table2columns[col[0]].append(col_id)
|
||||
db['column2table'], db['table2columns'] = column2table, table2columns
|
||||
|
||||
t_num, c_num, dtype = len(db['table_names']), len(db['column_names']), '<U100'
|
||||
|
||||
# relations in tables, tab_num * tab_num
|
||||
tab_mat = np.array([['table-table-generic'] * t_num for _ in range(t_num)], dtype=dtype)
|
||||
table_fks = set(map(lambda pair: (column2table[pair[0]], column2table[pair[1]]), db['foreign_keys']))
|
||||
for (tab1, tab2) in table_fks:
|
||||
if (tab2, tab1) in table_fks:
|
||||
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
|
||||
else:
|
||||
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fk', 'table-table-fkr'
|
||||
tab_mat[list(range(t_num)), list(range(t_num))] = 'table-table-identity'
|
||||
|
||||
# relations in columns, c_num * c_num
|
||||
col_mat = np.array([['column-column-generic'] * c_num for _ in range(c_num)], dtype=dtype)
|
||||
for i in range(t_num):
|
||||
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
|
||||
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
|
||||
col_mat[col1, col2] = 'column-column-sametable'
|
||||
col_mat[list(range(c_num)), list(range(c_num))] = 'column-column-identity'
|
||||
if len(db['foreign_keys']) > 0:
|
||||
col1, col2 = list(zip(*db['foreign_keys']))
|
||||
col_mat[col1, col2], col_mat[col2, col1] = 'column-column-fk', 'column-column-fkr'
|
||||
col_mat[0, list(range(c_num))] = '*-column-generic'
|
||||
col_mat[list(range(c_num)), 0] = 'column-*-generic'
|
||||
col_mat[0, 0] = '*-*-identity'
|
||||
|
||||
# relations between tables and columns, t_num*c_num and c_num*t_num
|
||||
tab_col_mat = np.array([['table-column-generic'] * c_num for _ in range(t_num)], dtype=dtype)
|
||||
col_tab_mat = np.array([['column-table-generic'] * t_num for _ in range(c_num)], dtype=dtype)
|
||||
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) # ignore *
|
||||
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-has', 'table-column-has'
|
||||
if len(db['primary_keys']) > 0:
|
||||
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), db['primary_keys']))))
|
||||
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-pk', 'table-column-pk'
|
||||
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
|
||||
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'
|
||||
|
||||
relations = np.concatenate([
|
||||
np.concatenate([tab_mat, tab_col_mat], axis=1),
|
||||
np.concatenate([col_tab_mat, col_mat], axis=1)
|
||||
], axis=0)
|
||||
db['relations'] = relations.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Tables:', ', '.join(db['table_names']))
|
||||
print('Lemmatized:', ', '.join(table_names))
|
||||
print('Columns:', ', '.join(list(map(lambda x: x[1], db['column_names']))))
|
||||
print('Lemmatized:', ', '.join(column_names), '\n')
|
||||
return db
|
||||
|
||||
def preprocess_question(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" Tokenize, lemmatize, lowercase question"""
|
||||
# stanza tokenize, lemmatize and POS tag
|
||||
question = ' '.join(quote_normalization(entry['question_toks']))
|
||||
doc = self.nlp(question)
|
||||
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
|
||||
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
pos_tags = [w.xpos for s in doc.sentences for w in s.words]
|
||||
|
||||
entry['raw_question_toks'] = raw_toks
|
||||
entry['processed_question_toks'] = toks
|
||||
entry['pos_tags'] = pos_tags
|
||||
|
||||
# relations in questions, q_num * q_num
|
||||
q_num, dtype = len(toks), '<U100'
|
||||
if q_num <= MAX_RELATIVE_DIST + 1:
|
||||
dist_vec = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity'
|
||||
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)]
|
||||
starting = MAX_RELATIVE_DIST
|
||||
else:
|
||||
dist_vec = ['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1) + \
|
||||
['question-question-dist' + str(i) if i != 0 else 'question-question-identity' \
|
||||
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)] + \
|
||||
['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1)
|
||||
starting = q_num - 1
|
||||
q_mat = np.array([dist_vec[starting - i: starting - i + q_num] for i in range(q_num)], dtype=dtype)
|
||||
entry['relations'] = q_mat.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Question:', entry['question'])
|
||||
print('Tokenized:', ' '.join(entry['raw_question_toks']))
|
||||
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
|
||||
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
|
||||
sql = entry['sql']
|
||||
used_schema = {'table': set(), 'column': set()}
|
||||
used_schema = self.extract_subgraph_from_sql(sql, used_schema)
|
||||
entry['used_tables'] = sorted(list(used_schema['table']))
|
||||
entry['used_columns'] = sorted(list(used_schema['column']))
|
||||
|
||||
if verbose:
|
||||
print('Used tables:', entry['used_tables'])
|
||||
print('Used columns:', entry['used_columns'], '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
|
||||
select_items = sql['select'][1]
|
||||
# select clause
|
||||
for _, val_unit in select_items:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# from clause conds
|
||||
table_units = sql['from']['table_units']
|
||||
for _, t in table_units:
|
||||
if type(t) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(t, used_schema)
|
||||
else:
|
||||
used_schema['table'].add(t)
|
||||
# from, where and having conds
|
||||
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'], used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['where'], used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['having'], used_schema)
|
||||
# groupBy and orderBy clause
|
||||
groupBy = sql['groupBy']
|
||||
for col_unit in groupBy:
|
||||
used_schema['column'].add(col_unit[1])
|
||||
orderBy = sql['orderBy']
|
||||
if len(orderBy) > 0:
|
||||
orderBy = orderBy[1]
|
||||
for val_unit in orderBy:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# union, intersect and except clause
|
||||
if sql['intersect']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['intersect'], used_schema)
|
||||
if sql['union']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['union'], used_schema)
|
||||
if sql['except']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['except'], used_schema)
|
||||
return used_schema
|
||||
|
||||
def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
|
||||
if len(conds) == 0:
|
||||
return used_schema
|
||||
for cond in conds:
|
||||
if cond in ['and', 'or']:
|
||||
continue
|
||||
val_unit, val1, val2 = cond[2:]
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
if type(val1) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val1) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
|
||||
if type(val2) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val2) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
|
||||
return used_schema
|
||||
|
||||
def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" Perform schema linking: both question and database need to be preprocessed """
|
||||
raw_question_toks, question_toks = entry['raw_question_toks'], entry['processed_question_toks']
|
||||
table_toks, column_toks = db['processed_table_toks'], db['processed_column_toks']
|
||||
table_names, column_names = db['processed_table_names'], db['processed_column_names']
|
||||
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(column_toks), '<U100'
|
||||
|
||||
assert len(column_names)==len(column_toks) and len(table_names) == len(table_toks) and len(raw_question_toks)==len(question_toks)
|
||||
question_id = [self.plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
|
||||
exact_question_token = len(question_id) - 1
|
||||
question_id.append(self.plm_tokenizer.sep_token_id)
|
||||
masked_question_id = [question_id]
|
||||
start = 1
|
||||
|
||||
for i, sub_len in enumerate(question_subword_len):
|
||||
tmp_question_id = question_id.copy()
|
||||
for m in range(start, start + sub_len):
|
||||
tmp_question_id[m] = self.plm_tokenizer.mask_token_id
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start += sub_len
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
|
||||
table_id.extend(toks)
|
||||
table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
table_word_len.append(l)
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
|
||||
column_id.extend(toks)
|
||||
column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
column_word_len.append(l)
|
||||
column_mask_plm = [1] * len(column_id) + [0]
|
||||
exact_column_token = len(column_id)
|
||||
column_id.append(self.plm_tokenizer.sep_token_id)
|
||||
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
|
||||
input_id = []
|
||||
segment_id = []
|
||||
atten_mask = []
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1] * len(input_id[-1]))
|
||||
|
||||
start = 0
|
||||
total_size = len(input_id)
|
||||
store_arr = []
|
||||
if total_size <= self.max_batch_size:
|
||||
ii = torch.tensor(input_id, dtype=torch.long, device=self.device)
|
||||
im = torch.tensor(atten_mask, dtype=torch.float, device=self.device)
|
||||
si = torch.tensor(segment_id, dtype=torch.long, device=self.device)
|
||||
outputs = self.plm_model(ii, im)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + self.max_batch_size <= len(input_id):
|
||||
ii = torch.tensor(input_id[start: start + self.max_batch_size], dtype=torch.long, device=self.device)
|
||||
im = torch.tensor(atten_mask[start: start + self.max_batch_size], dtype=torch.float, device=self.device)
|
||||
si = torch.tensor(segment_id[start: start + self.max_batch_size], dtype=torch.long, device=self.device)
|
||||
outputs = self.plm_model(ii, im)[0] # .squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii = torch.tensor(input_id[start: len(input_id)], dtype=torch.long, device=self.device)
|
||||
im = torch.tensor(atten_mask[start: len(input_id)], dtype=torch.float, device=self.device)
|
||||
si = torch.tensor(segment_id[start: len(input_id)], dtype=torch.long, device=self.device)
|
||||
outputs = self.plm_model(ii, im)[0] # .squeeze()
|
||||
store_arr.append(outputs)
|
||||
start += self.max_batch_size
|
||||
assert len(store_arr) > 0
|
||||
if len(store_arr) == 1:
|
||||
outputs = store_arr[0]
|
||||
else:
|
||||
outputs = store_arr[0]
|
||||
for t in store_arr[1:]:
|
||||
outputs = torch.cat((outputs, t), dim=0)
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks), len(table_names))
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=self.device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), self.hidden_size)
|
||||
|
||||
start = 0
|
||||
new_table_arr = []
|
||||
for i, sub_len in enumerate(table_word_len):
|
||||
curr = old_tables[:, start:start + sub_len]
|
||||
new_table_arr.append(agg(curr))
|
||||
start += sub_len
|
||||
new_tables = torch.cat(new_table_arr, 1)
|
||||
tbl_cmp = new_tables[0:1]
|
||||
tbl_msk = new_tables[1:]
|
||||
assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(table_word_len)):
|
||||
a = self.ball.expmap0(tbl_cmp[:, i])
|
||||
b = self.ball.expmap0(tbl_msk[:, i])
|
||||
dis=self.ball.dist(a,b)
|
||||
q_tab_mat[:, i] = dis
|
||||
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks), len(column_names))
|
||||
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=self.device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, self.hidden_size)
|
||||
new_column_arr = []
|
||||
start = 0
|
||||
for i, sub_len in enumerate(column_word_len):
|
||||
curr = old_columns[:, start:start + sub_len]
|
||||
new_column_arr.append(agg(curr))
|
||||
start += sub_len
|
||||
new_column = torch.cat(new_column_arr, 1)
|
||||
col_cmp = new_column[0:1]
|
||||
col_msk = new_column[1:]
|
||||
assert col_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(column_word_len)):
|
||||
a = self.ball.expmap0(col_cmp[:, i])
|
||||
b = self.ball.expmap0(col_msk[:, i])
|
||||
dis=self.ball.dist(a,b)
|
||||
q_col_mat[:, i] = dis
|
||||
|
||||
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
|
||||
matrix_min=torch.min(use_matrix)
|
||||
matrix_max=torch.max(use_matrix)
|
||||
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
|
||||
use_q_tab_mat = use_matrix[:, :q_tab_mat.size(1)]
|
||||
use_q_col_mat = use_matrix[:, q_tab_mat.size(1):]
|
||||
assert use_q_tab_mat.size(1) == t_num and use_q_col_mat.size(1)== c_num
|
||||
|
||||
use_tab_q_mat = use_q_tab_mat.transpose(0,1).cpu().detach().numpy()
|
||||
|
||||
use_col_q_mat = use_q_col_mat.transpose(0,1).cpu().detach().numpy()
|
||||
|
||||
table_matched_pairs = {'partial': [], 'exact': []}
|
||||
q_tab_mat = np.array([['question-table-nomatch'] * t_num for _ in range(q_num)], dtype=dtype)
|
||||
tab_q_mat = np.array([['table-question-nomatch'] * q_num for _ in range(t_num)], dtype=dtype)
|
||||
max_len = max([len(t) for t in table_toks])
|
||||
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i: j])
|
||||
if phrase in self.stopwords: continue
|
||||
for idx, name in enumerate(table_names):
|
||||
if phrase == name: # fully match will overwrite partial match due to sort
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
|
||||
|
||||
assert use_tab_q_mat.shape[0]==t_num and use_tab_q_mat.shape[1]==q_num
|
||||
for x in range(t_num):
|
||||
for y in range(q_num):
|
||||
if question_toks[y] in self.stopwords or question_toks[y] in '."?,':
|
||||
continue
|
||||
if use_tab_q_mat[x,y]>self.threshold:
|
||||
if 'partialmatch' in tab_q_mat[x,y]:
|
||||
tab_q_mat[x,y] = 'table-question-partialsemanticmatch'
|
||||
q_tab_mat[y,x] = 'question-table-partialsemanticmatch'
|
||||
elif 'exact' in tab_q_mat[x,y]:
|
||||
continue
|
||||
elif 'nomatch' in tab_q_mat[x,y]:
|
||||
tab_q_mat[x,y] = 'table-question-semanticmatch'
|
||||
q_tab_mat[y,x] = 'question-table-semanticmatch'
|
||||
# relations between questions and columns
|
||||
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
|
||||
q_col_mat = np.array([['question-column-nomatch'] * c_num for _ in range(q_num)], dtype=dtype)
|
||||
col_q_mat = np.array([['column-question-nomatch'] * q_num for _ in range(c_num)], dtype=dtype)
|
||||
max_len = max([len(c) for c in column_toks])
|
||||
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i: j])
|
||||
if phrase in self.stopwords: continue
|
||||
for idx, name in enumerate(column_names):
|
||||
if phrase == name: # fully match will overwrite partial match due to sort
|
||||
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
|
||||
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
|
||||
q_col_mat[range(i, j), idx] = 'question-column-partialmatch'
|
||||
col_q_mat[idx, range(i, j)] = 'column-question-partialmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
|
||||
if self.db_content:
|
||||
db_file = os.path.join(self.db_dir, db['db_id'], db['db_id'] + '.sqlite')
|
||||
if not os.path.exists(db_file):
|
||||
raise ValueError('[ERROR]: database file %s not found ...' % (db_file))
|
||||
conn = sqlite3.connect(db_file)
|
||||
conn.text_factory = lambda b: b.decode(errors='ignore')
|
||||
conn.execute('pragma foreign_keys=ON')
|
||||
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
|
||||
if i == 0 or 'id' in column_toks[i]: # ignore * and special token 'id'
|
||||
continue
|
||||
tab_name = db['table_names_original'][tab_id]
|
||||
try:
|
||||
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" % (col_name, tab_name))
|
||||
cell_values = cursor.fetchall()
|
||||
cell_values = [str(each[0]) for each in cell_values]
|
||||
cell_values = [[str(float(each))] if is_number(each) else each.lower().split() for each in cell_values]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for j, word in enumerate(raw_question_toks):
|
||||
word = str(float(word)) if is_number(word) else word
|
||||
for c in cell_values:
|
||||
if word in c and 'nomatch' in q_col_mat[j, i] and word not in self.stopwords:
|
||||
q_col_mat[j, i] = 'question-column-valuematch'
|
||||
col_q_mat[i, j] = 'column-question-valuematch'
|
||||
if verbose:
|
||||
column_matched_pairs['value'].append(str((column_names[i], i, word, j, j + 1)))
|
||||
break
|
||||
conn.close()
|
||||
|
||||
assert use_col_q_mat.shape[0]==c_num and use_col_q_mat.shape[1]==q_num
|
||||
for x in range(c_num):
|
||||
for y in range(q_num):
|
||||
if question_toks[y] in self.stopwords or question_toks[y] in '."?,':
|
||||
continue
|
||||
if use_col_q_mat[x,y]>self.threshold:
|
||||
if 'partialmatch' in col_q_mat[x,y]:
|
||||
col_q_mat[x,y] = 'column-question-partialsemanticmatch'
|
||||
q_col_mat[y,x] = 'question-column-partialsemanticmatch'
|
||||
elif 'exact' in col_q_mat[x,y] or 'value' in col_q_mat[x,y]:
|
||||
continue
|
||||
elif 'nomatch' in col_q_mat[x,y]:
|
||||
col_q_mat[x,y] = 'column-question-semanticmatch'
|
||||
q_col_mat[y,x] = 'question-column-semanticmatch'
|
||||
|
||||
# two symmetric schema linking matrix: q_num x (t_num + c_num), (t_num + c_num) x q_num
|
||||
q_col_mat[:, 0] = 'question-*-generic'
|
||||
col_q_mat[0] = '*-question-generic'
|
||||
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
|
||||
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
|
||||
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())
|
||||
|
||||
if verbose:
|
||||
print('Question:', ' '.join(question_toks))
|
||||
print('Table matched: (table name, column id, question span, start id, end id)')
|
||||
print('Exact match:', ', '.join(table_matched_pairs['exact']) if table_matched_pairs['exact'] else 'empty')
|
||||
print('Partial match:', ', '.join(table_matched_pairs['partial']) if table_matched_pairs['partial'] else 'empty')
|
||||
print('Column matched: (column name, column id, question span, start id, end id)')
|
||||
print('Exact match:', ', '.join(column_matched_pairs['exact']) if column_matched_pairs['exact'] else 'empty')
|
||||
print('Partial match:', ', '.join(column_matched_pairs['partial']) if column_matched_pairs['partial'] else 'empty')
|
||||
print('Value match:', ', '.join(column_matched_pairs['value']) if column_matched_pairs['value'] else 'empty', '\n')
|
||||
return entry
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,779 @@
|
|||
#coding=utf8
|
||||
import os, sqlite3
|
||||
import numpy as np
|
||||
import stanza, torch
|
||||
from nltk.corpus import stopwords
|
||||
from itertools import product, combinations
|
||||
import torch.nn.functional as F
|
||||
#from utils.constants import MAX_RELATIVE_DIST
|
||||
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
||||
|
||||
def is_number(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def agg(input):
|
||||
# if input.size(0)==1:
|
||||
# return input.squeeze()
|
||||
# else :
|
||||
return torch.sum(input,dim=1)/input.size(1)
|
||||
def quote_normalization(question):
|
||||
""" Normalize all usage of quotation marks into a separate \" """
|
||||
new_question, quotation_marks = [], ["'", '"', '`', '‘', '’', '“', '”', '``', "''", "‘‘", "’’"]
|
||||
for idx, tok in enumerate(question):
|
||||
if len(tok) > 2 and tok[0] in quotation_marks and tok[-1] in quotation_marks:
|
||||
new_question += ["\"", tok[1:-1], "\""]
|
||||
elif len(tok) > 2 and tok[0] in quotation_marks:
|
||||
new_question += ["\"", tok[1:]]
|
||||
elif len(tok) > 2 and tok[-1] in quotation_marks:
|
||||
new_question += [tok[:-1], "\"" ]
|
||||
elif tok in quotation_marks:
|
||||
new_question.append("\"")
|
||||
elif len(tok) == 2 and tok[0] in quotation_marks:
|
||||
# special case: the length of entity value is 1
|
||||
if idx + 1 < len(question) and question[idx + 1] in quotation_marks:
|
||||
new_question += ["\"", tok[1]]
|
||||
else:
|
||||
new_question.append(tok)
|
||||
else:
|
||||
new_question.append(tok)
|
||||
return new_question
|
||||
|
||||
class Preprocessor():
|
||||
|
||||
def __init__(self, db_dir='data/database', db_content=True):
|
||||
super(Preprocessor, self).__init__()
|
||||
self.db_dir = db_dir
|
||||
self.db_content = db_content
|
||||
self.nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
|
||||
self.stopwords = stopwords.words("english")
|
||||
self.device = torch.device("cuda:0")
|
||||
self.hidden_size=1024
|
||||
self.max_batch_size = 8
|
||||
self.plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(self.device)
|
||||
self.plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
|
||||
self.config = self.plm_model.config
|
||||
|
||||
# raw_question_toks=['What', 'is', 'the', 'model', 'for', 'the', 'car', 'with', 'a', 'weight', 'smaller', 'than', 'the', 'average', '?']
|
||||
|
||||
|
||||
# question_id = [self.plm_tokenizer.cls_token_id]
|
||||
# question = [q.lower() for q in raw_question_toks]
|
||||
# question_subword_len = []
|
||||
# for w in question:
|
||||
# toks = self.plm_tokenizer.convert_tokens_to_ids(self.plm_tokenizer.tokenize(w))
|
||||
# question_id.extend(toks)
|
||||
# question_subword_len.append(len(toks))
|
||||
# print(question_id)
|
||||
#self.subword_aggregation = SubwordAggregation(self.config.hidden_size, subword_aggregation=subword_aggregation)
|
||||
|
||||
|
||||
def pipeline(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" db should be preprocessed """
|
||||
entry = self.preprocess_question(entry, db, verbose=verbose)
|
||||
entry = self.schema_linking(entry, db, verbose=verbose)
|
||||
entry = self.extract_subgraph(entry, db, verbose=verbose)
|
||||
return entry
|
||||
|
||||
def preprocess_database(self, db: dict, verbose: bool = False):
|
||||
""" Tokenize, lemmatize, lowercase table and column names for each database """
|
||||
table_toks, table_names = [], []
|
||||
for tab in db['table_names']:
|
||||
doc = self.nlp(tab)
|
||||
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
table_toks.append(tab)
|
||||
table_names.append(" ".join(tab))
|
||||
db['processed_table_toks'], db['processed_table_names'] = table_toks, table_names
|
||||
column_toks, column_names = [], []
|
||||
for _, c in db['column_names']:
|
||||
doc = self.nlp(c)
|
||||
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
column_toks.append(c)
|
||||
column_names.append(" ".join(c))
|
||||
db['processed_column_toks'], db['processed_column_names'] = column_toks, column_names
|
||||
column2table = list(map(lambda x: x[0], db['column_names'])) # from column id to table id
|
||||
table2columns = [[] for _ in range(len(table_names))] # from table id to column ids list
|
||||
for col_id, col in enumerate(db['column_names']):
|
||||
if col_id == 0: continue
|
||||
table2columns[col[0]].append(col_id)
|
||||
db['column2table'], db['table2columns'] = column2table, table2columns
|
||||
|
||||
t_num, c_num, dtype = len(db['table_names']), len(db['column_names']), '<U100'
|
||||
|
||||
# relations in tables, tab_num * tab_num
|
||||
tab_mat = np.array([['table-table-generic'] * t_num for _ in range(t_num)], dtype=dtype)
|
||||
table_fks = set(map(lambda pair: (column2table[pair[0]], column2table[pair[1]]), db['foreign_keys']))
|
||||
for (tab1, tab2) in table_fks:
|
||||
if (tab2, tab1) in table_fks:
|
||||
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
|
||||
else:
|
||||
tab_mat[tab1, tab2], tab_mat[tab2, tab1] = 'table-table-fk', 'table-table-fkr'
|
||||
tab_mat[list(range(t_num)), list(range(t_num))] = 'table-table-identity'
|
||||
|
||||
# relations in columns, c_num * c_num
|
||||
col_mat = np.array([['column-column-generic'] * c_num for _ in range(c_num)], dtype=dtype)
|
||||
for i in range(t_num):
|
||||
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
|
||||
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
|
||||
col_mat[col1, col2] = 'column-column-sametable'
|
||||
col_mat[list(range(c_num)), list(range(c_num))] = 'column-column-identity'
|
||||
if len(db['foreign_keys']) > 0:
|
||||
col1, col2 = list(zip(*db['foreign_keys']))
|
||||
col_mat[col1, col2], col_mat[col2, col1] = 'column-column-fk', 'column-column-fkr'
|
||||
col_mat[0, list(range(c_num))] = '*-column-generic'
|
||||
col_mat[list(range(c_num)), 0] = 'column-*-generic'
|
||||
col_mat[0, 0] = '*-*-identity'
|
||||
|
||||
# relations between tables and columns, t_num*c_num and c_num*t_num
|
||||
tab_col_mat = np.array([['table-column-generic'] * c_num for _ in range(t_num)], dtype=dtype)
|
||||
col_tab_mat = np.array([['column-table-generic'] * t_num for _ in range(c_num)], dtype=dtype)
|
||||
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) # ignore *
|
||||
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-has', 'table-column-has'
|
||||
if len(db['primary_keys']) > 0:
|
||||
cols, tabs = list(zip(*list(map(lambda x: (x, column2table[x]), db['primary_keys']))))
|
||||
col_tab_mat[cols, tabs], tab_col_mat[tabs, cols] = 'column-table-pk', 'table-column-pk'
|
||||
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
|
||||
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'
|
||||
|
||||
relations = np.concatenate([
|
||||
np.concatenate([tab_mat, tab_col_mat], axis=1),
|
||||
np.concatenate([col_tab_mat, col_mat], axis=1)
|
||||
], axis=0)
|
||||
db['relations'] = relations.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Tables:', ', '.join(db['table_names']))
|
||||
print('Lemmatized:', ', '.join(table_names))
|
||||
print('Columns:', ', '.join(list(map(lambda x: x[1], db['column_names']))))
|
||||
print('Lemmatized:', ', '.join(column_names), '\n')
|
||||
return db
|
||||
|
||||
def preprocess_question(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" Tokenize, lemmatize, lowercase question"""
|
||||
# stanza tokenize, lemmatize and POS tag
|
||||
question = ' '.join(quote_normalization(entry['question_toks']))
|
||||
doc = self.nlp(question)
|
||||
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
|
||||
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
pos_tags = [w.xpos for s in doc.sentences for w in s.words]
|
||||
|
||||
entry['raw_question_toks'] = raw_toks
|
||||
entry['processed_question_toks'] = toks
|
||||
entry['pos_tags'] = pos_tags
|
||||
|
||||
# relations in questions, q_num * q_num
|
||||
q_num, dtype = len(toks), '<U100'
|
||||
if q_num <= MAX_RELATIVE_DIST + 1:
|
||||
dist_vec = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity'
|
||||
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)]
|
||||
starting = MAX_RELATIVE_DIST
|
||||
else:
|
||||
dist_vec = ['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1) + \
|
||||
['question-question-dist' + str(i) if i != 0 else 'question-question-identity' \
|
||||
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)] + \
|
||||
['question-question-generic'] * (q_num - MAX_RELATIVE_DIST - 1)
|
||||
starting = q_num - 1
|
||||
q_mat = np.array([dist_vec[starting - i: starting - i + q_num] for i in range(q_num)], dtype=dtype)
|
||||
entry['relations'] = q_mat.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Question:', entry['question'])
|
||||
print('Tokenized:', ' '.join(entry['raw_question_toks']))
|
||||
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
|
||||
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
|
||||
sql = entry['sql']
|
||||
used_schema = {'table': set(), 'column': set()}
|
||||
used_schema = self.extract_subgraph_from_sql(sql, used_schema)
|
||||
entry['used_tables'] = sorted(list(used_schema['table']))
|
||||
entry['used_columns'] = sorted(list(used_schema['column']))
|
||||
|
||||
if verbose:
|
||||
print('Used tables:', entry['used_tables'])
|
||||
print('Used columns:', entry['used_columns'], '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
|
||||
select_items = sql['select'][1]
|
||||
# select clause
|
||||
for _, val_unit in select_items:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# from clause conds
|
||||
table_units = sql['from']['table_units']
|
||||
for _, t in table_units:
|
||||
if type(t) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(t, used_schema)
|
||||
else:
|
||||
used_schema['table'].add(t)
|
||||
# from, where and having conds
|
||||
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'], used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['where'], used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['having'], used_schema)
|
||||
# groupBy and orderBy clause
|
||||
groupBy = sql['groupBy']
|
||||
for col_unit in groupBy:
|
||||
used_schema['column'].add(col_unit[1])
|
||||
orderBy = sql['orderBy']
|
||||
if len(orderBy) > 0:
|
||||
orderBy = orderBy[1]
|
||||
for val_unit in orderBy:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# union, intersect and except clause
|
||||
if sql['intersect']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['intersect'], used_schema)
|
||||
if sql['union']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['union'], used_schema)
|
||||
if sql['except']:
|
||||
used_schema = self.extract_subgraph_from_sql(sql['except'], used_schema)
|
||||
return used_schema
|
||||
|
||||
def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
|
||||
if len(conds) == 0:
|
||||
return used_schema
|
||||
for cond in conds:
|
||||
if cond in ['and', 'or']:
|
||||
continue
|
||||
val_unit, val1, val2 = cond[2:]
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
if type(val1) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val1) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
|
||||
if type(val2) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val2) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
|
||||
return used_schema
|
||||
|
||||
def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" Perform schema linking: both question and database need to be preprocessed """
|
||||
raw_question_toks, question_toks = entry['raw_question_toks'], entry['processed_question_toks']
|
||||
table_toks, column_toks = db['processed_table_toks'], db['processed_column_toks']
|
||||
table_names, column_names = db['processed_table_names'], db['processed_column_names']
|
||||
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(column_toks), '<U100'
|
||||
|
||||
print(raw_question_toks)
|
||||
print(column_names)
|
||||
print(table_names)
|
||||
assert len(column_names)==len(column_toks) and len(table_names) == len(table_toks) and len(raw_question_toks)==len(question_toks)
|
||||
|
||||
# relations between questions and tables, q_num*t_num and t_num*q_num
|
||||
table_matched_pairs = {'partial': [], 'exact': []}
|
||||
q_tab_mat = np.array([['question-table-nomatch'] * t_num for _ in range(q_num)], dtype=dtype)
|
||||
tab_q_mat = np.array([['table-question-nomatch'] * q_num for _ in range(t_num)], dtype=dtype)
|
||||
max_len = max([len(t) for t in table_toks])
|
||||
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i: j])
|
||||
if phrase in self.stopwords: continue
|
||||
for idx, name in enumerate(table_names):
|
||||
if phrase == name: # fully match will overwrite partial match due to sort
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
|
||||
|
||||
# relations between questions and columns
|
||||
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
|
||||
q_col_mat = np.array([['question-column-nomatch'] * c_num for _ in range(q_num)], dtype=dtype)
|
||||
col_q_mat = np.array([['column-question-nomatch'] * q_num for _ in range(c_num)], dtype=dtype)
|
||||
max_len = max([len(c) for c in column_toks])
|
||||
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i: j])
|
||||
if phrase in self.stopwords: continue
|
||||
for idx, name in enumerate(column_names):
|
||||
if phrase == name: # fully match will overwrite partial match due to sort
|
||||
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
|
||||
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['exact'].append(str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
|
||||
q_col_mat[range(i, j), idx] = 'question-column-partialmatch'
|
||||
col_q_mat[idx, range(i, j)] = 'column-question-partialmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['partial'].append(str((name, idx, phrase, i, j)))
|
||||
if self.db_content:
|
||||
db_file = os.path.join(self.db_dir, db['db_id'], db['db_id'] + '.sqlite')
|
||||
if not os.path.exists(db_file):
|
||||
raise ValueError('[ERROR]: database file %s not found ...' % (db_file))
|
||||
conn = sqlite3.connect(db_file)
|
||||
conn.text_factory = lambda b: b.decode(errors='ignore')
|
||||
conn.execute('pragma foreign_keys=ON')
|
||||
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
|
||||
if i == 0 or 'id' in column_toks[i]: # ignore * and special token 'id'
|
||||
continue
|
||||
tab_name = db['table_names_original'][tab_id]
|
||||
try:
|
||||
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" % (col_name, tab_name))
|
||||
cell_values = cursor.fetchall()
|
||||
cell_values = [str(each[0]) for each in cell_values]
|
||||
cell_values = [[str(float(each))] if is_number(each) else each.lower().split() for each in cell_values]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for j, word in enumerate(raw_question_toks):
|
||||
word = str(float(word)) if is_number(word) else word
|
||||
for c in cell_values:
|
||||
if word in c and 'nomatch' in q_col_mat[j, i] and word not in self.stopwords:
|
||||
q_col_mat[j, i] = 'question-column-valuematch'
|
||||
col_q_mat[i, j] = 'column-question-valuematch'
|
||||
if verbose:
|
||||
column_matched_pairs['value'].append(str((column_names[i], i, word, j, j + 1)))
|
||||
break
|
||||
conn.close()
|
||||
|
||||
# two symmetric schema linking matrix: q_num x (t_num + c_num), (t_num + c_num) x q_num
|
||||
q_col_mat[:, 0] = 'question-*-generic'
|
||||
col_q_mat[0] = '*-question-generic'
|
||||
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
|
||||
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
|
||||
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())
|
||||
|
||||
if verbose:
|
||||
print('Question:', ' '.join(question_toks))
|
||||
print('Table matched: (table name, column id, question span, start id, end id)')
|
||||
print('Exact match:', ', '.join(table_matched_pairs['exact']) if table_matched_pairs['exact'] else 'empty')
|
||||
print('Partial match:', ', '.join(table_matched_pairs['partial']) if table_matched_pairs['partial'] else 'empty')
|
||||
print('Column matched: (column name, column id, question span, start id, end id)')
|
||||
print('Exact match:', ', '.join(column_matched_pairs['exact']) if column_matched_pairs['exact'] else 'empty')
|
||||
print('Partial match:', ', '.join(column_matched_pairs['partial']) if column_matched_pairs['partial'] else 'empty')
|
||||
print('Value match:', ', '.join(column_matched_pairs['value']) if column_matched_pairs['value'] else 'empty', '\n')
|
||||
return entry
|
||||
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size=1024
|
||||
max_batch_size = 8
|
||||
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
|
||||
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
|
||||
config = plm_model.config
|
||||
raw_question_toks=['list', 'the', 'service', 'id', 'and', 'details', 'for', 'the', 'events', '.']
|
||||
column_names=['*', 'service id', 'service type code', 'participant id', 'participant type code', 'participant detail', 'event id', 'service id', 'event detail', 'event id', 'participant id']
|
||||
table_names=['service', 'participant', 'event', 'participant in event']
|
||||
|
||||
print('Q', len(raw_question_toks), raw_question_toks)
|
||||
print('C', len(column_names), column_names)
|
||||
print('T', len(table_names), table_names)
|
||||
|
||||
question_id = [plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in raw_question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
|
||||
exact_question_token=len(question_id) - 1
|
||||
question_id.append(plm_tokenizer.sep_token_id)
|
||||
#print(question_id)
|
||||
|
||||
masked_question_id = []
|
||||
#print(len(masked_question_id))
|
||||
|
||||
start=1
|
||||
#print(question_subword_len)
|
||||
for i,sub_len in enumerate(question_subword_len):
|
||||
#print(question_id)
|
||||
tmp_question_id=question_id.copy()
|
||||
#print(list(range(start,start+sub_len)))
|
||||
for m in range(start,start+sub_len):
|
||||
# print('index',index+i)
|
||||
#print('before', tmp_question_id)
|
||||
tmp_question_id[m]=plm_tokenizer.mask_token_id
|
||||
#print('after', tmp_question_id)
|
||||
# print('next', masked_question_id[index+i+1])
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start+=sub_len
|
||||
#print('end')
|
||||
# index+=1
|
||||
# for a in masked_question_id:
|
||||
# print(a)
|
||||
|
||||
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
#s 每一个table name
|
||||
l = 0
|
||||
for w in s:
|
||||
#table name里每一个单词
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
table_id.extend(toks)
|
||||
table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
table_word_len.append(l)
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
# print(table_id)
|
||||
# print(table_subword_len)
|
||||
# print(table_word_len)
|
||||
|
||||
|
||||
masked_table_id = []
|
||||
#print(len(masked_question_id))
|
||||
|
||||
start=0
|
||||
#print(question_subword_len)
|
||||
for i,sub_len in enumerate(table_word_len):
|
||||
#print(question_id)
|
||||
tmp_table_id=table_id.copy()
|
||||
#print(list(range(start,start+sub_len)))
|
||||
for m in range(start,start+sub_len):
|
||||
# print('index',index+i)
|
||||
#print('before', tmp_question_id)
|
||||
tmp_table_id[m]=plm_tokenizer.mask_token_id
|
||||
#print('after', tmp_question_id)
|
||||
# print('next', masked_question_id[index+i+1])
|
||||
masked_table_id.append(tmp_table_id)
|
||||
start+=sub_len
|
||||
#print('end')
|
||||
# index+=1
|
||||
# for a in masked_table_id:
|
||||
# print(a)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
column_id.extend(toks)
|
||||
column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
column_word_len.append(l)
|
||||
column_mask_plm = [1] * len(column_id) + [0]
|
||||
exact_column_token=len(column_id)
|
||||
column_id.append(plm_tokenizer.sep_token_id)
|
||||
#print(column_id)
|
||||
#print(column_subword_len)
|
||||
#print(column_word_len)
|
||||
|
||||
masked_column_id = []
|
||||
#print(len(masked_question_id))
|
||||
|
||||
start=0
|
||||
#print(question_subword_len)
|
||||
for i,sub_len in enumerate(column_word_len):
|
||||
#print(question_id)
|
||||
tmp_column_id=column_id.copy()
|
||||
#print(list(range(start,start+sub_len)))
|
||||
for m in range(start,start+sub_len):
|
||||
# print('index',index+i)
|
||||
#print('before', tmp_question_id)
|
||||
tmp_column_id[m]=plm_tokenizer.mask_token_id
|
||||
#print('after', tmp_question_id)
|
||||
# print('next', masked_question_id[index+i+1])
|
||||
masked_column_id.append(tmp_column_id)
|
||||
start+=sub_len
|
||||
#print('end')
|
||||
# index+=1
|
||||
# for a in masked_column_id:
|
||||
# print(a)
|
||||
|
||||
|
||||
# input_id = question_id + table_id + column_id
|
||||
# segment_id = [0] * len(question_id) + [1] * (len(table_id) + len(column_id))
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
print(len(table_id))
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
|
||||
|
||||
#device = torch.device("cuda:1")
|
||||
# print(len(input_id))
|
||||
# print(len(segment_id))
|
||||
# print(len(question_mask_plm))
|
||||
#inputs = {"input_ids": torch.tensor(input_id, dtype=torch.long, device=device), "attention_mask": torch.tensor([1]*len(input_id), dtype=torch.float, device=device), "token_type_ids": torch.tensor(segment_id, dtype=torch.long, device=device), "position_ids": None}
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
#print(ii.size())
|
||||
#print(im.size())
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
#print(ii.size())
|
||||
#print(im.size())
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
#print(ii.size())
|
||||
#print(im.size())
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
for t in store_arr[1:]:
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
qu_output=outputs
|
||||
#print(qu_output.size())
|
||||
|
||||
|
||||
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
|
||||
|
||||
|
||||
|
||||
for j, msk_t_id in enumerate(masked_table_id):
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + msk_t_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
for t in store_arr[1:]:
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
#print(outputs.size())
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
#print(old_tables.size())
|
||||
cmp_tables = qu_output.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
#print(old_tables.size()) #[16, 6, 1024]
|
||||
# table_compare=old_tables[0] # [6, 1024]
|
||||
# table_mask = old_tables[1:] # [15, 6, 1024]
|
||||
start=0
|
||||
new_table_arr = []
|
||||
cmp_table_arr = []
|
||||
for m,sub_len in enumerate(table_word_len):
|
||||
#print(start, start+sub_len-1)
|
||||
if m==j:
|
||||
curr=agg(old_tables[:, start:start+sub_len])
|
||||
cmp_curr = agg(cmp_tables[:, start:start+sub_len])
|
||||
# print(curr.size())
|
||||
# print(cmp_curr.size())
|
||||
dis=F.pairwise_distance(cmp_curr,curr,p=2)
|
||||
q_tab_mat[:,j]=dis
|
||||
start+=sub_len
|
||||
break
|
||||
else:
|
||||
start+=sub_len
|
||||
#print(q_tab_mat)
|
||||
|
||||
# curr=old_tables[:, start:start+sub_len]
|
||||
# #print(curr.size())
|
||||
# new_table_arr.append(agg(curr))
|
||||
|
||||
# curr=cmp_tables[:, start:start+sub_len]
|
||||
# #print(curr.size())
|
||||
# cmp_table_arr.append(agg(curr))
|
||||
# # new_tables[i]=agg(curr)
|
||||
# #print(new_questions[i][:10])
|
||||
# #print(curr.size())
|
||||
# start+=sub_len
|
||||
# new_tables = torch.cat(new_table_arr, 1)
|
||||
# #print(new_tables.size()) # [16, 2, 1024]
|
||||
# tbl_msk=new_tables[:,j] #15, 2, 1024
|
||||
# print('tbl_msk',tbl_msk.size())
|
||||
# # assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
# for i, msk_q_id in enumerate(masked_question_id):
|
||||
# a=qu_output[i]
|
||||
# b=tbl_msk[i]
|
||||
# dis=F.pairwise_distance(a,b,p=2)
|
||||
# q_tab_mat[i][j]=dis
|
||||
|
||||
# # for i in range(len(table_word_len)):
|
||||
# # a=tbl_cmp[:,i]
|
||||
# # #print(a.size())
|
||||
# # b=tbl_msk[:,i]
|
||||
# # #print(b.size())
|
||||
# # dis=F.pairwise_distance(a,b,p=2)
|
||||
# # q_tab_mat[:,i]=dis
|
||||
# # #print(dis)
|
||||
#print(q_tab_mat.transpose(0,1).size())
|
||||
print(q_tab_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
|
||||
# old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
|
||||
# #print(old_columns.size())
|
||||
# new_column_arr = []
|
||||
# start=0
|
||||
# for i,sub_len in enumerate(column_word_len):
|
||||
# #print(start, start+sub_len-1)
|
||||
# curr=old_columns[:,start:start+sub_len]
|
||||
# new_column_arr.append(agg(curr))
|
||||
# #print(new_questions[i][:10])
|
||||
# #print(curr.size())
|
||||
# start+=sub_len
|
||||
# new_column = torch.cat(new_column_arr, 1)
|
||||
# #print(new_column.size()) #[16, 8, 1024]
|
||||
|
||||
# col_cmp=new_column[0:1] #1,2,1024
|
||||
# col_msk=new_column[1:] #15, 2, 1024
|
||||
# assert col_msk.size(0) == len(raw_question_toks)
|
||||
# for i in range(len(column_word_len)):
|
||||
# a=col_cmp[:,i]
|
||||
# #print(a.size())
|
||||
# b=col_msk[:,i]
|
||||
# #print(b.size())
|
||||
# dis=F.pairwise_distance(a,b,p=2)
|
||||
# q_col_mat[:,i]=dis
|
||||
# #print(dis)
|
||||
# print(q_col_mat.transpose(0,1).size())
|
||||
# print(q_col_mat.transpose(0,1))
|
||||
# print(q_col_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
#for i, curr_table in enumerate(table_names):
|
||||
|
||||
|
||||
|
||||
|
||||
# use_whole = outputs[0] #35, 1024]
|
||||
# use_mask = outputs[1:] #[15, 35, 1024]
|
||||
# print(use_whole.size())
|
||||
# print(use_mask.size())
|
||||
# assert use_mask.size(0)==len(raw_question_toks)
|
||||
|
||||
|
||||
# for i, sub_len in enumerate(raw_question_toks):
|
||||
# outputs=use_mask[i]
|
||||
# print(outputs.size())
|
||||
|
||||
|
||||
|
||||
|
||||
# old_questions = outputs.masked_select(torch.tensor(question_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(exact_question_token, hidden_size)
|
||||
# print(old_questions.size())
|
||||
# new_questions = old_questions.new_zeros(len(raw_question_toks), hidden_size)
|
||||
# start=0
|
||||
# print(question_subword_len)
|
||||
# assert len(question_subword_len) == len(raw_question_toks)
|
||||
# for i,sub_len in enumerate(question_subword_len):
|
||||
# #print(start, start+sub_len-1)
|
||||
# curr=old_questions[start:start+sub_len]
|
||||
# new_questions[i]=agg(curr)
|
||||
# #print(new_questions[i][:10])
|
||||
# #print(curr.size())
|
||||
# start+=sub_len
|
||||
|
||||
# old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(len(table_id), hidden_size)
|
||||
# print(old_tables.size())
|
||||
# new_tables = old_tables.new_zeros(len(table_names), hidden_size)
|
||||
# start=0
|
||||
# print(table_word_len)
|
||||
# assert len(table_word_len) == len(table_names)
|
||||
# for i,sub_len in enumerate(table_word_len):
|
||||
# #print(start, start+sub_len-1)
|
||||
# curr=old_tables[start:start+sub_len]
|
||||
# new_tables[i]=agg(curr)
|
||||
# #print(new_questions[i][:10])
|
||||
# #print(curr.size())
|
||||
# start+=sub_len
|
||||
|
||||
|
||||
# old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1)).view(exact_column_token, hidden_size)
|
||||
# print(old_columns.size())
|
||||
# new_column = old_columns.new_zeros(len(column_names),hidden_size)
|
||||
# start=0
|
||||
# print(column_word_len)
|
||||
# assert len(column_word_len) == len(column_names)
|
||||
# for i,sub_len in enumerate(column_word_len):
|
||||
# #print(start, start+sub_len-1)
|
||||
# curr=old_columns[start:start+sub_len]
|
||||
# new_column[i]=agg(curr)
|
||||
# #print(new_questions[i][:10])
|
||||
# print(curr.size())
|
||||
# start+=sub_len
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
from os import listdir, makedirs
|
||||
from os.path import isfile, isdir, join, split, exists, splitext
|
||||
import traceback
|
||||
|
||||
EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"}
|
||||
|
||||
def convert_fk_index(data):
|
||||
fk_holder = []
|
||||
for fk in data["foreign_keys"]:
|
||||
tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1]
|
||||
ref_cid, cid = None, None
|
||||
try:
|
||||
tid = data['table_names_original'].index(tn)
|
||||
ref_tid = data['table_names_original'].index(ref_tn)
|
||||
|
||||
for i, (tab_id, col_org) in enumerate(data['column_names_original']):
|
||||
if tab_id == ref_tid and ref_col == col_org:
|
||||
ref_cid = i
|
||||
elif tid == tab_id and col == col_org:
|
||||
cid = i
|
||||
if ref_cid and cid:
|
||||
fk_holder.append([cid, ref_cid])
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("table_names_original: ", data['table_names_original'])
|
||||
print("finding tab name: ", tn, ref_tn)
|
||||
sys.exit()
|
||||
return fk_holder
|
||||
|
||||
|
||||
def dump_db_json_schema(db, f):
|
||||
'''read table and column info'''
|
||||
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute('pragma foreign_keys=ON')
|
||||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
|
||||
data = {'db_id': f,
|
||||
'table_names_original': [],
|
||||
'table_names': [],
|
||||
'column_names_original': [(-1, '*')],
|
||||
'column_names': [(-1, '*')],
|
||||
'column_types': ['text'],
|
||||
'primary_keys': [],
|
||||
'foreign_keys': []}
|
||||
|
||||
fk_holder = []
|
||||
for i, item in enumerate(cursor.fetchall()):
|
||||
table_name = item[0]
|
||||
data['table_names_original'].append(table_name)
|
||||
data['table_names'].append(table_name.lower().replace("_", ' '))
|
||||
fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall()
|
||||
fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks])
|
||||
cur = conn.execute("PRAGMA table_info('{}') ".format(table_name))
|
||||
for j, col in enumerate(cur.fetchall()):
|
||||
data['column_names_original'].append((i, col[1]))
|
||||
data['column_names'].append((i, col[1].lower().replace("_", " ")))
|
||||
col_type = col[2].lower()
|
||||
if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type:
|
||||
data['column_types'].append('text')
|
||||
elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type\
|
||||
or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type:
|
||||
data['column_types'].append('number')
|
||||
elif 'date' in col_type or 'time' in col_type or 'year' in col_type:
|
||||
data['column_types'].append('time')
|
||||
elif 'boolean' in col_type:
|
||||
data['column_types'].append('boolean')
|
||||
else:
|
||||
data['column_types'].append('others')
|
||||
|
||||
if col[5] == 1:
|
||||
data['primary_keys'].append(len(data['column_names'])-1)
|
||||
|
||||
data["foreign_keys"] = fk_holder
|
||||
data['foreign_keys'] = convert_fk_index(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]")
|
||||
sys.exit()
|
||||
input_dir = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
ex_tab_file = sys.argv[3]
|
||||
|
||||
all_fs = [df for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))]
|
||||
with open(ex_tab_file) as f:
|
||||
ex_tabs = json.load(f)
|
||||
#for tab in ex_tabs:
|
||||
# tab["foreign_keys"] = convert_fk_index(tab)
|
||||
ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs}
|
||||
# print "precessed file num: ", len(ex_tabs)
|
||||
not_fs = [df for df in listdir(input_dir) if not exists(join(input_dir, df, df+'.sqlite'))]
|
||||
for d in not_fs:
|
||||
print("no sqlite file found in: ", d)
|
||||
db_files = [(df+'.sqlite', df) for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))]
|
||||
tables = []
|
||||
for f, df in db_files:
|
||||
#if df in ex_tabs.keys():
|
||||
#print 'reading old db: ', df
|
||||
# tables.append(ex_tabs[df])
|
||||
db = join(input_dir, df, f)
|
||||
# print('\nreading new db: ', df)
|
||||
table = dump_db_json_schema(db, df)
|
||||
prev_tab_num = len(ex_tabs[df]["table_names"])
|
||||
prev_col_num = len(ex_tabs[df]["column_names"])
|
||||
cur_tab_num = len(table["table_names"])
|
||||
cur_col_num = len(table["column_names"])
|
||||
if df in ex_tabs.keys() and prev_tab_num == cur_tab_num and prev_col_num == cur_col_num and prev_tab_num != 0 and len(ex_tabs[df]["column_names"]) > 1:
|
||||
table["table_names"] = ex_tabs[df]["table_names"]
|
||||
table["column_names"] = ex_tabs[df]["column_names"]
|
||||
else:
|
||||
print("\n----------------------------------problem db: ", df)
|
||||
tables.append(table)
|
||||
print("final db num: ", len(tables))
|
||||
with open(output_file, 'wt') as out:
|
||||
json.dump(tables, out, sort_keys=True, indent=2, separators=(',', ': '))
|
|
@ -0,0 +1,90 @@
|
|||
#coding=utf8
|
||||
import math, dgl, torch
|
||||
import numpy as np
|
||||
from utils.constants import MAX_RELATIVE_DIST
|
||||
from utils.graph_example import GraphExample
|
||||
|
||||
# mapping special column * as an ordinary column
|
||||
special_column_mapping_dict = {
|
||||
'question-*-generic': 'question-column-nomatch',
|
||||
'*-question-generic': 'column-question-nomatch',
|
||||
'table-*-generic': 'table-column-has',
|
||||
'*-table-generic': 'column-table-has',
|
||||
'*-column-generic': 'column-column-generic',
|
||||
'column-*-generic': 'column-column-generic',
|
||||
'*-*-identity': 'column-column-identity'
|
||||
}
|
||||
nonlocal_relations = [
|
||||
'question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic',
|
||||
'table-table-fk', 'table-table-fkr', 'table-table-fkb', 'column-column-sametable',
|
||||
'*-column-generic', 'column-*-generic', '*-*-identity', '*-table-generic',
|
||||
'question-question-identity', 'table-table-identity', 'column-column-identity'] + [
|
||||
'question-question-dist' + str(i) for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1) if i not in [-1, 0, 1]
|
||||
]
|
||||
|
||||
class GraphProcessor():
|
||||
|
||||
def process_rgatsql(self, ex: dict, db: dict, relation: list):
|
||||
graph = GraphExample()
|
||||
num_nodes = int(math.sqrt(len(relation)))
|
||||
local_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r))
|
||||
for idx, r in enumerate(relation) if r not in nonlocal_relations]
|
||||
nonlocal_edges = [(idx // num_nodes, idx % num_nodes, (special_column_mapping_dict[r] if r in special_column_mapping_dict else r))
|
||||
for idx, r in enumerate(relation) if r in nonlocal_relations]
|
||||
global_edges = local_edges + nonlocal_edges
|
||||
src_ids, dst_ids = list(map(lambda r: r[0], global_edges)), list(map(lambda r: r[1], global_edges))
|
||||
graph.global_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32)
|
||||
graph.global_edges = global_edges
|
||||
src_ids, dst_ids = list(map(lambda r: r[0], local_edges)), list(map(lambda r: r[1], local_edges))
|
||||
graph.local_g = dgl.graph((src_ids, dst_ids), num_nodes=num_nodes, idtype=torch.int32)
|
||||
graph.local_edges = local_edges
|
||||
# graph pruning for nodes
|
||||
q_num = len(ex['processed_question_toks'])
|
||||
s_num = num_nodes - q_num
|
||||
graph.question_mask = [1] * q_num + [0] * s_num
|
||||
graph.schema_mask = [0] * q_num + [1] * s_num
|
||||
graph.gp = dgl.heterograph({
|
||||
('question', 'to', 'schema'): (list(range(q_num)) * s_num,
|
||||
[i for i in range(s_num) for _ in range(q_num)])
|
||||
}, num_nodes_dict={'question': q_num, 'schema': s_num}, idtype=torch.int32
|
||||
)
|
||||
t_num = len(db['processed_table_toks'])
|
||||
def check_node(i):
|
||||
if i < t_num and i in ex['used_tables']:
|
||||
return 1.0
|
||||
elif i >= t_num and i - t_num in ex['used_columns']:
|
||||
return 1.0
|
||||
else: return 0.0
|
||||
graph.node_label = list(map(check_node, range(s_num)))
|
||||
ex['graph'] = graph
|
||||
return ex
|
||||
|
||||
def process_lgesql(self, ex: dict, db: dict, relation: list):
|
||||
ex = self.process_rgatsql(ex, db, relation)
|
||||
graph = ex['graph']
|
||||
lg = graph.local_g.line_graph(backtracking=False)
|
||||
# prevent information propagate through matching edges
|
||||
match_ids = [idx for idx, r in enumerate(graph.global_edges) if 'match' in r[2]]
|
||||
src, dst, eids = lg.edges(form='all', order='eid')
|
||||
eids = [e for u, v, e in zip(src.tolist(), dst.tolist(), eids.tolist()) if not (u in match_ids and v in match_ids)]
|
||||
graph.lg = lg.edge_subgraph(eids, preserve_nodes=True).remove_self_loop().add_self_loop()
|
||||
ex['graph'] = graph
|
||||
return ex
|
||||
|
||||
def process_graph_utils(self, ex: dict, db: dict, method: str = 'rgatsql'):
|
||||
""" Example should be preprocessed by self.pipeline
|
||||
"""
|
||||
q = np.array(ex['relations'], dtype='<U100')
|
||||
s = np.array(db['relations'], dtype='<U100')
|
||||
q_s = np.array(ex['schema_linking'][0], dtype='<U100')
|
||||
s_q = np.array(ex['schema_linking'][1], dtype='<U100')
|
||||
relation = np.concatenate([
|
||||
np.concatenate([q, q_s], axis=1),
|
||||
np.concatenate([s_q, s], axis=1)
|
||||
], axis=0)
|
||||
relation = relation.flatten().tolist()
|
||||
if method == 'rgatsql':
|
||||
ex = self.process_rgatsql(ex, db, relation)
|
||||
elif method == 'lgesql':
|
||||
ex = self.process_lgesql(ex, db, relation)
|
||||
return ex
|
|
@ -0,0 +1,112 @@
|
|||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from process_sql import get_sql
|
||||
|
||||
|
||||
#TODO: update the following dirs
|
||||
sql_path = 'data/train.json'
|
||||
db_dir = 'data/database/'
|
||||
output_file = 'train.json'
|
||||
table_file = 'data/tables.json'
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema, table):
|
||||
self._schema = schema
|
||||
self._table = table
|
||||
self._idMap = self._map(self._schema, self._table)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema, table):
|
||||
column_names_original = table['column_names_original']
|
||||
table_names_original = table['table_names_original']
|
||||
#print 'column_names_original: ', column_names_original
|
||||
#print 'table_names_original: ', table_names_original
|
||||
for i, (tab_id, col) in enumerate(column_names_original):
|
||||
if tab_id == -1:
|
||||
idMap = {'*': i}
|
||||
else:
|
||||
key = table_names_original[tab_id].lower()
|
||||
val = col.lower()
|
||||
idMap[key + "." + val] = i
|
||||
|
||||
for i, tab in enumerate(table_names_original):
|
||||
key = tab.lower()
|
||||
idMap[key] = i
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schemas_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
db_names = [db['db_id'] for db in data]
|
||||
|
||||
tables = {}
|
||||
schemas = {}
|
||||
for db in data:
|
||||
db_id = db['db_id']
|
||||
schema = {} #{'table': [col.lower, ..., ]} * -> __all__
|
||||
column_names_original = db['column_names_original']
|
||||
table_names_original = db['table_names_original']
|
||||
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
|
||||
for i, tabn in enumerate(table_names_original):
|
||||
table = str(tabn.lower())
|
||||
cols = [str(col.lower()) for td, col in column_names_original if td == i]
|
||||
schema[table] = cols
|
||||
schemas[db_id] = schema
|
||||
|
||||
return schemas, db_names, tables
|
||||
|
||||
|
||||
|
||||
schemas, db_names, tables = get_schemas_from_json(table_file)
|
||||
|
||||
with open(sql_path) as inf:
|
||||
sql_data = json.load(inf)
|
||||
|
||||
sql_data_new = []
|
||||
for data in sql_data:
|
||||
try:
|
||||
if data['query'] == 'SELECT T1.company_name FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id JOIN Ref_Company_Types AS T3 ON T1.company_type_code = T3.company_type_code ORDER BY T2.contract_end_date DESC LIMIT 1':
|
||||
data['query'] = 'SELECT T1.company_type FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id ORDER BY T2.contract_end_date DESC LIMIT 1'
|
||||
data['query_toks'] = ['SELECT', 'T1.company_type', 'FROM', 'Third_Party_Companies', 'AS', 'T1', 'JOIN', 'Maintenance_Contracts', 'AS', 'T2', 'ON', 'T1.company_id', '=', 'T2.maintenance_contract_company_id', 'ORDER', 'BY', 'T2.contract_end_date', 'DESC', 'LIMIT', '1']
|
||||
data['query_toks_no_value'] = ['select', 't1', '.', 'company_type', 'from', 'third_party_companies', 'as', 't1', 'join', 'maintenance_contracts', 'as', 't2', 'on', 't1', '.', 'company_id', '=', 't2', '.', 'maintenance_contract_company_id', 'order', 'by', 't2', '.', 'contract_end_date', 'desc', 'limit', 'value']
|
||||
data['question'] = 'What is the type of the company who concluded its contracts most recently?'
|
||||
data['question_toks'] = ['What', 'is', 'the', 'type', 'of', 'the', 'company', 'who', 'concluded', 'its', 'contracts', 'most', 'recently', '?']
|
||||
if data['query'].startswith('SELECT T1.fname FROM student AS T1 JOIN lives_in AS T2 ON T1.stuid = T2.stuid WHERE T2.dormid IN'):
|
||||
data['query'] = data['query'].replace('IN (SELECT T2.dormid)', 'IN (SELECT T3.dormid)')
|
||||
index = data['query_toks'].index('(') + 2
|
||||
assert data['query_toks'][index] == 'T2.dormid'
|
||||
data['query_toks'][index] = 'T3.dormid'
|
||||
index = data['query_toks_no_value'].index('(') + 2
|
||||
assert data['query_toks_no_value'][index] == 't2'
|
||||
data['query_toks_no_value'][index] = 't3'
|
||||
db_id = data["db_id"]
|
||||
schema = schemas[db_id]
|
||||
table = tables[db_id]
|
||||
schema = Schema(schema, table)
|
||||
sql = data["query"]
|
||||
sql_label = get_sql(schema, sql)
|
||||
data["sql"] = sql_label
|
||||
sql_data_new.append(data)
|
||||
except:
|
||||
print("db_id: ", db_id)
|
||||
print("sql: ", sql)
|
||||
|
||||
with open(output_file, 'wt') as out:
|
||||
json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(',', ': '))
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import traceback
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import sqlparse
|
||||
import random
|
||||
from os import listdir, makedirs
|
||||
from collections import OrderedDict
|
||||
from nltk import word_tokenize, tokenize
|
||||
from os.path import isfile, isdir, join, split, exists, splitext
|
||||
|
||||
from process_sql import get_sql
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema, table):
|
||||
self._schema = schema
|
||||
self._table = table
|
||||
self._idMap = self._map(self._schema, self._table)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema, table):
|
||||
column_names_original = table['column_names_original']
|
||||
table_names_original = table['table_names_original']
|
||||
#print 'column_names_original: ', column_names_original
|
||||
#print 'table_names_original: ', table_names_original
|
||||
for i, (tab_id, col) in enumerate(column_names_original):
|
||||
if tab_id == -1:
|
||||
idMap = {'*': i}
|
||||
else:
|
||||
key = table_names_original[tab_id].lower()
|
||||
val = col.lower()
|
||||
idMap[key + "." + val] = i
|
||||
|
||||
for i, tab in enumerate(table_names_original):
|
||||
key = tab.lower()
|
||||
idMap[key] = i
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schemas_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
db_names = [db['db_id'] for db in data]
|
||||
|
||||
tables = {}
|
||||
schemas = {}
|
||||
for db in data:
|
||||
db_id = db['db_id']
|
||||
schema = {} #{'table': [col.lower, ..., ]} * -> __all__
|
||||
column_names_original = db['column_names_original']
|
||||
table_names_original = db['table_names_original']
|
||||
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
|
||||
for i, tabn in enumerate(table_names_original):
|
||||
table = str(tabn.lower())
|
||||
cols = [str(col.lower()) for td, col in column_names_original if td == i]
|
||||
schema[table] = cols
|
||||
schemas[db_id] = schema
|
||||
|
||||
return schemas, db_names, tables
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
sql = "SELECT name , country , age FROM singer ORDER BY age DESC"
|
||||
db_id = "concert_singer"
|
||||
table_file = "tables.json"
|
||||
|
||||
schemas, db_names, tables = get_schemas_from_json(table_file)
|
||||
schema = schemas[db_id]
|
||||
table = tables[db_id]
|
||||
schema = Schema(schema, table)
|
||||
sql_label = get_sql(schema, sql)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
#coding=utf8
|
||||
import os, json, pickle, argparse, sys, time
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.transition_system import TransitionSystem
|
||||
from asdl.action_info import get_action_infos
|
||||
from preprocess.common_utils import Preprocessor
|
||||
|
||||
def process_example(processor, entry, db, trans, verbose=False):
|
||||
# preprocess raw tokens, schema linking and subgraph extraction
|
||||
entry = processor.pipeline(entry, db, verbose=verbose)
|
||||
# generate target output actions
|
||||
ast = trans.surface_code_to_ast(entry['sql'])
|
||||
actions = trans.get_actions(ast)
|
||||
entry['ast'] = ast
|
||||
entry['actions'] = get_action_infos(tgt_actions=actions)
|
||||
return entry
|
||||
|
||||
def process_tables(processor, tables_list, output_path=None, verbose=False):
|
||||
tables = {}
|
||||
for each in tables_list:
|
||||
if verbose:
|
||||
print('*************** Processing database %s **************' % (each['db_id']))
|
||||
tables[each['db_id']] = processor.preprocess_database(each, verbose=verbose)
|
||||
print('In total, process %d databases .' % (len(tables)))
|
||||
if output_path is not None:
|
||||
pickle.dump(tables, open(output_path, 'wb'))
|
||||
return tables
|
||||
|
||||
def process_dataset(processor, dataset, tables, output_path=None, skip_large=False, verbose=False):
|
||||
from utils.constants import GRAMMAR_FILEPATH
|
||||
grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
|
||||
trans = TransitionSystem.get_class_by_lang('sql')(grammar)
|
||||
processed_dataset = []
|
||||
for idx, entry in enumerate(dataset):
|
||||
if skip_large and len(tables[entry['db_id']]['column_names']) > 100: continue
|
||||
if verbose:
|
||||
print('*************** Processing %d-th sample **************' % (idx))
|
||||
entry = process_example(processor, entry, tables[entry['db_id']], trans, verbose=verbose)
|
||||
processed_dataset.append(entry)
|
||||
print('In total, process %d samples , skip %d extremely large databases.' % (len(processed_dataset), len(dataset) - len(processed_dataset)))
|
||||
if output_path is not None:
|
||||
# serialize preprocessed dataset
|
||||
pickle.dump(processed_dataset, open(output_path, 'wb'))
|
||||
return processed_dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--db_dir', type=str, default='data/database')
|
||||
arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path')
|
||||
arg_parser.add_argument('--raw_table_path', type=str, help='raw tables path')
|
||||
arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path')
|
||||
arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset')
|
||||
arg_parser.add_argument('--skip_large', action='store_true', help='whether skip large databases')
|
||||
arg_parser.add_argument('--verbose', action='store_true', help='whether print processing information')
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
processor = Preprocessor(db_dir=args.db_dir, db_content=True)
|
||||
# loading database and dataset
|
||||
if args.raw_table_path:
|
||||
# need to preprocess database items
|
||||
tables_list = json.load(open(args.raw_table_path, 'r'))
|
||||
print('Firstly, preprocess the original databases ...')
|
||||
start_time = time.time()
|
||||
tables = process_tables(processor, tables_list, args.table_path, args.verbose)
|
||||
print('Databases preprocessing costs %.4fs .' % (time.time() - start_time))
|
||||
else:
|
||||
tables = pickle.load(open(args.table_path, 'rb'))
|
||||
dataset = json.load(open(args.dataset_path, 'r'))
|
||||
start_time = time.time()
|
||||
dataset = process_dataset(processor, dataset, tables, args.output_path, args.skip_large, verbose=args.verbose)
|
||||
print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time))
|
|
@ -0,0 +1,37 @@
|
|||
#coding=utf8
|
||||
import os, json, pickle, argparse, sys, time
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from preprocess.graph_utils import GraphProcessor
|
||||
|
||||
def process_dataset_graph(processor, dataset, tables, method, output_path=None, skip_large=False):
|
||||
processed_dataset = []
|
||||
for idx, entry in enumerate(dataset):
|
||||
db = tables[entry['db_id']]
|
||||
if skip_large and len(db['column_names']) > 100:
|
||||
continue
|
||||
if (idx + 1) % 500 == 0:
|
||||
print('Processing the %d-th example ...' % (idx + 1))
|
||||
entry = processor.process_graph_utils(entry, db, method=method)
|
||||
processed_dataset.append(entry)
|
||||
print('In total, process %d samples, skip %d samples .' % (len(processed_dataset), len(dataset) - len(processed_dataset)))
|
||||
if output_path is not None:
|
||||
# serialize preprocessed dataset
|
||||
pickle.dump(processed_dataset, open(output_path, 'wb'))
|
||||
return processed_dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--dataset_path', type=str, required=True, help='dataset path')
|
||||
arg_parser.add_argument('--table_path', type=str, required=True, help='processed table path')
|
||||
arg_parser.add_argument('--method', type=str, default='lgesql', choices=['rgatsql', 'lgesql'])
|
||||
arg_parser.add_argument('--output_path', type=str, required=True, help='output preprocessed dataset')
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
processor = GraphProcessor()
|
||||
# loading database and dataset
|
||||
tables = pickle.load(open(args.table_path, 'rb'))
|
||||
dataset = pickle.load(open(args.dataset_path, 'rb'))
|
||||
start_time = time.time()
|
||||
dataset = process_dataset_graph(processor, dataset, tables, args.method, args.output_path)
|
||||
print('Dataset preprocessing costs %.4fs .' % (time.time() - start_time))
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf8
|
||||
# import os
|
||||
# import numpy as np
|
||||
# import torch
|
||||
# import torch.nn.functional as F
|
||||
# from transformers import AutoModel, AutoConfig, AutoTokenizer
|
||||
# import matplotlib.pyplot as plt
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
def agg(input):
|
||||
return torch.sum(input, dim=1, keepdim=True) / input.size(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
stopwords = stopwords.words("english")
|
||||
b='."?,'
|
||||
a='the'
|
||||
print(a in stopwords)
|
||||
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
#coding=utf8
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import stanza, torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
||||
import matplotlib.pyplot as plt
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
def agg(input):
|
||||
assert input.size(1)>=2
|
||||
input=input[:,1:]
|
||||
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
|
||||
|
||||
if __name__=='__main__':
|
||||
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
|
||||
stopwords = stopwords.words("english")
|
||||
a='.'
|
||||
print(a in stopwords)
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size=1024
|
||||
max_batch_size = 8
|
||||
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt')).to(device)
|
||||
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt'))
|
||||
config = plm_model.config
|
||||
raw_question_toks=['what', 'is', 'the', 'number', 'of', 'employees', '?']
|
||||
column_names=['*', 'flight number', 'origin', 'destination', 'distance', 'departure date', 'arrival date', 'price', 'airline id', 'airline id', 'name', 'distance', 'employee id', 'name', 'salary', 'employee id', 'airline id']
|
||||
table_names=['flight', 'aircraft', 'employee', 'certificate']
|
||||
question = " ".join(raw_question_toks)
|
||||
doc = nlp(question)
|
||||
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
print('Q', len(raw_question_toks), raw_question_toks)
|
||||
print('C', len(column_names), column_names)
|
||||
print('T', len(table_names), table_names)
|
||||
|
||||
question_id = [plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in raw_question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) #+ [0]
|
||||
#question_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
masked_question_id = [question_id]
|
||||
|
||||
start=1
|
||||
|
||||
for i,sub_len in enumerate(question_subword_len):
|
||||
tmp_question_id=question_id.copy()
|
||||
for m in range(start,start+sub_len):
|
||||
tmp_question_id[m]=plm_tokenizer.mask_token_id
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start+=sub_len
|
||||
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
#l = 1
|
||||
toks = [plm_tokenizer.sep_token_id]
|
||||
for w in s:
|
||||
sub_toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
toks.extend(sub_toks)
|
||||
#table_subword_len.append(len(toks))
|
||||
#l += len(sub_toks)
|
||||
table_id.extend(toks)
|
||||
table_word_len.append(len(toks))
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
#l = 1
|
||||
toks = [plm_tokenizer.sep_token_id]
|
||||
for w in s:
|
||||
sub_toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
toks.extend(sub_toks)
|
||||
#column_subword_len.append(len(toks))
|
||||
#l += len(sub_toks)
|
||||
column_id.extend(toks)
|
||||
column_word_len.append(len(toks))
|
||||
column_mask_plm = [1] * len(column_id) #+ [0]
|
||||
#exact_column_token=len(column_id)
|
||||
#column_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
#print(total_size)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
print(outputs.size())
|
||||
for t in store_arr[1:]:
|
||||
print(t.size())
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
|
||||
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
|
||||
start=0
|
||||
new_table_arr = []
|
||||
for i,sub_len in enumerate(table_word_len):
|
||||
curr=old_tables[:, start:start+sub_len]
|
||||
new_table_arr.append(agg(curr))
|
||||
start+=sub_len
|
||||
new_tables = torch.cat(new_table_arr, 1)
|
||||
tbl_cmp=new_tables[0:1]
|
||||
tbl_msk=new_tables[1:]
|
||||
assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(table_word_len)):
|
||||
a=tbl_cmp[:,i]
|
||||
b=tbl_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_tab_mat[:,i]=dis
|
||||
#print(q_tab_mat.transpose(0,1).size())
|
||||
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_tab_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
|
||||
|
||||
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(column_id), hidden_size)
|
||||
new_column_arr = []
|
||||
start=0
|
||||
for i,sub_len in enumerate(column_word_len):
|
||||
curr=old_columns[:,start:start+sub_len]
|
||||
new_column_arr.append(agg(curr))
|
||||
start+=sub_len
|
||||
new_column = torch.cat(new_column_arr, 1)
|
||||
|
||||
col_cmp=new_column[0:1]
|
||||
col_msk=new_column[1:]
|
||||
assert col_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(column_word_len)):
|
||||
a=col_cmp[:,i]
|
||||
b=col_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_col_mat[:,i]=dis
|
||||
#print(dis)
|
||||
#print(q_col_mat.transpose(0,1).size())
|
||||
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_col_mat.transpose(0,1))
|
||||
|
||||
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
|
||||
matrix_min=torch.min(use_matrix)
|
||||
print(matrix_min)
|
||||
matrix_max=torch.max(use_matrix)
|
||||
print(matrix_max)
|
||||
matrix_mean=torch.mean(use_matrix)
|
||||
print(matrix_mean)
|
||||
matrix_var=torch.sqrt(torch.mean(use_matrix))
|
||||
print(matrix_var)
|
||||
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
|
||||
#use_matrix=(use_matrix-matrix_mean)/matrix_var
|
||||
print(use_matrix.size())
|
||||
print(use_matrix)
|
||||
|
||||
|
||||
# vegetables = table_names+column_names
|
||||
# farmers = raw_question_toks
|
||||
|
||||
vegetables = raw_question_toks
|
||||
farmers = table_names+column_names
|
||||
|
||||
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(15,15))
|
||||
im = ax.imshow(harvest,cmap="YlGn")
|
||||
#threshold = im.norm(use_matrix.max())/2
|
||||
# We want to show all ticks...
|
||||
ax.set_xticks(np.arange(len(farmers)))
|
||||
ax.set_yticks(np.arange(len(vegetables)))
|
||||
# ... and label them with the respective list entries
|
||||
ax.set_xticklabels(farmers)
|
||||
ax.set_yticklabels(vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(vegetables)):
|
||||
for j in range(len(farmers)):
|
||||
text = ax.text(j, i, harvest[i, j],
|
||||
ha="center", va="center", color="w")
|
||||
|
||||
ax.set_title("Probing Matrix")
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig('1.jpg')
|
||||
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
#coding=utf8
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import stanza, torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def agg(input):
|
||||
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
|
||||
|
||||
if __name__=='__main__':
|
||||
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
|
||||
device = torch.device("cuda:1")
|
||||
hidden_size=1024
|
||||
max_batch_size = 8
|
||||
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
|
||||
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
|
||||
config = plm_model.config
|
||||
raw_question_toks=['How', 'many', 'flights', 'depart', 'from', 'Jackson']
|
||||
column_names=['*', 'airline id', 'airline name', 'abbreviation', 'country', 'city', 'airport code', 'airport name', 'country', 'country abbrev', 'airline', 'flight number', 'source airport', 'destination airport']
|
||||
table_names=['airlines', 'airports', 'flights']
|
||||
question = " ".join(raw_question_toks)
|
||||
doc = nlp(question)
|
||||
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
|
||||
print('Q', len(raw_question_toks), raw_question_toks)
|
||||
print('C', len(column_names), column_names)
|
||||
print('T', len(table_names), table_names)
|
||||
|
||||
question_id = [plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in raw_question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
|
||||
exact_question_token=len(question_id) - 1
|
||||
question_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
masked_question_id = [question_id]
|
||||
|
||||
start=1
|
||||
|
||||
for i,sub_len in enumerate(question_subword_len):
|
||||
tmp_question_id=question_id.copy()
|
||||
for m in range(start,start+sub_len):
|
||||
tmp_question_id[m]=plm_tokenizer.mask_token_id
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start+=sub_len
|
||||
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
table_id.extend(toks)
|
||||
table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
table_word_len.append(l)
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
column_id.extend(toks)
|
||||
column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
column_word_len.append(l)
|
||||
column_mask_plm = [1] * len(column_id) + [0]
|
||||
exact_column_token=len(column_id)
|
||||
column_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
#print(total_size)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
#print(outputs.size())
|
||||
for t in store_arr[1:]:
|
||||
#print(t.size())
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
|
||||
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
|
||||
start=0
|
||||
new_table_arr = []
|
||||
maybe_use=[]
|
||||
for i,sub_len in enumerate(table_word_len):
|
||||
curr=old_tables[:, start:start+sub_len]
|
||||
new_table_arr.append(agg(curr))
|
||||
maybe_use.append(curr)
|
||||
start+=sub_len
|
||||
new_tables = torch.cat(new_table_arr, 1)
|
||||
tbl_cmp=new_tables[0:1]
|
||||
tbl_msk=new_tables[1:]
|
||||
assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(table_word_len)):
|
||||
a=tbl_cmp[:,i] #1,hidden
|
||||
b=tbl_msk[:,i] #qu_len,hidden
|
||||
#maybe_use[i] qu_len,sublen,hidden
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_tab_mat[:,i]=dis
|
||||
# for m in range(maybe_use[i].size(1)):
|
||||
# tmp=maybe_use[i][1:,m]
|
||||
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
|
||||
# for k in range(pre_dis.size(0)):
|
||||
# if pre_dis[k]>q_tab_mat[k,i]:
|
||||
# q_tab_mat[k,i]=pre_dis[k]
|
||||
|
||||
|
||||
#q_tab_mat[:,i]=pre_dis
|
||||
#print(q_tab_mat.transpose(0,1).size())
|
||||
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_tab_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
|
||||
|
||||
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
|
||||
new_column_arr = []
|
||||
start=0
|
||||
maybe_use=[]
|
||||
for i,sub_len in enumerate(column_word_len):
|
||||
curr=old_columns[:,start:start+sub_len]
|
||||
new_column_arr.append(agg(curr))
|
||||
maybe_use.append(curr)
|
||||
start+=sub_len
|
||||
new_column = torch.cat(new_column_arr, 1)
|
||||
|
||||
col_cmp=new_column[0:1]
|
||||
col_msk=new_column[1:]
|
||||
assert col_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(column_word_len)):
|
||||
a=col_cmp[:,i]
|
||||
b=col_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_col_mat[:,i]=dis
|
||||
# for m in range(maybe_use[i].size(1)):
|
||||
# tmp=maybe_use[i][1:,m]
|
||||
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
|
||||
# for k in range(pre_dis.size(0)):
|
||||
# if pre_dis[k]>q_col_mat[k,i]:
|
||||
# q_col_mat[k,i]=pre_dis[k]
|
||||
#print(dis)
|
||||
#print(q_col_mat.transpose(0,1).size())
|
||||
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_col_mat.transpose(0,1))
|
||||
|
||||
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
|
||||
matrix_min=torch.min(use_matrix)
|
||||
#print(matrix_min)
|
||||
matrix_max=torch.max(use_matrix)
|
||||
#print(matrix_max)
|
||||
matrix_mean=torch.mean(use_matrix)
|
||||
#print(matrix_mean)
|
||||
matrix_var=torch.sqrt(torch.mean(use_matrix))
|
||||
#print(matrix_var)
|
||||
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
|
||||
#use_matrix=(use_matrix-matrix_mean)/matrix_var
|
||||
#print(use_matrix.size())
|
||||
#print(use_matrix)
|
||||
|
||||
|
||||
# vegetables = table_names+column_names
|
||||
# farmers = raw_question_toks
|
||||
|
||||
vegetables = raw_question_toks
|
||||
farmers = table_names+column_names
|
||||
|
||||
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
|
||||
print()
|
||||
fig, ax = plt.subplots(figsize=(25,25))
|
||||
im = ax.imshow(harvest,cmap="YlGn")
|
||||
#threshold = im.norm(use_matrix.max())/2
|
||||
# We want to show all ticks...
|
||||
ax.set_xticks(np.arange(len(farmers)))
|
||||
ax.set_yticks(np.arange(len(vegetables)))
|
||||
# ... and label them with the respective list entries
|
||||
ax.set_xticklabels(farmers)
|
||||
ax.set_yticklabels(vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(vegetables)):
|
||||
for j in range(len(farmers)):
|
||||
text = ax.text(j, i, harvest[i, j],
|
||||
ha="center", va="center", color="w")
|
||||
|
||||
ax.set_title("Probing Matrix")
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig('1.jpg')
|
||||
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
#coding=utf8
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import stanza, torch
|
||||
import torch.nn.functional as F
|
||||
from itertools import product, combinations
|
||||
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def agg(input):
|
||||
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
|
||||
|
||||
if __name__=='__main__':
|
||||
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma')#, use_gpu=False)
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size=1024
|
||||
max_batch_size = 8
|
||||
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt')).to(device)
|
||||
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'grappa_large_jnt'))
|
||||
config = plm_model.config
|
||||
raw_question_toks=['list', 'the', 'most', 'common', 'result', 'of', 'the', 'musicals', '.']
|
||||
column_names=['*', 'musical id', 'name', 'year', 'award', 'category', 'nominee', 'result', 'actor id', 'name', 'musical id', 'character', 'duration', 'age']
|
||||
table_names=['musical', 'actor']
|
||||
question = " ".join(raw_question_toks)
|
||||
doc = nlp(question)
|
||||
raw_question_toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
|
||||
q_num, t_num, c_num = len(raw_question_toks), len(table_names), len(column_names)
|
||||
t_q_dict = {}
|
||||
c_q_dict = {}
|
||||
max_len = max([len(t.split()) for t in table_names])
|
||||
index_pairs = list(filter(lambda x: x[1] - x[0] <= max_len, combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(raw_question_toks[i: j])
|
||||
for idx, name in enumerate(table_names):
|
||||
if phrase == name:
|
||||
if idx in t_q_dict.keys():
|
||||
t_q_dict[idx].append([phrase,i,j])
|
||||
else:
|
||||
t_q_dict[idx]=[[phrase,i,j]]
|
||||
elif (j - i == 1 and phrase in name.split()) or (j - i > 1 and phrase in name):
|
||||
if idx in t_q_dict.keys():
|
||||
t_q_dict[idx].append([phrase,i,j])
|
||||
else:
|
||||
t_q_dict[idx]=[[phrase,i,j]]
|
||||
print(t_q_dict)
|
||||
|
||||
print('Q', len(raw_question_toks), raw_question_toks)
|
||||
print('C', len(column_names), column_names)
|
||||
print('T', len(table_names), table_names)
|
||||
|
||||
question_id = [plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in raw_question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
|
||||
exact_question_token=len(question_id) - 1
|
||||
question_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
masked_question_id = [question_id]
|
||||
|
||||
start=1
|
||||
|
||||
for i,sub_len in enumerate(question_subword_len):
|
||||
tmp_question_id=question_id.copy()
|
||||
for m in range(start,start+sub_len):
|
||||
tmp_question_id[m]=plm_tokenizer.mask_token_id
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start+=sub_len
|
||||
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
table_id.extend(toks)
|
||||
table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
table_word_len.append(l)
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
column_id.extend(toks)
|
||||
column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
column_word_len.append(l)
|
||||
column_mask_plm = [1] * len(column_id) + [0]
|
||||
exact_column_token=len(column_id)
|
||||
column_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
#print(total_size)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
#print(outputs.size())
|
||||
for t in store_arr[1:]:
|
||||
#print(t.size())
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
|
||||
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
|
||||
start=0
|
||||
new_table_arr = []
|
||||
maybe_use=[]
|
||||
for i,sub_len in enumerate(table_word_len):
|
||||
curr=old_tables[:, start:start+sub_len]
|
||||
new_table_arr.append(agg(curr))
|
||||
maybe_use.append(curr)
|
||||
start+=sub_len
|
||||
new_tables = torch.cat(new_table_arr, 1)
|
||||
tbl_cmp=new_tables[0:1]
|
||||
tbl_msk=new_tables[1:]
|
||||
assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(table_word_len)):
|
||||
a=tbl_cmp[:,i] #1,hidden
|
||||
b=tbl_msk[:,i] #qu_len,hidden
|
||||
#maybe_use[i] qu_len,sublen,hidden
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_tab_mat[:,i]=dis
|
||||
# for m in range(maybe_use[i].size(1)):
|
||||
# tmp=maybe_use[i][1:,m]
|
||||
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
|
||||
# for k in range(pre_dis.size(0)):
|
||||
# if pre_dis[k]>q_tab_mat[k,i]:
|
||||
# q_tab_mat[k,i]=pre_dis[k]
|
||||
|
||||
|
||||
#q_tab_mat[:,i]=pre_dis
|
||||
#print(q_tab_mat.transpose(0,1).size())
|
||||
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_tab_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
|
||||
|
||||
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
|
||||
new_column_arr = []
|
||||
start=0
|
||||
maybe_use=[]
|
||||
for i,sub_len in enumerate(column_word_len):
|
||||
curr=old_columns[:,start:start+sub_len]
|
||||
new_column_arr.append(agg(curr))
|
||||
maybe_use.append(curr)
|
||||
start+=sub_len
|
||||
new_column = torch.cat(new_column_arr, 1)
|
||||
|
||||
col_cmp=new_column[0:1]
|
||||
col_msk=new_column[1:]
|
||||
assert col_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(column_word_len)):
|
||||
a=col_cmp[:,i]
|
||||
b=col_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_col_mat[:,i]=dis
|
||||
# for m in range(maybe_use[i].size(1)):
|
||||
# tmp=maybe_use[i][1:,m]
|
||||
# pre_dis=F.pairwise_distance(a,tmp,p=2) #qu_len
|
||||
# for k in range(pre_dis.size(0)):
|
||||
# if pre_dis[k]>q_col_mat[k,i]:
|
||||
# q_col_mat[k,i]=pre_dis[k]
|
||||
#print(dis)
|
||||
#print(q_col_mat.transpose(0,1).size())
|
||||
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_col_mat.transpose(0,1))
|
||||
|
||||
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
|
||||
matrix_min=torch.min(use_matrix)
|
||||
#print(matrix_min)
|
||||
matrix_max=torch.max(use_matrix)
|
||||
#print(matrix_max)
|
||||
matrix_mean=torch.mean(use_matrix)
|
||||
#print(matrix_mean)
|
||||
matrix_var=torch.sqrt(torch.mean(use_matrix))
|
||||
#print(matrix_var)
|
||||
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
|
||||
#use_matrix=(use_matrix-matrix_mean)/matrix_var
|
||||
#print(use_matrix.size())
|
||||
#print(use_matrix)
|
||||
|
||||
|
||||
# vegetables = table_names+column_names
|
||||
# farmers = raw_question_toks
|
||||
|
||||
vegetables = raw_question_toks
|
||||
farmers = table_names+column_names
|
||||
|
||||
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
|
||||
print()
|
||||
fig, ax = plt.subplots(figsize=(15,15))
|
||||
im = ax.imshow(harvest,cmap="YlGn")
|
||||
#threshold = im.norm(use_matrix.max())/2
|
||||
# We want to show all ticks...
|
||||
ax.set_xticks(np.arange(len(farmers)))
|
||||
ax.set_yticks(np.arange(len(vegetables)))
|
||||
# ... and label them with the respective list entries
|
||||
ax.set_xticklabels(farmers)
|
||||
ax.set_yticklabels(vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(vegetables)):
|
||||
for j in range(len(farmers)):
|
||||
text = ax.text(j, i, harvest[i, j],
|
||||
ha="center", va="center", color="w")
|
||||
|
||||
ax.set_title("Probing Matrix")
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig('1.jpg')
|
||||
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
#coding=utf8
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoConfig,AutoTokenizer
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def agg(input):
|
||||
return torch.sum(input,dim=1,keepdim=True)/input.size(1)
|
||||
|
||||
if __name__=='__main__':
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
hidden_size=512
|
||||
max_batch_size = 8
|
||||
plm_model = AutoModel.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator')).to(device)
|
||||
plm_tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', 'electra-large-discriminator'))
|
||||
config = plm_model.config
|
||||
raw_question_toks=['list', 'the', 'most', 'common', 'result', 'of', 'the', 'musicals', '.']
|
||||
column_names=['*', 'musical id', 'name', 'year', 'award', 'category', 'nominee', 'result', 'actor id', 'name', 'musical id', 'character', 'duration', 'age']
|
||||
table_names=['musical', 'actor']
|
||||
|
||||
print('Q', len(raw_question_toks), raw_question_toks)
|
||||
print('C', len(column_names), column_names)
|
||||
print('T', len(table_names), table_names)
|
||||
|
||||
question_id = [plm_tokenizer.cls_token_id]
|
||||
question = [q.lower() for q in raw_question_toks]
|
||||
question_subword_len = []
|
||||
for w in question:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
question_id.extend(toks)
|
||||
question_subword_len.append(len(toks))
|
||||
question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]
|
||||
exact_question_token=len(question_id) - 1
|
||||
question_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
masked_question_id = [question_id]
|
||||
|
||||
start=1
|
||||
|
||||
for i,sub_len in enumerate(question_subword_len):
|
||||
tmp_question_id=question_id.copy()
|
||||
for m in range(start,start+sub_len):
|
||||
tmp_question_id[m]=plm_tokenizer.mask_token_id
|
||||
masked_question_id.append(tmp_question_id)
|
||||
start+=sub_len
|
||||
|
||||
table = [t.lower().split() for t in table_names]
|
||||
table_id, table_mask_plm, table_subword_len = [], [], []
|
||||
table_word_len = []
|
||||
for s in table:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
table_id.extend(toks)
|
||||
table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
table_word_len.append(l)
|
||||
table_mask_plm = [1] * len(table_id)
|
||||
|
||||
column = [t.lower().split() for t in column_names]
|
||||
column_id, column_mask_plm, column_subword_len = [], [], []
|
||||
column_word_len = []
|
||||
for s in column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = plm_tokenizer.convert_tokens_to_ids(plm_tokenizer.tokenize(w))
|
||||
column_id.extend(toks)
|
||||
column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
column_word_len.append(l)
|
||||
column_mask_plm = [1] * len(column_id) + [0]
|
||||
exact_column_token=len(column_id)
|
||||
column_id.append(plm_tokenizer.sep_token_id)
|
||||
|
||||
question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
|
||||
table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
|
||||
column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm
|
||||
|
||||
input_id=[]
|
||||
segment_id=[]
|
||||
atten_mask=[]
|
||||
for i, msk_q_id in enumerate(masked_question_id):
|
||||
input_id.append(msk_q_id + table_id + column_id)
|
||||
segment_id.append([0] * len(msk_q_id) + [1] * (len(table_id) + len(column_id)))
|
||||
atten_mask.append([1]*len(input_id[-1]))
|
||||
|
||||
start=0
|
||||
total_size=len(input_id)
|
||||
#print(total_size)
|
||||
store_arr=[]
|
||||
if total_size <= max_batch_size:
|
||||
ii=torch.tensor(input_id, dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask, dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id, dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0].squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
while start < len(input_id):
|
||||
if start + max_batch_size <= len(input_id):
|
||||
ii=torch.tensor(input_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : start + max_batch_size], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : start + max_batch_size], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
else:
|
||||
ii=torch.tensor(input_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
im=torch.tensor(atten_mask[start : len(input_id)], dtype=torch.float, device=device)
|
||||
si=torch.tensor(segment_id[start : len(input_id)], dtype=torch.long, device=device)
|
||||
outputs = plm_model(ii,im,si)[0]#.squeeze()
|
||||
store_arr.append(outputs)
|
||||
#print(outputs.size())
|
||||
start+=max_batch_size
|
||||
assert len(store_arr)>0
|
||||
if len(store_arr)==1:
|
||||
outputs=store_arr[0]
|
||||
else:
|
||||
outputs=store_arr[0]
|
||||
print(outputs.size())
|
||||
for t in store_arr[1:]:
|
||||
print(t.size())
|
||||
outputs= torch.cat((outputs, t), dim=0)
|
||||
q_tab_mat = outputs.new_zeros(len(raw_question_toks),len(table_names))
|
||||
|
||||
old_tables = outputs.masked_select(torch.tensor(table_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),len(table_id), hidden_size)
|
||||
|
||||
start=0
|
||||
new_table_arr = []
|
||||
for i,sub_len in enumerate(table_word_len):
|
||||
curr=old_tables[:, start:start+sub_len]
|
||||
new_table_arr.append(agg(curr))
|
||||
start+=sub_len
|
||||
new_tables = torch.cat(new_table_arr, 1)
|
||||
tbl_cmp=new_tables[0:1]
|
||||
tbl_msk=new_tables[1:]
|
||||
assert tbl_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(table_word_len)):
|
||||
a=tbl_cmp[:,i]
|
||||
b=tbl_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_tab_mat[:,i]=dis
|
||||
#print(q_tab_mat.transpose(0,1).size())
|
||||
#print(q_tab_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_tab_mat.transpose(0,1))
|
||||
|
||||
|
||||
|
||||
|
||||
q_col_mat = outputs.new_zeros(len(raw_question_toks),len(column_names))
|
||||
|
||||
old_columns = outputs.masked_select(torch.tensor(column_mask_plm, dtype=torch.bool, device=device).unsqueeze(-1).unsqueeze(0).repeat(outputs.size(0),1,1)).view(outputs.size(0),exact_column_token, hidden_size)
|
||||
new_column_arr = []
|
||||
start=0
|
||||
for i,sub_len in enumerate(column_word_len):
|
||||
curr=old_columns[:,start:start+sub_len]
|
||||
new_column_arr.append(agg(curr))
|
||||
start+=sub_len
|
||||
new_column = torch.cat(new_column_arr, 1)
|
||||
|
||||
col_cmp=new_column[0:1]
|
||||
col_msk=new_column[1:]
|
||||
assert col_msk.size(0) == len(raw_question_toks)
|
||||
for i in range(len(column_word_len)):
|
||||
a=col_cmp[:,i]
|
||||
b=col_msk[:,i]
|
||||
dis=F.pairwise_distance(a,b,p=2)
|
||||
q_col_mat[:,i]=dis
|
||||
#print(dis)
|
||||
#print(q_col_mat.transpose(0,1).size())
|
||||
#print(q_col_mat.transpose(0,1).cpu().detach().numpy())
|
||||
print(q_col_mat.transpose(0,1))
|
||||
|
||||
use_matrix = torch.cat([q_tab_mat,q_col_mat], dim=1)
|
||||
matrix_min=torch.min(use_matrix)
|
||||
print(matrix_min)
|
||||
matrix_max=torch.max(use_matrix)
|
||||
print(matrix_max)
|
||||
matrix_mean=torch.mean(use_matrix)
|
||||
print(matrix_mean)
|
||||
matrix_var=torch.sqrt(torch.mean(use_matrix))
|
||||
print(matrix_var)
|
||||
use_matrix=(use_matrix-matrix_min)/(matrix_max-matrix_min)
|
||||
#use_matrix=(use_matrix-matrix_mean)/matrix_var
|
||||
print(use_matrix.size())
|
||||
print(use_matrix)
|
||||
|
||||
|
||||
# vegetables = table_names+column_names
|
||||
# farmers = raw_question_toks
|
||||
|
||||
vegetables = raw_question_toks
|
||||
farmers = table_names+column_names
|
||||
|
||||
harvest = np.around(use_matrix.cpu().detach().numpy(),3)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(15,15))
|
||||
im = ax.imshow(harvest,cmap="YlGn")
|
||||
#threshold = im.norm(use_matrix.max())/2
|
||||
# We want to show all ticks...
|
||||
ax.set_xticks(np.arange(len(farmers)))
|
||||
ax.set_yticks(np.arange(len(vegetables)))
|
||||
# ... and label them with the respective list entries
|
||||
ax.set_xticklabels(farmers)
|
||||
ax.set_yticklabels(vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(vegetables)):
|
||||
for j in range(len(farmers)):
|
||||
text = ax.text(j, i, harvest[i, j],
|
||||
ha="center", va="center", color="w")
|
||||
|
||||
ax.set_title("Probing Matrix")
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig('1.jpg')
|
||||
|
||||
|
|
@ -0,0 +1,665 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from nltk import word_tokenize
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema):
|
||||
self._schema = schema
|
||||
self._idMap = self._map(self._schema)
|
||||
|
||||
@property
|
||||
def schema(self): # map lowercased raw tab name to lowercased raw col name list
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self): # map tab name to __tab__, tab.col to __tab.col__, * to __all__, all lowercased
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema):
|
||||
idMap = {'*': "__all__"}
|
||||
id = 1
|
||||
for key, vals in schema.items():
|
||||
for val in vals:
|
||||
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
|
||||
id += 1
|
||||
|
||||
for key in schema:
|
||||
idMap[key.lower()] = "__" + key.lower() + "__"
|
||||
id += 1
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schema(db):
|
||||
"""
|
||||
Get database's schema, which is a dict with table name as key
|
||||
and list of column names as value
|
||||
:param db: database path
|
||||
:return: schema dict
|
||||
"""
|
||||
|
||||
schema = {}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# fetch table names
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [str(table[0].lower()) for table in cursor.fetchall()]
|
||||
|
||||
# fetch table info
|
||||
for table in tables:
|
||||
cursor.execute("PRAGMA table_info({})".format(table))
|
||||
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_schema_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
|
||||
schema = {}
|
||||
for entry in data:
|
||||
table = str(entry['table'].lower())
|
||||
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
|
||||
schema[table] = cols
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def tokenize(string):
|
||||
string = str(string)
|
||||
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
|
||||
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
|
||||
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
|
||||
|
||||
# keep string value as token
|
||||
vals = {}
|
||||
for i in range(len(quote_idxs)-1, -1, -2):
|
||||
qidx1 = quote_idxs[i-1]
|
||||
qidx2 = quote_idxs[i]
|
||||
val = string[qidx1: qidx2+1]
|
||||
key = "__val_{}_{}__".format(qidx1, qidx2)
|
||||
string = string[:qidx1] + key + string[qidx2+1:]
|
||||
vals[key] = val
|
||||
|
||||
toks = [word.lower() for word in word_tokenize(string)]
|
||||
# replace with string value token
|
||||
for i in range(len(toks)):
|
||||
if toks[i] in vals:
|
||||
toks[i] = vals[toks[i]]
|
||||
|
||||
# find if there exists !=, >=, <=
|
||||
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
|
||||
eq_idxs.reverse()
|
||||
prefix = ('!', '>', '<')
|
||||
for eq_idx in eq_idxs:
|
||||
pre_tok = toks[eq_idx-1]
|
||||
if pre_tok in prefix:
|
||||
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def scan_alias(toks):
|
||||
"""Scan the index of 'as' and build the map for all alias"""
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
alias[toks[idx+1]] = toks[idx-1] # overwritten problem
|
||||
return alias
|
||||
|
||||
def check_contradiction(toks):
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
a = toks[idx+1]
|
||||
if a in alias and alias[a] != toks[idx-1]:
|
||||
return True
|
||||
alias[a] = toks[idx-1]
|
||||
return False
|
||||
|
||||
def toks2nested(toks):
|
||||
"""
|
||||
Determine the scope for each sub-sql
|
||||
mapping [select, count, (, c1, ), from, (, select c1, c2, from, t, ), ... ] into
|
||||
[select, count, (, c1, ), from, [select, c1, c2, from, t], ... ]
|
||||
"""
|
||||
def detect_sql(idx):
|
||||
count, sql_list = 0, []
|
||||
while idx < len(toks):
|
||||
if toks[idx] == '(':
|
||||
count += 1
|
||||
if toks[idx + 1] == 'select':
|
||||
sub_sql_list, idx = detect_sql(idx + 1)
|
||||
count -= 1
|
||||
sql_list.append(sub_sql_list)
|
||||
else:
|
||||
sql_list.append('(')
|
||||
elif toks[idx] == ')':
|
||||
count -= 1
|
||||
if count < 0:
|
||||
return sql_list, idx
|
||||
else:
|
||||
sql_list.append(')')
|
||||
else:
|
||||
sql_list.append(toks[idx])
|
||||
idx += 1
|
||||
return sql_list, idx
|
||||
|
||||
def intersect_union_except(tok_list):
|
||||
for idx, tok in enumerate(tok_list):
|
||||
if type(tok) == list:
|
||||
new_tok = intersect_union_except(tok)
|
||||
tok_list[idx] = new_tok
|
||||
for op in ['intersect', 'union', 'except']:
|
||||
if op in tok_list:
|
||||
idx = tok_list.index(op)
|
||||
tok_list = [tok_list[:idx]] + [op] + [tok_list[idx+1:]]
|
||||
break
|
||||
return tok_list
|
||||
|
||||
try:
|
||||
nested_toks, _ = detect_sql(0)
|
||||
# not all sqls are wrapped with (), e.g. sql1 intersect sql2
|
||||
# add wrapper for each sql on the left and right handside of intersect/union/except
|
||||
nested_toks = intersect_union_except(nested_toks)
|
||||
return nested_toks
|
||||
except:
|
||||
print('Something unknown happened when transforming %s' % (' '.join(toks)))
|
||||
return None
|
||||
|
||||
def reassign_table_alias(nested_toks, index):
|
||||
current_map = {} # map old alias in the current sql to new alias in global map
|
||||
as_idxs = [idx for idx, tok in enumerate(nested_toks) if tok == 'as']
|
||||
for idx in as_idxs:
|
||||
index += 1 # add 1 to global index for table alias before assignment
|
||||
assert nested_toks[idx+1] not in current_map
|
||||
current_map[nested_toks[idx+1]] = 't' + str(index)
|
||||
nested_toks[idx+1] = 't' + str(index)
|
||||
for j, tok in enumerate(nested_toks):
|
||||
if type(tok) == list:
|
||||
new_tok, index = reassign_table_alias(tok, index)
|
||||
nested_toks[j] = new_tok
|
||||
elif '.' in tok:
|
||||
for alias in current_map.keys():
|
||||
if tok.startswith(alias + '.'):
|
||||
nested_toks[j] = current_map[alias] + '.' + tok[tok.index('.')+1:]
|
||||
break
|
||||
return nested_toks, index
|
||||
|
||||
def normalize_table_alias(toks):
|
||||
""" Make sure that different table alias are assigned to different tables """
|
||||
if toks.count('select') == 1:
|
||||
return toks # no nested sql, don't worry
|
||||
elif toks.count('select') > 1:
|
||||
flag = check_contradiction(toks)
|
||||
if flag: # avoid unnecessary normalization process
|
||||
nested_toks = toks2nested(toks)
|
||||
index = 0 # global index for table alias
|
||||
nested_toks, _ = reassign_table_alias(nested_toks, index)
|
||||
def flatten(x):
|
||||
if type(x) == list and ('intersect' in x or 'union' in x or 'except' in x):
|
||||
assert len(x) == 3 # sql1 union sql2
|
||||
return ['('] + flatten(x[0])[1:-1] + [x[1]] + flatten(x[2])[1:-1] + [')']
|
||||
elif type(x) == list:
|
||||
return ['('] + [y for l in x for y in flatten(l)] + [')']
|
||||
else:
|
||||
return [x]
|
||||
toks = flatten(nested_toks)[1:-1]
|
||||
return toks
|
||||
else:
|
||||
raise ValueError('Something wrong in sql because no select clause is found!')
|
||||
|
||||
|
||||
def get_tables_with_alias(schema, toks):
|
||||
toks = normalize_table_alias(toks)
|
||||
tables = scan_alias(toks)
|
||||
for key in schema:
|
||||
assert key not in tables, "Alias {} has the same name in table".format(key)
|
||||
tables[key] = key
|
||||
return tables, toks
|
||||
|
||||
|
||||
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, column id
|
||||
"""
|
||||
tok = toks[start_idx]
|
||||
if tok == "*":
|
||||
return start_idx + 1, schema.idMap[tok]
|
||||
|
||||
if '.' in tok: # if token is a composite
|
||||
alias, col = tok.split('.')
|
||||
key = tables_with_alias[alias] + "." + col
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
|
||||
|
||||
for alias in default_tables:
|
||||
table = tables_with_alias[alias]
|
||||
if tok in schema.schema[table]:
|
||||
key = table + "." + tok
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert False, "Error col: {}".format(tok)
|
||||
|
||||
|
||||
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, (agg_op id, col_id)
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
isDistinct = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
assert idx < len_ and toks[idx] == '('
|
||||
idx += 1
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert idx < len_ and toks[idx] == ')'
|
||||
idx += 1
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
agg_id = AGG_OPS.index("none")
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
|
||||
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
col_unit1 = None
|
||||
col_unit2 = None
|
||||
unit_op = UNIT_OPS.index('none')
|
||||
|
||||
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if idx < len_ and toks[idx] in UNIT_OPS:
|
||||
unit_op = UNIT_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (unit_op, col_unit1, col_unit2)
|
||||
|
||||
|
||||
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
:returns next idx, table id, table name
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
key = tables_with_alias[toks[idx]]
|
||||
|
||||
if idx + 1 < len_ and toks[idx+1] == "as":
|
||||
idx += 3
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
return idx, schema.idMap[key], key
|
||||
|
||||
|
||||
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
elif "\"" in toks[idx]: # token is a string value
|
||||
val = toks[idx]
|
||||
idx += 1
|
||||
else:
|
||||
try:
|
||||
val = float(toks[idx])
|
||||
idx += 1
|
||||
except:
|
||||
end_idx = idx
|
||||
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
|
||||
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
|
||||
end_idx += 1
|
||||
|
||||
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
|
||||
idx = end_idx
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
|
||||
return idx, val
|
||||
|
||||
|
||||
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
not_op = False
|
||||
if toks[idx] == 'not':
|
||||
not_op = True
|
||||
idx += 1
|
||||
|
||||
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
|
||||
op_id = WHERE_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
val1 = val2 = None
|
||||
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert toks[idx] == 'and'
|
||||
idx += 1
|
||||
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
else: # normal case: single value
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val2 = None
|
||||
|
||||
conds.append((not_op, op_id, val_unit, val1, val2))
|
||||
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
|
||||
break
|
||||
|
||||
if idx < len_ and toks[idx] in COND_OPS:
|
||||
conds.append(toks[idx])
|
||||
idx += 1 # skip and/or
|
||||
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
assert toks[idx] == 'select', "'select' not found"
|
||||
idx += 1
|
||||
isDistinct = False
|
||||
if idx < len_ and toks[idx] == 'distinct':
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
val_units = []
|
||||
|
||||
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
|
||||
agg_id = AGG_OPS.index("none")
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append((agg_id, val_unit))
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
|
||||
return idx, (isDistinct, val_units)
|
||||
|
||||
|
||||
def parse_from(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
Assume in the from clause, all table units are combined with join
|
||||
"""
|
||||
assert 'from' in toks[start_idx:], "'from' not found"
|
||||
|
||||
len_ = len(toks)
|
||||
idx = toks.index('from', start_idx) + 1
|
||||
default_tables = []
|
||||
table_units = []
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['sql'], sql))
|
||||
else:
|
||||
if idx < len_ and toks[idx] == 'join':
|
||||
idx += 1 # skip join
|
||||
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['table_unit'],table_unit))
|
||||
default_tables.append(table_name)
|
||||
if idx < len_ and toks[idx] == "on":
|
||||
idx += 1 # skip on
|
||||
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if len(conds) > 0:
|
||||
conds.append('and')
|
||||
conds.extend(this_conds)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
break
|
||||
|
||||
return idx, table_units, conds, default_tables
|
||||
|
||||
|
||||
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'where':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
col_units = []
|
||||
|
||||
if idx >= len_ or toks[idx] != 'group':
|
||||
return idx, col_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
col_units.append(col_unit)
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, col_units
|
||||
|
||||
|
||||
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
val_units = []
|
||||
order_type = 'asc' # default type is 'asc'
|
||||
|
||||
if idx >= len_ or toks[idx] != 'order':
|
||||
return idx, val_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append(val_unit)
|
||||
if idx < len_ and toks[idx] in ORDER_OPS:
|
||||
order_type = toks[idx]
|
||||
idx += 1
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, (order_type, val_units)
|
||||
|
||||
|
||||
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'having':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_limit(toks, start_idx):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx < len_ and toks[idx] == 'limit':
|
||||
idx += 2
|
||||
return idx, int(toks[idx-1])
|
||||
|
||||
return idx, None
|
||||
|
||||
|
||||
def parse_sql(toks, start_idx, tables_with_alias, schema):
|
||||
isBlock = False # indicate whether this is a block of sql/sub-sql
|
||||
len_ = len(toks)
|
||||
idx = start_idx
|
||||
|
||||
sql = {}
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
# parse from clause in order to get default tables
|
||||
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
|
||||
sql['from'] = {'table_units': table_units, 'conds': conds}
|
||||
# select clause
|
||||
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
|
||||
idx = from_end_idx
|
||||
sql['select'] = select_col_units
|
||||
# where clause
|
||||
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['where'] = where_conds
|
||||
# group by clause
|
||||
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['groupBy'] = group_col_units
|
||||
# having clause
|
||||
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['having'] = having_conds
|
||||
# order by clause
|
||||
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['orderBy'] = order_col_units
|
||||
# limit clause
|
||||
idx, limit_val = parse_limit(toks, idx)
|
||||
sql['limit'] = limit_val
|
||||
|
||||
idx = skip_semicolon(toks, idx)
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
idx = skip_semicolon(toks, idx)
|
||||
|
||||
# intersect/union/except clause
|
||||
for op in SQL_OPS: # initialize IUE
|
||||
sql[op] = None
|
||||
if idx < len_ and toks[idx] in SQL_OPS:
|
||||
sql_op = toks[idx]
|
||||
idx += 1
|
||||
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
sql[sql_op] = IUE_sql
|
||||
return idx, sql
|
||||
|
||||
|
||||
def load_data(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias, toks = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
|
||||
return sql
|
||||
|
||||
|
||||
def skip_semicolon(toks, start_idx):
|
||||
idx = start_idx
|
||||
while idx < len(toks) and toks[idx] == ";":
|
||||
idx += 1
|
||||
return idx
|
|
@ -0,0 +1,5 @@
|
|||
dgl-cu101==0.5.3
|
||||
embeddings==0.0.8
|
||||
transformers==3.5.1
|
||||
stanza==1.1.1
|
||||
nltk==3.5
|
|
@ -0,0 +1,825 @@
|
|||
medium pred: SELECT singer.Name , singer.Country FROM singer ORDER BY singer.Country DESC
|
||||
medium gold: SELECT name , country FROM singer ORDER BY birthday ASC
|
||||
|
||||
medium pred: SELECT singer.Name , singer.Country FROM singer ORDER BY singer.Birthday DESC
|
||||
medium gold: SELECT name , country FROM singer ORDER BY birthday ASC
|
||||
|
||||
medium pred: SELECT singer.Song_Name , singer.Song_release_year FROM singer ORDER BY singer.Birthday ASC LIMIT 1
|
||||
medium gold: SELECT song_name , song_release_year FROM singer ORDER BY birthday desc LIMIT 1
|
||||
|
||||
medium pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value"
|
||||
medium gold: SELECT DISTINCT country FROM singer WHERE birthday like '2001%'
|
||||
|
||||
medium pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value"
|
||||
medium gold: SELECT DISTINCT country FROM singer WHERE birthday like '2001%'
|
||||
|
||||
hard pred: SELECT singer.Song_Name FROM singer WHERE singer.Song_Name > ( SELECT AVG(singer.Song_Name) FROM singer )
|
||||
hard gold: SELECT song_name FROM singer WHERE birthday < (SELECT avg(birthday) FROM singer)
|
||||
|
||||
hard pred: SELECT singer.Song_Name FROM singer WHERE singer.Is_male > ( SELECT AVG(singer.Is_male) FROM singer )
|
||||
hard gold: SELECT song_name FROM singer WHERE birthday < (SELECT avg(birthday) FROM singer)
|
||||
|
||||
medium pred: SELECT AVG(stadium.Average) , MAX(stadium.Capacity) FROM stadium
|
||||
medium gold: SELECT avg(capacity) , max(capacity) FROM stadium
|
||||
|
||||
medium pred: SELECT AVG(stadium.Capacity) , MAX(stadium.Highest) FROM stadium
|
||||
medium gold: SELECT avg(capacity) , max(capacity) FROM stadium
|
||||
|
||||
medium pred: SELECT stadium.Stadium_ID , COUNT(*) FROM stadium JOIN concert GROUP BY stadium.Stadium_ID
|
||||
medium gold: SELECT T2.name , count(*) FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id GROUP BY T1.stadium_id
|
||||
|
||||
extra pred: SELECT stadium.Name , stadium.Highest FROM stadium JOIN concert WHERE concert.Year > "value" GROUP BY concert.Stadium_ID ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T2.name , T2.capacity FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.year >= 2014 GROUP BY T2.stadium_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT MAX(stadium.Highest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert )
|
||||
hard gold: SELECT highest FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
|
||||
|
||||
hard pred: SELECT MIN(stadium.Lowest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert )
|
||||
hard gold: SELECT lowest FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
|
||||
|
||||
extra pred: SELECT singer.Country FROM singer WHERE singer.Birthday = "value" OR singer.Birthday = "value"
|
||||
extra gold: SELECT country FROM singer WHERE birthday like '1981%' or birthday like '1991%'
|
||||
|
||||
hard pred: SELECT MIN(stadium.Lowest) FROM stadium WHERE stadium.Stadium_ID NOT IN ( SELECT concert.Stadium_ID FROM concert WHERE concert.Year = "value" )
|
||||
hard gold: SELECT lowest FROM stadium EXCEPT SELECT T2.lowest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.year = 2014
|
||||
|
||||
medium pred: SELECT singer.Name , COUNT(*) FROM singer_in_concert JOIN singer GROUP BY singer.Name
|
||||
medium gold: SELECT T2.name , count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id GROUP BY T2.singer_id
|
||||
|
||||
hard pred: SELECT singer.Name FROM singer_in_concert JOIN singer JOIN concert WHERE concert.Year < "value"
|
||||
hard gold: SELECT T2.name FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id JOIN concert AS T3 ON T1.concert_id = T3.concert_id WHERE T3.year <= 2014
|
||||
|
||||
extra pred: SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value" OR concert.Year = "value"
|
||||
extra gold: SELECT T2.lowest , T2.highest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2014 INTERSECT SELECT T2.name , T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2015
|
||||
|
||||
extra pred: SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value" INTERSECT SELECT stadium.Lowest , stadium.Highest FROM concert JOIN stadium WHERE concert.Year = "value"
|
||||
extra gold: SELECT T2.lowest , T2.highest FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2014 INTERSECT SELECT T2.name , T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id WHERE T1.Year = 2015
|
||||
|
||||
medium pred: SELECT Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
|
||||
medium gold: SELECT weight FROM pets ORDER BY birthdate desc LIMIT 1
|
||||
|
||||
medium pred: SELECT Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
|
||||
medium gold: SELECT weight FROM pets ORDER BY birthdate desc LIMIT 1
|
||||
|
||||
medium pred: SELECT MAX(Pets.weight) , MAX(Pets.weight) , Pets.PetType FROM Pets GROUP BY Pets.PetType
|
||||
medium gold: SELECT max(weight) , petType FROM pets GROUP BY petType
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM Student JOIN Has_Pet JOIN Student WHERE Student.Sex = "value"
|
||||
hard gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T2.petid = T3.petid WHERE T1.sex = 'F' AND T3.pettype = 'dog'
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM Has_Pet JOIN Student JOIN Has_Pet WHERE Student.Sex = "value"
|
||||
hard gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T2.petid = T3.petid WHERE T1.sex = 'F' AND T3.pettype = 'dog'
|
||||
|
||||
extra pred: SELECT Student.LName FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" INTERSECT SELECT Student.LName FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
extra gold: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog'
|
||||
|
||||
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet WHERE Has_Pet.StuID = "value"
|
||||
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet WHERE Has_Pet.StuID = "value"
|
||||
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.StuID = "value"
|
||||
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
|
||||
|
||||
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
|
||||
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
|
||||
|
||||
extra pred: SELECT Student.Fname , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
extra pred: SELECT Student.Fname FROM Student WHERE Student.StuID = "value" EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
|
||||
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
medium pred: SELECT Pets.PetType , Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
|
||||
medium gold: SELECT pettype , weight FROM pets ORDER BY birthdate desc LIMIT 1
|
||||
|
||||
medium pred: SELECT Pets.PetType , Pets.weight FROM Pets ORDER BY Pets.birthdate ASC LIMIT 1
|
||||
medium gold: SELECT pettype , weight FROM pets ORDER BY birthdate desc LIMIT 1
|
||||
|
||||
medium pred: SELECT Pets.PetID , Pets.weight FROM Pets WHERE Pets.birthdate > "value"
|
||||
medium gold: SELECT petid , weight FROM pets WHERE birthdate < '2020-01-01'
|
||||
|
||||
medium pred: SELECT Pets.PetID , Pets.weight FROM Pets WHERE Pets.birthdate > "value"
|
||||
medium gold: SELECT petid , weight FROM pets WHERE birthdate < '2020-05-01'
|
||||
|
||||
medium pred: SELECT Student.Fname , Student.LName FROM Student JOIN Has_Pet
|
||||
medium gold: SELECT DISTINCT T1.fname , T1.LName , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid
|
||||
|
||||
medium pred: SELECT Student.Fname , Student.LName FROM Student JOIN Has_Pet
|
||||
medium gold: SELECT DISTINCT T1.fname , T1.LName T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid
|
||||
|
||||
medium pred: SELECT COUNT(*) , Has_Pet.StuID FROM Has_Pet GROUP BY Has_Pet.StuID
|
||||
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
|
||||
|
||||
medium pred: SELECT Has_Pet.StuID , COUNT(*) FROM Has_Pet GROUP BY Has_Pet.StuID
|
||||
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
|
||||
|
||||
extra pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.birthdate = "value"
|
||||
extra gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.birthdate like '2001%' AND T3.pettype = 'cat'
|
||||
|
||||
extra pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.birthdate = "value"
|
||||
extra gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.birthdate like '2001%' AND T3.pettype = 'cat'
|
||||
|
||||
extra pred: SELECT AVG(Student.Age) FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet )
|
||||
extra gold: SELECT avg(age) FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid)
|
||||
|
||||
extra pred: SELECT AVG(Student.Age) FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet )
|
||||
extra gold: SELECT avg(age) FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid)
|
||||
|
||||
extra pred: SELECT car_makers.Maker FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
|
||||
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year >= '2019';
|
||||
|
||||
extra pred: SELECT car_makers.FullName FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
|
||||
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year >= '2019';
|
||||
|
||||
extra pred: SELECT car_names.Make , cars_data.Year FROM cars_data JOIN car_names ORDER BY cars_data.Year ASC LIMIT 1
|
||||
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
|
||||
|
||||
extra pred: SELECT car_makers.Maker , cars_data.Year FROM cars_data JOIN car_names JOIN model_list JOIN car_makers WHERE cars_data.Year = "value" ORDER BY cars_data.Year ASC LIMIT 1
|
||||
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
|
||||
|
||||
hard pred: SELECT car_names.Model FROM cars_data JOIN car_names WHERE cars_data.Year <= "value"
|
||||
hard gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.model = T2.model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.id WHERE T3.year <= 1980;
|
||||
|
||||
medium pred: SELECT COUNT(*) , car_makers.FullName FROM car_makers JOIN model_list GROUP BY car_makers.Maker
|
||||
medium gold: SELECT Count(*) , T2.FullName , T2.id FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id GROUP BY T2.id;
|
||||
|
||||
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
|
||||
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
|
||||
|
||||
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
|
||||
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM car_makers WHERE car_makers.Country = "value"
|
||||
medium gold: SELECT count(*) FROM CAR_MAKERS AS T1 JOIN COUNTRIES AS T2 ON T1.Country = T2.CountryId WHERE T2.CountryName = 'Japan';
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM car_makers WHERE car_makers.Country = "value"
|
||||
medium gold: SELECT count(*) FROM CAR_MAKERS AS T1 JOIN COUNTRIES AS T2 ON T1.Country = T2.CountryId WHERE T2.CountryName = 'Japan';
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
|
||||
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
|
||||
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
|
||||
|
||||
hard pred: SELECT MIN(cars_data.Weight) FROM cars_data WHERE cars_data.Year = "value" AND cars_data.Cylinders = "value"
|
||||
hard gold: SELECT Weight FROM CARS_DATA WHERE Cylinders = 4 AND YEAR = 1974 ORDER BY Weight ASC LIMIT 1;
|
||||
|
||||
hard pred: SELECT MIN(cars_data.Weight) FROM cars_data WHERE cars_data.Year = "value" AND cars_data.Cylinders = "value"
|
||||
hard gold: SELECT Weight FROM CARS_DATA WHERE Cylinders = 4 AND YEAR = 1974 ORDER BY Weight ASC LIMIT 1;
|
||||
|
||||
medium pred: SELECT car_makers.Maker , model_list.Model FROM car_makers JOIN model_list
|
||||
medium gold: SELECT Maker , Model FROM MODEL_LIST;
|
||||
|
||||
medium pred: SELECT countries.CountryName , car_makers.Id FROM car_makers JOIN countries
|
||||
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
|
||||
|
||||
medium pred: SELECT countries.CountryName , countries.CountryId FROM car_makers JOIN countries
|
||||
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
|
||||
|
||||
medium pred: SELECT AVG(cars_data.Weight) , MAX(cars_data.Weight) , cars_data.Year FROM cars_data GROUP BY cars_data.Year
|
||||
medium gold: SELECT avg(Weight) , YEAR FROM CARS_DATA GROUP BY YEAR;
|
||||
|
||||
extra pred: SELECT countries.CountryName FROM countries JOIN car_makers WHERE countries.CountryName = "value" GROUP BY countries.CountryName HAVING COUNT(*) >= "value"
|
||||
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
|
||||
|
||||
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
|
||||
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
|
||||
|
||||
extra pred: SELECT MAX(cars_data.Horsepower) , MAX(car_names.Make) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
|
||||
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
|
||||
|
||||
extra pred: SELECT MAX(car_names.Make) , MAX(cars_data.Horsepower) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
|
||||
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
|
||||
|
||||
hard pred: SELECT cars_data.MPG FROM cars_data ORDER BY cars_data.MPG DESC LIMIT 1
|
||||
hard gold: SELECT T1.Model FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id ORDER BY T2.mpg DESC LIMIT 1;
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM cars_data ORDER BY cars_data.Year DESC LIMIT 1
|
||||
easy gold: SELECT count(*) FROM CARS_DATA WHERE YEAR >= 2019;
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM cars_data ORDER BY cars_data.Year DESC LIMIT 1
|
||||
easy gold: SELECT count(*) FROM CARS_DATA WHERE YEAR >= 2019;
|
||||
|
||||
medium pred: SELECT car_makers.FullName , car_makers.Id , COUNT(*) FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
medium gold: SELECT T1.FullName , T1.Id FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) > 3;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
|
||||
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_makers JOIN cars_data WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
|
||||
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
|
||||
|
||||
medium pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight BETWEEN "value" AND "value"
|
||||
medium gold: SELECT DISTINCT T1.Year FROM CARS_DATA AS T1 WHERE T1.Weight > 3000 AND T1.weight < 4000;
|
||||
|
||||
medium pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight < "value" INTERSECT SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight > "value"
|
||||
medium gold: SELECT DISTINCT T1.Year FROM CARS_DATA AS T1 WHERE T1.Weight > 3000 AND T1.weight < 4000;
|
||||
|
||||
extra pred: SELECT cars_data.Cylinders FROM cars_data WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
|
||||
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'tesla' ORDER BY T1.accelerate ASC LIMIT 1;
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
|
||||
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
|
||||
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
|
||||
easy gold: SELECT COUNT(*) FROM ( SELECT T1.CountryId , COUNT(*) FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) > 2 );
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
|
||||
easy gold: SELECT COUNT(*) FROM ( SELECT T1.CountryId , COUNT(*) FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) > 2 );
|
||||
|
||||
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value" EXCEPT SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value"
|
||||
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
|
||||
|
||||
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders < "value"
|
||||
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
|
||||
|
||||
extra pred: SELECT MAX(cars_data.MPG) FROM cars_data WHERE cars_data.Cylinders = "value" OR cars_data.Year < "value"
|
||||
extra gold: SELECT mpg FROM CARS_DATA WHERE Cylinders = 8 OR YEAR <= 1980 ORDER BY mpg DESC LIMIT 1;
|
||||
|
||||
extra pred: SELECT MAX(cars_data.MPG) FROM cars_data WHERE cars_data.Cylinders = "value" OR cars_data.Year < "value"
|
||||
extra gold: SELECT mpg FROM CARS_DATA WHERE Cylinders = 8 OR YEAR <= 1980 ORDER BY mpg DESC LIMIT 1;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM cars_data JOIN car_names JOIN car_makers WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM cars_data JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value"
|
||||
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_names WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_names.Make = "value"
|
||||
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
|
||||
|
||||
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
|
||||
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
|
||||
|
||||
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
|
||||
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
|
||||
|
||||
extra pred: SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
|
||||
|
||||
extra pred: SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
|
||||
|
||||
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN model_list JOIN countries JOIN countries JOIN model_list WHERE model_list.Model = "value"
|
||||
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'tesla';
|
||||
|
||||
extra pred: SELECT countries.CountryId , countries.CountryName FROM countries JOIN car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM countries JOIN car_makers WHERE car_makers.Country = "value"
|
||||
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'tesla';
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM flights WHERE flights.SourceAirport = "value"
|
||||
medium gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.SourceAirport = T2.AirportCode WHERE T2.City = "Jackson"
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
|
||||
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Syracuse"
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
|
||||
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Syracuse"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.DestAirport = "ASY"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airports.AirportName = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.DestAirport = "ASY"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.SourceAirport = "AHD"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airports.AirportName = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "JetBlue Airways" AND T2.SourceAirport = "AHD"
|
||||
|
||||
extra pred: SELECT airports.City FROM airports JOIN flights GROUP BY flights.SourceAirport ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.City FROM AIRPORTS AS T1 JOIN FLIGHTS AS T2 ON T1.AirportCode = T2.SourceAirport GROUP BY T1.City ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT airlines.Abbreviation , airlines.Country FROM airlines JOIN flights GROUP BY flights.Airline ORDER BY COUNT(*) ASC LIMIT 1
|
||||
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
|
||||
|
||||
extra pred: SELECT airlines.Abbreviation , airports.Country FROM airlines JOIN airports JOIN flights JOIN flights GROUP BY airlines.Abbreviation ORDER BY COUNT(*) ASC LIMIT 1
|
||||
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "AHD"
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.DestAirport = "AHD"
|
||||
|
||||
extra pred: SELECT airlines.Airline FROM airports JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE flights.SourceAirport = "value"
|
||||
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
|
||||
|
||||
extra pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airports JOIN airlines JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE airlines.Airline = "value"
|
||||
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airlines JOIN flights GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM flights JOIN airlines GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
|
||||
|
||||
easy pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
|
||||
easy gold: SELECT FlightNo FROM FLIGHTS WHERE DestAirport = "APG"
|
||||
|
||||
medium pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
|
||||
medium gold: SELECT T1.FlightNo FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.SourceAirport = T2.AirportCode WHERE T2.City = "Jackson"
|
||||
|
||||
medium pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
|
||||
medium gold: SELECT T1.FlightNo FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode WHERE T2.City = "Jackson"
|
||||
|
||||
hard pred: SELECT airports.AirportName FROM airports WHERE airports.AirportCode NOT IN ( SELECT flights.SourceAirport FROM flights )
|
||||
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
|
||||
|
||||
hard pred: SELECT airports.AirportName FROM airports EXCEPT SELECT airports.AirportName FROM airports JOIN flights
|
||||
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
|
||||
|
||||
medium pred: SELECT Documents.Document_ID , Documents.Document_Name FROM Documents
|
||||
medium gold: SELECT document_id , document_name , document_description FROM Documents
|
||||
|
||||
medium pred: SELECT Documents.Document_Name , Documents.Template_ID FROM Documents WHERE Documents.Document_Description LIKE "value"
|
||||
medium gold: SELECT document_name , Document_ID, template_id FROM Documents WHERE Document_Description LIKE "%w%"
|
||||
|
||||
medium pred: SELECT Documents.Document_Description , Documents.Template_ID FROM Documents WHERE Documents.Document_Name = "value"
|
||||
medium gold: SELECT document_id , template_id , Document_Description FROM Documents WHERE document_name = "Robbin CV"
|
||||
|
||||
medium pred: SELECT Documents.Document_Description , Documents.Template_ID , Documents.Document_Name FROM Documents WHERE Documents.Document_Name = "value"
|
||||
medium gold: SELECT document_id , template_id , Document_Description FROM Documents WHERE document_name = "Robbin CV"
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents GROUP BY Documents.Template_ID ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Documents AS T1 JOIN Templates AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents GROUP BY Documents.Template_ID ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Documents AS T1 JOIN Templates AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From , Templates.Template_Type_Code FROM Templates
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To , Templates.Template_Type_Code FROM Templates
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From FROM Templates
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value" OR Templates.Template_Type_Code = "value"
|
||||
extra gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "PP" OR template_type_code = "PPT"
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value" OR Templates.Template_Type_Code = "value"
|
||||
extra gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "PP" OR template_type_code = "PPT"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "CV"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates WHERE Templates.Template_Type_Code = "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates WHERE template_type_code = "CV"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To , Templates.Template_Type_Code FROM Templates WHERE Templates.Version_Number > "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates WHERE version_number > 5
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From , Templates.Template_Type_Code FROM Templates WHERE Templates.Version_Number > "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , template_type_code FROM Templates WHERE version_number > 5
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Date_Effective_To
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , count(*) FROM Templates GROUP BY template_type_code
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Date_Effective_To
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To , count(*) FROM Templates GROUP BY template_type_code
|
||||
|
||||
hard pred: SELECT Templates.Date_Effective_From FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To HAVING COUNT(*) < "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code HAVING count(*) < 3
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To HAVING COUNT(*) < "value"
|
||||
medium gold: SELECT Date_Effective_From , Date_Effective_To FROM Templates GROUP BY template_type_code HAVING count(*) < 3
|
||||
|
||||
medium pred: SELECT MIN(Templates.Date_Effective_From) , MIN(Templates.Date_Effective_To) , Templates.Version_Number FROM Templates
|
||||
medium gold: SELECT min(Version_Number) , Date_Effective_From , Date_Effective_To FROM Templates
|
||||
|
||||
medium pred: SELECT MIN(Templates.Date_Effective_From) , AVG(Templates.Date_Effective_To) , Templates.Version_Number FROM Templates
|
||||
medium gold: SELECT min(Version_Number) , Date_Effective_From , Date_Effective_To FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Documents JOIN Templates WHERE Documents.Document_Name = "value"
|
||||
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T2.document_name = "Data base"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Documents WHERE Documents.Document_Name = "value"
|
||||
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T2.document_name = "Data base"
|
||||
|
||||
medium pred: SELECT Documents.Document_Name , Documents.Template_ID , Templates.Template_Type_Code FROM Documents JOIN Templates WHERE Templates.Template_Type_Code = "value"
|
||||
medium gold: SELECT T2.document_name, T2.Document_ID, T2.template_id FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T1.template_type_code = "BK"
|
||||
|
||||
medium pred: SELECT Documents.Document_Name , Documents.Template_ID FROM Documents JOIN Templates WHERE Templates.Template_Type_Code = "value"
|
||||
medium gold: SELECT T2.document_name, T2.Document_ID, T2.template_id FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id WHERE T1.template_type_code = "BK"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From , Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
|
||||
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From , Templates.Date_Effective_To , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
|
||||
medium gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Templates.Date_Effective_To FROM Templates GROUP BY Templates.Date_Effective_To ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.Date_Effective_From , T1.Date_Effective_To FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_From FROM Templates JOIN Ref_Template_Types WHERE Ref_Template_Types.Template_Type_Description = "value"
|
||||
medium gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Ref_template_types AS T1 JOIN Templates AS T2 ON T1.template_type_code = T2.template_type_code WHERE T1.template_type_description = "Presentation"
|
||||
|
||||
medium pred: SELECT Templates.Date_Effective_To FROM Templates JOIN Ref_Template_Types WHERE Ref_Template_Types.Template_Type_Description = "value"
|
||||
medium gold: SELECT T2.Date_Effective_From , T2.Date_Effective_To FROM Ref_template_types AS T1 JOIN Templates AS T2 ON T1.template_type_code = T2.template_type_code WHERE T1.template_type_description = "Presentation"
|
||||
|
||||
medium pred: SELECT players.first_name FROM players JOIN matches
|
||||
medium gold: SELECT loser_name, winner_name FROM matches
|
||||
|
||||
medium pred: SELECT players.first_name FROM players JOIN matches
|
||||
medium gold: SELECT loser_name, winner_name FROM matches
|
||||
|
||||
medium pred: SELECT matches.loser_age FROM matches
|
||||
medium gold: SELECT winner_age, loser_age FROM matches
|
||||
|
||||
medium pred: SELECT matches.loser_age FROM matches
|
||||
medium gold: SELECT winner_age, loser_age FROM matches
|
||||
|
||||
medium pred: SELECT AVG(matches.winner_rank) FROM matches
|
||||
medium gold: SELECT avg(winner_rank), avg(loser_rank) FROM matches
|
||||
|
||||
medium pred: SELECT MAX(matches.winner_rank) FROM matches
|
||||
medium gold: SELECT min(winner_rank), min(loser_rank) FROM matches
|
||||
|
||||
medium pred: SELECT MAX(matches.winner_rank_points) FROM matches
|
||||
medium gold: SELECT min(winner_rank), min(loser_rank) FROM matches
|
||||
|
||||
easy pred: SELECT matches.tourney_name FROM matches GROUP BY matches.tourney_id HAVING COUNT(*) > "value"
|
||||
easy gold: SELECT tourney_name FROM matches GROUP BY tourney_name HAVING count(*) > 10
|
||||
|
||||
extra pred: SELECT matches.winner_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT matches.winner_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value"
|
||||
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT loser_name, winner_name FROM matches WHERE YEAR = 2016
|
||||
|
||||
extra pred: SELECT players.first_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT players.first_name , players.last_name , matches.loser_name FROM players JOIN matches WHERE matches.year = "value"
|
||||
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT loser_name, winner_name FROM matches WHERE YEAR = 2016
|
||||
|
||||
extra pred: SELECT players.first_name FROM players JOIN matches WHERE matches.year = "value" OR matches.year = "value"
|
||||
extra gold: SELECT loser_name, winner_name FROM matches WHERE YEAR = 2013 OR YEAR = 2016
|
||||
|
||||
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
|
||||
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
|
||||
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , players.last_name FROM players ORDER BY players.last_name DESC
|
||||
medium gold: SELECT first_name , last_name FROM players ORDER BY birth_date
|
||||
|
||||
medium pred: SELECT players.first_name , players.last_name FROM players WHERE players.hand = "value" ORDER BY players.birth_date DESC
|
||||
medium gold: SELECT first_name , last_name FROM players WHERE hand = 'L' ORDER BY birth_date
|
||||
|
||||
medium pred: SELECT players.first_name FROM players WHERE players.hand = "value" ORDER BY players.birth_date ASC
|
||||
medium gold: SELECT first_name , last_name FROM players WHERE hand = 'L' ORDER BY birth_date desc
|
||||
|
||||
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT matches.winner_name , matches.winner_rank_points FROM players JOIN matches GROUP BY players.first_name ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT players.first_name , players.last_name FROM players JOIN matches GROUP BY players.first_name ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT matches.winner_name , matches.loser_name , matches.winner_rank_points FROM matches WHERE matches.tourney_name = "value" ORDER BY matches.winner_rank_points DESC LIMIT 1
|
||||
hard gold: SELECT loser_name, winner_name FROM matches WHERE tourney_name = 'Australian Open' ORDER BY winner_rank_points DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT matches.winner_name , matches.loser_name , matches.winner_rank_points FROM matches WHERE matches.tourney_name = "value" ORDER BY matches.winner_rank_points DESC LIMIT 1
|
||||
hard gold: SELECT loser_name, winner_name FROM matches WHERE tourney_name = 'Australian Open' ORDER BY winner_rank_points DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , players.last_name FROM players JOIN matches ORDER BY matches.minutes DESC LIMIT 1
|
||||
medium gold: SELECT winner_name , loser_name FROM matches ORDER BY minutes DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , AVG(rankings.ranking_points) FROM players JOIN rankings GROUP BY players.first_name
|
||||
medium gold: SELECT avg(ranking) , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
|
||||
|
||||
medium pred: SELECT SUM(rankings.ranking_points) , players.last_name FROM players JOIN rankings GROUP BY players.first_name
|
||||
medium gold: SELECT sum(ranking_points) , T1.first_name, T1.last_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
|
||||
|
||||
medium pred: SELECT players.first_name , SUM(players.last_name) FROM players JOIN rankings GROUP BY players.last_name
|
||||
medium gold: SELECT sum(ranking_points) , T1.first_name, T1.last_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
|
||||
|
||||
medium pred: SELECT COUNT(*) , rankings.ranking_date FROM rankings GROUP BY rankings.ranking_date
|
||||
medium gold: SELECT sum(tours) , ranking_date FROM rankings GROUP BY ranking_date
|
||||
|
||||
medium pred: SELECT COUNT(matches.winner_entry) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
|
||||
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
|
||||
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
|
||||
|
||||
hard pred: SELECT matches.winner_name , players.birth_date FROM players JOIN matches ORDER BY matches.winner_rank_points DESC LIMIT 1
|
||||
hard gold: SELECT T1.first_name , T1.last_name , T1.birth_date FROM players AS T1 JOIN matches AS T2 ON T1.player_id = T2.winner_id ORDER BY T2.winner_rank_points DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT players.first_name , players.birth_date FROM players JOIN matches ORDER BY matches.winner_rank_points DESC LIMIT 1
|
||||
hard gold: SELECT T1.first_name , T1.last_name , T1.birth_date FROM players AS T1 JOIN matches AS T2 ON T1.player_id = T2.winner_id ORDER BY T2.winner_rank_points DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Addresses.line_1 FROM Addresses
|
||||
medium gold: SELECT line_1 , line_2, line_3 FROM addresses
|
||||
|
||||
medium pred: SELECT Addresses.line_3 FROM Addresses
|
||||
medium gold: SELECT line_1 , line_2, line_3 FROM addresses
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM show
|
||||
easy gold: SELECT count(*) FROM show where If_first_show = 'T'
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM show
|
||||
easy gold: SELECT count(*) FROM show where If_first_show = 'T'
|
||||
|
||||
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
|
||||
easy gold: SELECT Name FROM conductor ORDER BY birthday desc
|
||||
|
||||
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
|
||||
easy gold: SELECT Name FROM conductor ORDER BY birthday desc
|
||||
|
||||
easy pred: SELECT AVG(show.Attendance) FROM show WHERE show.If_first_show NOT IN ( SELECT show.If_first_show FROM show )
|
||||
easy gold: SELECT avg(Attendance) FROM SHOW where If_first_show = 'F'
|
||||
|
||||
easy pred: SELECT AVG(show.Attendance) FROM show WHERE show.If_first_show != "value"
|
||||
easy gold: SELECT avg(Attendance) FROM SHOW where If_first_show = 'F'
|
||||
|
||||
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
|
||||
easy gold: SELECT Name FROM conductor ORDER BY birthday DESC
|
||||
|
||||
easy pred: SELECT conductor.Name FROM conductor ORDER BY conductor.birthday ASC
|
||||
easy gold: SELECT Name FROM conductor ORDER BY birthday DESC
|
||||
|
||||
hard pred: SELECT conductor.Name FROM orchestra JOIN conductor ORDER BY orchestra.Year_of_Founded DESC LIMIT 1
|
||||
hard gold: SELECT T1.Name FROM conductor AS T1 JOIN orchestra AS T2 ON T1.Conductor_ID = T2.Conductor_ID order by Year_of_Founded asc limit 1
|
||||
|
||||
hard pred: SELECT conductor.Name FROM conductor JOIN orchestra ORDER BY orchestra.Orchestra DESC LIMIT 1
|
||||
hard gold: SELECT T1.Name FROM conductor AS T1 JOIN orchestra AS T2 ON T1.Conductor_ID = T2.Conductor_ID order by Year_of_Founded desc limit 1
|
||||
|
||||
easy pred: SELECT orchestra.Major_Record_Format FROM orchestra ORDER BY orchestra.Year_of_Founded ASC
|
||||
easy gold: SELECT Major_Record_Format FROM orchestra ORDER BY Year_of_Founded desc
|
||||
|
||||
easy pred: SELECT orchestra.Major_Record_Format FROM orchestra ORDER BY orchestra.Year_of_Founded ASC
|
||||
easy gold: SELECT Major_Record_Format FROM orchestra ORDER BY Year_of_Founded desc
|
||||
|
||||
medium pred: SELECT orchestra.Record_Company FROM orchestra ORDER BY orchestra.Record_Company DESC LIMIT 1
|
||||
medium gold: SELECT Record_Company FROM orchestra ORDER BY Year_of_Founded desc LIMIT 1
|
||||
|
||||
easy pred: SELECT orchestra.Record_Company FROM orchestra WHERE orchestra.Year_of_Founded >= "value" INTERSECT SELECT orchestra.Record_Company FROM orchestra WHERE orchestra.Year_of_Founded >= "value"
|
||||
easy gold: SELECT Record_Company FROM orchestra WHERE Year_of_Founded >= 2003
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM show WHERE show.Result = "value"
|
||||
medium gold: SELECT COUNT(*) FROM show WHERE Result = "Glebe Park" and If_first_show = "T"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM show WHERE show.Result = "value"
|
||||
medium gold: SELECT COUNT(*) FROM show WHERE Result = "Glebe Park" and If_first_show = "T"
|
||||
|
||||
hard pred: SELECT performance.Type FROM performance JOIN show WHERE show.If_first_show != "value" GROUP BY show.Performance_ID HAVING COUNT(*) > "value"
|
||||
hard gold: SELECT Type FROM performance AS T1 JOIN show AS T2 ON T1.Performance_ID = T2.Performance_ID WHERE If_first_show = 'F' GROUP BY T2.Performance_ID HAVING COUNT(*) > 1
|
||||
|
||||
hard pred: SELECT performance.Type FROM performance JOIN show WHERE show.If_first_show != "value" GROUP BY show.Performance_ID HAVING COUNT(*) > "value"
|
||||
hard gold: SELECT Type FROM performance AS T1 JOIN show AS T2 ON T1.Performance_ID = T2.Performance_ID Where If_first_show = 'F' GROUP BY T2.Performance_ID HAVING COUNT(*) > 1
|
||||
|
||||
hard pred: SELECT AVG(Dogs.age) FROM Dogs JOIN Treatments
|
||||
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
hard pred: SELECT AVG(Dogs.age) FROM Treatments JOIN Dogs
|
||||
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
|
||||
|
||||
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Professionals.state HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
|
||||
|
||||
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
|
||||
hard gold: SELECT name FROM Dogs WHERE dog_id NOT IN( SELECT dog_id FROM Treatments GROUP BY dog_id HAVING sum(cost_of_treatment) > 1000 )
|
||||
|
||||
hard pred: SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
|
||||
hard gold: SELECT name FROM Dogs WHERE dog_id NOT IN( SELECT dog_id FROM Treatments GROUP BY dog_id HAVING sum(cost_of_treatment) > 1000 )
|
||||
|
||||
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
|
||||
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
|
||||
|
||||
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
|
||||
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
|
||||
|
||||
extra pred: SELECT Professionals.professional_id , Professionals.last_name FROM Professionals WHERE Professionals.professional_id NOT IN ( SELECT Treatments.professional_id FROM Treatments )
|
||||
extra gold: SELECT professional_id , first_name , last_name FROM Professionals EXCEPT SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id
|
||||
|
||||
extra pred: SELECT Professionals.professional_id , Professionals.last_name FROM Professionals WHERE Professionals.professional_id NOT IN ( SELECT Treatments.professional_id FROM Treatments )
|
||||
extra gold: SELECT professional_id , first_name , last_name FROM Professionals EXCEPT SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.last_name FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T2.first_name , T2.last_name FROM Dogs AS T1 JOIN Owners AS T2 ON T1.owner_id = T2.owner_id GROUP BY T1.owner_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.last_name FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T2.first_name , T2.last_name FROM Dogs AS T1 JOIN Owners AS T2 ON T1.owner_id = T2.owner_id GROUP BY T1.owner_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Treatments JOIN Professionals GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
|
||||
|
||||
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Treatments JOIN Professionals GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
|
||||
|
||||
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code where abandoned_yn = 1 GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code where abandoned_yn = 1 GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs JOIN Charges GROUP BY Owners.owner_id ORDER BY SUM(Charges.charge_amount) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY SUM(Dogs.date_adopted) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Treatments.professional_id , Professionals.last_name FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.professional_id , T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
|
||||
|
||||
medium pred: SELECT Treatments.professional_id , Professionals.home_phone , COUNT(*) FROM Professionals JOIN Treatments GROUP BY Professionals.professional_id HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.professional_id , T1.home_phone , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) >= 2
|
||||
|
||||
extra pred: SELECT Professionals.first_name FROM Professionals JOIN Treatments WHERE Professionals.last_name < ( SELECT AVG(Treatments.cost_of_treatment) FROM Treatments )
|
||||
extra gold: SELECT DISTINCT T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 WHERE cost_of_treatment < ( SELECT avg(cost_of_treatment) FROM Treatments )
|
||||
|
||||
extra pred: SELECT Professionals.first_name FROM Professionals JOIN Treatments WHERE Professionals.last_name < ( SELECT AVG(Treatments.cost_of_treatment) FROM Treatments )
|
||||
extra gold: SELECT DISTINCT T1.first_name , T1.last_name FROM Professionals AS T1 JOIN Treatments AS T2 WHERE cost_of_treatment < ( SELECT avg(cost_of_treatment) FROM Treatments )
|
||||
|
||||
medium pred: SELECT Treatments.date_of_treatment , Professionals.last_name FROM Treatments JOIN Professionals
|
||||
medium gold: SELECT T1.date_of_treatment , T2.first_name , T2.last_name FROM Treatments AS T1 JOIN Professionals AS T2 ON T1.professional_id = T2.professional_id
|
||||
|
||||
medium pred: SELECT Treatments.date_of_treatment , Professionals.last_name FROM Treatments JOIN Professionals
|
||||
medium gold: SELECT T1.date_of_treatment , T2.first_name , T2.last_name FROM Treatments AS T1 JOIN Professionals AS T2 ON T1.professional_id = T2.professional_id
|
||||
|
||||
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
|
||||
medium gold: SELECT T1.first_name , T1.last_name , T2.size_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
|
||||
|
||||
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
|
||||
medium gold: SELECT T1.first_name , T1.last_name , T2.size_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
|
||||
|
||||
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
|
||||
medium gold: SELECT T1.first_name , T1.last_name , T2.name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
|
||||
|
||||
medium pred: SELECT Owners.first_name , Owners.last_name FROM Owners JOIN Dogs
|
||||
medium gold: SELECT T1.first_name , T1.last_name , T2.name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id
|
||||
|
||||
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Treatments JOIN Dogs GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
|
||||
|
||||
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Dogs JOIN Treatments GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
|
||||
|
||||
medium pred: SELECT Dogs.date_arrived FROM Dogs
|
||||
medium gold: SELECT DISTINCT T1.date_arrived , T1.date_departed FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id
|
||||
|
||||
medium pred: SELECT Dogs.date_arrived FROM Dogs ORDER BY Dogs.date_departed DESC
|
||||
medium gold: SELECT DISTINCT T1.date_arrived , T1.date_departed FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id
|
||||
|
||||
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
|
||||
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
|
||||
|
||||
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
|
||||
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
|
||||
|
||||
medium pred: SELECT Dogs.date_arrived , Dogs.date_departed FROM Dogs
|
||||
medium gold: SELECT date_arrived , date_departed FROM Dogs where abandoned_yn = 1
|
||||
|
||||
medium pred: SELECT Dogs.date_arrived , Dogs.date_departed FROM Dogs
|
||||
medium gold: SELECT date_arrived , date_departed FROM Dogs where abandoned_yn = 1
|
||||
|
||||
medium pred: SELECT Professionals.first_name FROM Professionals WHERE Professionals.city LIKE "value"
|
||||
medium gold: SELECT first_name , last_name FROM professionals WHERE city LIKE '%West%'
|
||||
|
||||
medium pred: SELECT Professionals.first_name FROM Professionals WHERE Professionals.city LIKE "value"
|
||||
medium gold: SELECT first_name , last_name FROM professionals WHERE city LIKE '%West%'
|
||||
|
||||
medium pred: SELECT Owners.last_name FROM Owners WHERE Owners.state LIKE "value"
|
||||
medium gold: SELECT first_name , last_name FROM Owners WHERE state LIKE '%North%'
|
||||
|
||||
medium pred: SELECT Owners.first_name FROM Owners WHERE Owners.state LIKE "value"
|
||||
medium gold: SELECT first_name , last_name FROM Owners WHERE state LIKE '%North%'
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.age < ( SELECT AVG(Dogs.age) FROM Dogs )
|
||||
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and age < ( SELECT avg(age) FROM Dogs )
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.age < ( SELECT AVG(Dogs.age) FROM Dogs )
|
||||
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and age < ( SELECT avg(age) FROM Dogs )
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM Dogs WHERE Dogs.dog_id NOT IN ( SELECT Treatments.dog_id FROM Treatments )
|
||||
extra gold: SELECT count(*) FROM Dogs WHERE abandoned_yn = 1 and dog_id NOT IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
extra pred: SELECT COUNT(Treatments.dog_id) FROM Treatments
|
||||
extra gold: SELECT count(*) FROM Dogs WHERE dog_id NOT IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM Owners WHERE Owners.owner_id NOT IN ( SELECT Dogs.owner_id FROM Dogs )
|
||||
extra gold: SELECT count(*) FROM Owners WHERE owner_id NOT IN ( SELECT owner_id FROM Dogs where abandoned_yn = 1 )
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM Owners WHERE Owners.owner_id NOT IN ( SELECT Dogs.owner_id FROM Dogs )
|
||||
extra gold: SELECT count(*) FROM Owners WHERE owner_id NOT IN ( SELECT owner_id FROM Dogs where abandoned_yn = 1 )
|
||||
|
||||
easy pred: SELECT AVG(Dogs.age) FROM Dogs
|
||||
easy gold: SELECT avg(age) FROM Dogs where abandoned_yn = 1
|
||||
|
||||
easy pred: SELECT AVG(Dogs.age) FROM Dogs
|
||||
easy gold: SELECT avg(age) FROM Dogs where abandoned_yn = 1
|
||||
|
||||
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
|
||||
medium gold: SELECT charge_type , charge_amount FROM Charges
|
||||
|
||||
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
|
||||
medium gold: SELECT charge_type , charge_amount FROM Charges
|
||||
|
||||
easy pred: SELECT Charges.charge_type FROM Charges ORDER BY Charges.charge_amount DESC LIMIT 1
|
||||
easy gold: SELECT max(charge_amount) FROM Charges
|
||||
|
||||
easy pred: SELECT Charges.charge_amount FROM Charges ORDER BY Charges.charge_type DESC LIMIT 1
|
||||
easy gold: SELECT max(charge_amount) FROM Charges
|
||||
|
||||
medium pred: SELECT Professionals.email_address , Professionals.last_name FROM Professionals
|
||||
medium gold: SELECT email_address , first_name , last_name FROM professionals
|
||||
|
||||
medium pred: SELECT Professionals.email_address , Professionals.last_name FROM Professionals
|
||||
medium gold: SELECT email_address , first_name , last_name FROM professionals
|
||||
|
||||
medium pred: SELECT Sizes.size_code , Sizes.size_code FROM Sizes
|
||||
medium gold: SELECT DISTINCT breed_code , size_code FROM dogs
|
||||
|
||||
medium pred: SELECT Professionals.first_name , Professionals.last_name FROM Professionals JOIN Treatments JOIN Treatment_Types
|
||||
medium gold: SELECT DISTINCT T1.first_name , T1.last_name , T3.treatment_type_description FROM professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id JOIN Treatment_types AS T3 ON T2.treatment_type_code = T3.treatment_type_code
|
||||
|
||||
medium pred: SELECT Professionals.first_name , Professionals.last_name FROM Professionals JOIN Treatments JOIN Treatment_Types
|
||||
medium gold: SELECT DISTINCT T1.first_name , T1.last_name , T3.treatment_type_description FROM professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id JOIN Treatment_types AS T3 ON T2.treatment_type_code = T3.treatment_type_code
|
||||
|
||||
easy pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC
|
||||
easy gold: SELECT Name FROM singer ORDER BY Birth_Year desc
|
||||
|
||||
easy pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC
|
||||
easy gold: SELECT Name FROM singer ORDER BY Birth_Year desc
|
||||
|
||||
easy pred: SELECT singer.Name FROM singer WHERE singer.Birth_Year < ( SELECT MAX(singer.Birth_Year) FROM singer WHERE singer.Birth_Year = "value" )
|
||||
easy gold: SELECT Name FROM singer WHERE Birth_Year <= 1948
|
||||
|
||||
easy pred: SELECT singer.Name FROM singer WHERE singer.Birth_Year < ( SELECT MAX(singer.Birth_Year) FROM singer WHERE singer.Birth_Year = "value" )
|
||||
easy gold: SELECT Name FROM singer WHERE Birth_Year <= 1948
|
||||
|
||||
medium pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC LIMIT 1
|
||||
medium gold: SELECT Name FROM singer ORDER BY Birth_Year DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT singer.Name FROM singer ORDER BY singer.Birth_Year ASC LIMIT 1
|
||||
medium gold: SELECT Name FROM singer ORDER BY Birth_Year DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT singer.Name FROM singer JOIN song GROUP BY song.Singer_ID HAVING COUNT(*) > "value"
|
||||
medium gold: SELECT T1.Name FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID = T2.Singer_ID GROUP BY T1.Name HAVING COUNT(*) > 1
|
||||
|
||||
easy pred: SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year < "value" INTERSECT SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year <= "value"
|
||||
easy gold: SELECT Citizenship FROM singer WHERE Birth_Year <= 1945
|
||||
|
||||
easy pred: SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year < "value" INTERSECT SELECT singer.Citizenship FROM singer WHERE singer.Birth_Year = "value"
|
||||
easy gold: SELECT Citizenship FROM singer WHERE Birth_Year <= 1945
|
||||
|
||||
easy medium hard extra all
|
||||
count 110 246 74 105 535
|
||||
|
||||
====================== EXACT MATCHING ACCURACY =====================
|
||||
exact match 0.755 0.508 0.405 0.333 0.510
|
||||
|
||||
---------------------PARTIAL MATCHING ACCURACY----------------------
|
||||
select 0.964 0.695 0.824 0.752 0.779
|
||||
select(no AGG) 0.973 0.711 0.892 0.790 0.806
|
||||
where 0.860 0.807 0.486 0.508 0.692
|
||||
where(no OP) 0.980 0.852 0.657 0.689 0.808
|
||||
group(no Having) 0.625 0.861 0.667 0.775 0.797
|
||||
group 0.625 0.833 0.667 0.750 0.775
|
||||
order 0.143 0.543 0.786 0.786 0.639
|
||||
and/or 1.000 0.984 0.973 0.894 0.968
|
||||
IUEN 0.000 0.000 0.455 0.320 0.325
|
||||
keywords 0.778 0.896 0.792 0.760 0.828
|
||||
---------------------- PARTIAL MATCHING RECALL ----------------------
|
||||
select 0.964 0.695 0.824 0.752 0.779
|
||||
select(no AGG) 0.973 0.711 0.892 0.790 0.806
|
||||
where 0.768 0.789 0.472 0.449 0.645
|
||||
where(no OP) 0.875 0.833 0.639 0.609 0.753
|
||||
group(no Having) 0.625 0.861 0.750 0.816 0.821
|
||||
group 0.625 0.833 0.750 0.789 0.799
|
||||
order 0.200 0.559 0.786 0.825 0.679
|
||||
and/or 1.000 0.992 1.000 0.989 0.994
|
||||
IUEN 0.000 0.000 0.417 0.364 0.382
|
||||
keywords 0.757 0.892 0.770 0.752 0.817
|
||||
---------------------- PARTIAL MATCHING F1 --------------------------
|
||||
select 0.964 0.695 0.824 0.752 0.779
|
||||
select(no AGG) 0.973 0.711 0.892 0.790 0.806
|
||||
where 0.811 0.798 0.479 0.477 0.668
|
||||
where(no OP) 0.925 0.843 0.648 0.646 0.779
|
||||
group(no Having) 0.625 0.861 0.706 0.795 0.809
|
||||
group 0.625 0.833 0.706 0.769 0.787
|
||||
order 0.167 0.551 0.786 0.805 0.658
|
||||
and/or 1.000 0.988 0.986 0.939 0.981
|
||||
IUEN 1.000 1.000 0.435 0.340 0.351
|
||||
keywords 0.767 0.894 0.781 0.756 0.822
|
|
@ -0,0 +1,774 @@
|
|||
medium pred: SELECT MAX(stadium.Capacity) , AVG(stadium.Average) FROM stadium
|
||||
medium gold: select max(capacity), average from stadium
|
||||
|
||||
medium pred: SELECT AVG(stadium.Average) , MAX(stadium.Capacity) FROM stadium
|
||||
medium gold: select avg(capacity) , max(capacity) from stadium
|
||||
|
||||
medium pred: SELECT stadium.Stadium_ID , COUNT(*) FROM stadium JOIN concert GROUP BY stadium.Stadium_ID
|
||||
medium gold: SELECT T2.name , count(*) FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id GROUP BY T1.stadium_id
|
||||
|
||||
medium pred: SELECT singer.Name , COUNT(*) FROM singer_in_concert JOIN singer GROUP BY singer.Name
|
||||
medium gold: SELECT T2.name , count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id GROUP BY T2.singer_id
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM concert JOIN stadium ORDER BY stadium.Capacity DESC LIMIT 1
|
||||
hard gold: select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM concert JOIN stadium ORDER BY stadium.Capacity DESC LIMIT 1
|
||||
hard gold: select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)
|
||||
|
||||
medium pred: SELECT MAX(Pets.weight) , MAX(Pets.weight) , Pets.PetType FROM Pets GROUP BY Pets.PetType
|
||||
medium gold: SELECT max(weight) , petType FROM pets GROUP BY petType
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM Pets JOIN Has_Pet JOIN Student WHERE Student.Age > "value"
|
||||
medium gold: SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid WHERE T1.age > 20
|
||||
|
||||
extra pred: SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" INTERSECT SELECT Student.Fname FROM Student JOIN Has_Pet WHERE Has_Pet.PetID = "value"
|
||||
extra gold: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog'
|
||||
|
||||
extra pred: SELECT Student.Major , Student.Age FROM Student WHERE Student.StuID NOT IN ( SELECT Has_Pet.StuID FROM Has_Pet JOIN Pets WHERE Pets.PetType = "value" )
|
||||
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
extra pred: SELECT Student.Major , Student.Age FROM Student EXCEPT SELECT Student.Major , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
extra gold: SELECT major , age FROM student WHERE stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
|
||||
|
||||
hard pred: SELECT Student.StuID FROM Student EXCEPT SELECT Has_Pet.StuID FROM Has_Pet WHERE Has_Pet.PetID = "value"
|
||||
hard gold: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat'
|
||||
|
||||
extra pred: SELECT Student.Fname , Student.Age FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
extra pred: SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value" EXCEPT SELECT Student.Fname FROM Student JOIN Has_Pet JOIN Pets WHERE Pets.PetType = "value"
|
||||
extra gold: SELECT T1.fname , T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'dog' AND T1.stuid NOT IN (SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat')
|
||||
|
||||
medium pred: SELECT COUNT(*) , Has_Pet.StuID FROM Has_Pet GROUP BY Has_Pet.StuID
|
||||
medium gold: SELECT count(*) , T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid GROUP BY T1.stuid
|
||||
|
||||
medium pred: SELECT Has_Pet.StuID , COUNT(*) FROM Has_Pet WHERE Has_Pet.StuID != "value" GROUP BY Has_Pet.StuID
|
||||
medium gold: select count(*) , t1.stuid from student as t1 join has_pet as t2 on t1.stuid = t2.stuid group by t1.stuid
|
||||
|
||||
hard pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.pet_age = "value"
|
||||
hard gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pet_age = 3 AND T3.pettype = 'cat'
|
||||
|
||||
hard pred: SELECT Student.LName FROM Pets JOIN Has_Pet JOIN Student WHERE Pets.pet_age = "value"
|
||||
hard gold: SELECT T1.lname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pet_age = 3 AND T3.pettype = 'cat'
|
||||
|
||||
extra pred: SELECT car_makers.Maker FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
|
||||
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year = '1970';
|
||||
|
||||
extra pred: SELECT car_makers.FullName FROM cars_data JOIN car_makers WHERE cars_data.Year = "value"
|
||||
extra gold: SELECT DISTINCT T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model JOIN CARS_DATA AS T4 ON T3.MakeId = T4.id WHERE T4.year = '1970';
|
||||
|
||||
extra pred: SELECT car_names.Make , cars_data.Year FROM cars_data JOIN car_names ORDER BY cars_data.Year ASC LIMIT 1
|
||||
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
|
||||
|
||||
extra pred: SELECT car_makers.Maker , cars_data.Year FROM cars_data JOIN car_names JOIN model_list JOIN car_makers WHERE cars_data.Year = "value" ORDER BY cars_data.Year ASC LIMIT 1
|
||||
extra gold: SELECT T2.Make , T1.Year FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Year = (SELECT min(YEAR) FROM CARS_DATA);
|
||||
|
||||
hard pred: SELECT car_names.Model FROM cars_data JOIN car_names WHERE cars_data.Year > "value"
|
||||
hard gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.model = T2.model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.id WHERE T3.year > 1980;
|
||||
|
||||
medium pred: SELECT COUNT(*) , car_makers.FullName FROM car_makers JOIN model_list GROUP BY car_makers.Maker
|
||||
medium gold: select count(*) , t2.fullname from model_list as t1 join car_makers as t2 on t1.maker = t2.id group by t2.id;
|
||||
|
||||
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
|
||||
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
|
||||
|
||||
medium pred: SELECT cars_data.Accelerate FROM cars_data JOIN car_names WHERE car_names.Model = "value" AND car_names.Make = "value"
|
||||
medium gold: SELECT T1.Accelerate FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Make = 'amc hornet sportabout (sw)';
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM model_list JOIN car_makers WHERE car_makers.Country = "value"
|
||||
hard gold: SELECT count(*) FROM MODEL_LIST AS T1 JOIN CAR_MAKERS AS T2 ON T1.Maker = T2.Id JOIN COUNTRIES AS T3 ON T2.Country = T3.CountryId WHERE T3.CountryName = 'usa';
|
||||
|
||||
medium pred: SELECT car_makers.Maker , model_list.Model FROM car_makers JOIN model_list
|
||||
medium gold: SELECT Maker , Model FROM MODEL_LIST;
|
||||
|
||||
medium pred: SELECT countries.CountryName , car_makers.Id FROM car_makers JOIN countries
|
||||
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
|
||||
|
||||
medium pred: SELECT countries.CountryName , countries.CountryId FROM car_makers JOIN countries
|
||||
medium gold: SELECT T1.CountryName , T1.CountryId FROM COUNTRIES AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.CountryId HAVING count(*) >= 1;
|
||||
|
||||
medium pred: SELECT AVG(cars_data.Weight) , MAX(cars_data.Weight) , cars_data.Year FROM cars_data GROUP BY cars_data.Year
|
||||
medium gold: SELECT avg(Weight) , YEAR FROM CARS_DATA GROUP BY YEAR;
|
||||
|
||||
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
|
||||
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
|
||||
|
||||
extra pred: SELECT countries.CountryName FROM car_makers JOIN countries WHERE countries.Continent = "value" GROUP BY car_makers.Country HAVING COUNT(*) >= "value"
|
||||
extra gold: SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3;
|
||||
|
||||
extra pred: SELECT MAX(cars_data.Horsepower) , MAX(car_names.Make) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
|
||||
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
|
||||
|
||||
extra pred: SELECT MAX(car_names.Make) , MAX(cars_data.Horsepower) FROM cars_data JOIN car_names WHERE cars_data.Cylinders = "value"
|
||||
extra gold: SELECT T2.horsepower , T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T2.cylinders = 3 ORDER BY T2.horsepower DESC LIMIT 1;
|
||||
|
||||
medium pred: SELECT AVG(cars_data.Edispl) FROM cars_data JOIN car_names WHERE car_names.Make = "value"
|
||||
medium gold: SELECT avg(T2.edispl) FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId = T2.Id WHERE T1.Model = 'volvo';
|
||||
|
||||
medium pred: SELECT car_makers.FullName , car_makers.Id , COUNT(*) FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
medium gold: SELECT T1.FullName , T1.Id FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) > 3;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
|
||||
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_makers JOIN cars_data WHERE car_makers.FullName = "value" OR cars_data.Weight > "value"
|
||||
extra gold: SELECT DISTINCT T2.Model FROM CAR_NAMES AS T1 JOIN MODEL_LIST AS T2 ON T1.Model = T2.Model JOIN CAR_MAKERS AS T3 ON T2.Maker = T3.Id JOIN CARS_DATA AS T4 ON T1.MakeId = T4.Id WHERE T3.FullName = 'General Motors' OR T4.weight > 3500;
|
||||
|
||||
easy pred: SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight < "value" INTERSECT SELECT cars_data.Year FROM cars_data WHERE cars_data.Weight > "value"
|
||||
easy gold: select distinct year from cars_data where weight between 3000 and 4000;
|
||||
|
||||
extra pred: SELECT cars_data.Cylinders FROM cars_data JOIN car_names WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
|
||||
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'volvo' ORDER BY T1.accelerate ASC LIMIT 1;
|
||||
|
||||
extra pred: SELECT cars_data.Cylinders FROM cars_data JOIN car_names WHERE cars_data.Accelerate = "value" ORDER BY cars_data.Cylinders ASC LIMIT 1
|
||||
extra gold: SELECT T1.cylinders FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T2.Model = 'volvo' ORDER BY T1.accelerate ASC LIMIT 1;
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
|
||||
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM cars_data WHERE cars_data.Accelerate > ( SELECT MAX(cars_data.Accelerate) FROM cars_data ) ORDER BY cars_data.Horsepower DESC LIMIT 1
|
||||
hard gold: SELECT COUNT(*) FROM CARS_DATA WHERE Accelerate > ( SELECT Accelerate FROM CARS_DATA ORDER BY Horsepower DESC LIMIT 1 );
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
|
||||
medium gold: select count(*) from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 2
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM ( SELECT car_makers.Country FROM car_makers GROUP BY car_makers.Country HAVING COUNT(*) > "value" )
|
||||
medium gold: select count(*) from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 2
|
||||
|
||||
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value" EXCEPT SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders > "value"
|
||||
extra gold: SELECT T2.MakeId , T2.Make FROM CARS_DATA AS T1 JOIN CAR_NAMES AS T2 ON T1.Id = T2.MakeId WHERE T1.Horsepower > (SELECT min(Horsepower) FROM CARS_DATA) AND T1.Cylinders <= 3;
|
||||
|
||||
extra pred: SELECT car_names.MakeId , car_names.Make FROM cars_data JOIN car_names WHERE cars_data.Cylinders < "value"
|
||||
extra gold: select t2.makeid , t2.make from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t1.horsepower > (select min(horsepower) from cars_data) and t1.cylinders < 4;
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_makers.FullName = "value"
|
||||
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
|
||||
|
||||
extra pred: SELECT model_list.Model FROM model_list JOIN car_names WHERE cars_data.Weight < "value" EXCEPT SELECT model_list.Model FROM model_list JOIN car_names JOIN car_makers WHERE car_names.Make = "value"
|
||||
extra gold: SELECT DISTINCT T1.model FROM MODEL_LIST AS T1 JOIN CAR_NAMES AS T2 ON T1.Model = T2.Model JOIN CARS_DATA AS T3 ON T2.MakeId = T3.Id JOIN CAR_MAKERS AS T4 ON T1.Maker = T4.Id WHERE T3.weight < 3500 AND T4.FullName != 'Ford Motor Company';
|
||||
|
||||
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
|
||||
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
|
||||
|
||||
hard pred: SELECT countries.CountryName FROM countries WHERE countries.CountryId NOT IN ( SELECT car_makers.Country FROM car_makers )
|
||||
hard gold: SELECT CountryName FROM countries EXCEPT SELECT T1.CountryName FROM countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.countryId = T2.Country;
|
||||
|
||||
extra pred: SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM car_makers JOIN model_list GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
extra gold: select t1.id , t1.maker from car_makers as t1 join model_list as t2 on t1.id = t2.maker group by t1.id having count(*) >= 2 intersect select t1.id , t1.maker from car_makers as t1 join model_list as t2 on t1.id = t2.maker join car_names as t3 on t2.model = t3.model group by t1.id having count(*) > 3;
|
||||
|
||||
extra pred: SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY car_makers.Id HAVING COUNT(*) >= "value" INTERSECT SELECT car_makers.Id , car_makers.Maker FROM model_list JOIN car_makers GROUP BY model_list.Maker HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker GROUP BY T1.Id HAVING count(*) >= 2 INTERSECT SELECT T1.Id , T1.Maker FROM CAR_MAKERS AS T1 JOIN MODEL_LIST AS T2 ON T1.Id = T2.Maker JOIN CAR_NAMES AS T3 ON T2.model = T3.model GROUP BY T1.Id HAVING count(*) > 3;
|
||||
|
||||
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN model_list JOIN countries JOIN model_list JOIN countries WHERE model_list.Model = "value"
|
||||
extra gold: SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country GROUP BY T1.countryId HAVING count(*) > 3 UNION SELECT T1.countryId , T1.CountryName FROM Countries AS T1 JOIN CAR_MAKERS AS T2 ON T1.CountryId = T2.Country JOIN MODEL_LIST AS T3 ON T2.Id = T3.Maker WHERE T3.Model = 'fiat';
|
||||
|
||||
extra pred: SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries GROUP BY car_makers.Country HAVING COUNT(*) > "value" UNION SELECT countries.CountryId , countries.CountryName FROM car_makers JOIN countries WHERE car_makers.Maker = "value"
|
||||
extra gold: select t1.countryid , t1.countryname from countries as t1 join car_makers as t2 on t1.countryid = t2.country group by t1.countryid having count(*) > 3 union select t1.countryid , t1.countryname from countries as t1 join car_makers as t2 on t1.countryid = t2.country join model_list as t3 on t2.id = t3.maker where t3.model = 'fiat';
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
|
||||
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Aberdeen"
|
||||
|
||||
hard pred: SELECT COUNT(*) FROM airports JOIN flights WHERE airports.City = "value" AND airports.AirportName = "value"
|
||||
hard gold: SELECT count(*) FROM FLIGHTS AS T1 JOIN AIRPORTS AS T2 ON T1.DestAirport = T2.AirportCode JOIN AIRPORTS AS T3 ON T1.SourceAirport = T3.AirportCode WHERE T2.City = "Ashley" AND T3.City = "Aberdeen"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.DestAirport = "ASY"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.DestAirport = "ASY"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airlines.Airline = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.SourceAirport = "AHD"
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM airports JOIN flights JOIN airlines WHERE airlines.Airline = "value" AND airports.AirportName = "value"
|
||||
medium gold: SELECT count(*) FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T2.Airline = T1.uid WHERE T1.Airline = "United Airlines" AND T2.SourceAirport = "AHD"
|
||||
|
||||
extra pred: SELECT airports.City FROM airports JOIN flights GROUP BY flights.SourceAirport ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.City FROM AIRPORTS AS T1 JOIN FLIGHTS AS T2 ON T1.AirportCode = T2.SourceAirport GROUP BY T1.City ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT airlines.Abbreviation , airlines.Country FROM airlines JOIN flights GROUP BY flights.Airline ORDER BY COUNT(*) ASC LIMIT 1
|
||||
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
|
||||
|
||||
extra pred: SELECT airlines.Abbreviation , airports.Country FROM airlines JOIN airports JOIN flights JOIN flights GROUP BY airlines.Abbreviation ORDER BY COUNT(*) ASC LIMIT 1
|
||||
extra gold: SELECT T1.Abbreviation , T1.Country FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline ORDER BY count(*) LIMIT 1
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.SourceAirport = "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "AHD"
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airlines WHERE flights.DestAirport = "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.DestAirport = "AHD"
|
||||
|
||||
extra pred: SELECT airlines.Airline FROM airports JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE flights.SourceAirport = "value"
|
||||
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
|
||||
|
||||
extra pred: SELECT airlines.Airline FROM airports JOIN flights JOIN airports JOIN airlines JOIN flights WHERE flights.SourceAirport = "value" EXCEPT SELECT airlines.Airline FROM airlines WHERE airlines.Airline = "value"
|
||||
extra gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "CVO" EXCEPT SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline WHERE T2.SourceAirport = "APG"
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM airlines JOIN flights GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
|
||||
|
||||
medium pred: SELECT airlines.Airline FROM flights JOIN airlines GROUP BY airlines.Airline HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T1.Airline FROM AIRLINES AS T1 JOIN FLIGHTS AS T2 ON T1.uid = T2.Airline GROUP BY T1.Airline HAVING count(*) > 10
|
||||
|
||||
easy pred: SELECT flights.FlightNo FROM airports JOIN flights WHERE airports.AirportName = "value"
|
||||
easy gold: SELECT FlightNo FROM FLIGHTS WHERE DestAirport = "APG"
|
||||
|
||||
hard pred: SELECT airports.AirportName FROM airports WHERE airports.AirportCode NOT IN ( SELECT flights.SourceAirport FROM flights )
|
||||
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
|
||||
|
||||
hard pred: SELECT airports.AirportName FROM airports EXCEPT SELECT airports.AirportName FROM airports JOIN flights
|
||||
hard gold: SELECT AirportName FROM Airports WHERE AirportCode NOT IN (SELECT SourceAirport FROM Flights UNION SELECT DestAirport FROM Flights)
|
||||
|
||||
medium pred: SELECT COUNT(*) , shop.Name FROM hiring JOIN shop GROUP BY hiring.Shop_ID
|
||||
medium gold: SELECT count(*) , t2.name FROM hiring AS t1 JOIN shop AS t2 ON t1.shop_id = t2.shop_id GROUP BY t2.name
|
||||
|
||||
medium pred: SELECT Templates.Version_Number , Templates.Template_Type_Code FROM Templates ORDER BY Templates.Version_Number ASC LIMIT 1
|
||||
medium gold: SELECT min(Version_Number) , template_type_code FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Version_Number , Templates.Template_Type_Code FROM Templates ORDER BY Templates.Version_Number ASC LIMIT 1
|
||||
medium gold: SELECT min(Version_Number) , template_type_code FROM Templates
|
||||
|
||||
medium pred: SELECT Templates.Template_Type_Code , COUNT(*) FROM Templates GROUP BY Templates.Template_Type_Code
|
||||
medium gold: SELECT T1.template_type_code , count(*) FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code
|
||||
|
||||
extra pred: SELECT Templates.Template_Type_Code FROM Templates GROUP BY Templates.Template_Type_Code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.template_type_code FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Templates.Template_Type_Code FROM Templates GROUP BY Templates.Template_Type_Code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.template_type_code FROM Templates AS T1 JOIN Documents AS T2 ON T1.template_id = T2.template_id GROUP BY T1.template_type_code ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT Paragraphs.Document_ID , COUNT(*) FROM Paragraphs GROUP BY Paragraphs.Document_ID ORDER BY COUNT(*) ASC
|
||||
medium gold: SELECT document_id , count(*) FROM Paragraphs GROUP BY document_id ORDER BY document_id
|
||||
|
||||
medium pred: SELECT Paragraphs.Document_ID , COUNT(*) FROM Paragraphs GROUP BY Paragraphs.Document_ID ORDER BY COUNT(*) ASC
|
||||
medium gold: SELECT document_id , count(*) FROM Paragraphs GROUP BY document_id ORDER BY document_id
|
||||
|
||||
medium pred: SELECT teacher.Name , COUNT(*) FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID
|
||||
medium gold: SELECT T2.Name , COUNT(*) FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name
|
||||
|
||||
medium pred: SELECT teacher.Name , COUNT(*) FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID
|
||||
medium gold: SELECT T2.Name , COUNT(*) FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name
|
||||
|
||||
medium pred: SELECT teacher.Name FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T2.Name FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name HAVING COUNT(*) >= 2
|
||||
|
||||
medium pred: SELECT teacher.Name FROM course_arrange JOIN teacher GROUP BY course_arrange.Teacher_ID HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT T2.Name FROM course_arrange AS T1 JOIN teacher AS T2 ON T1.Teacher_ID = T2.Teacher_ID GROUP BY T2.Name HAVING COUNT(*) >= 2
|
||||
|
||||
medium pred: SELECT visitor.Name FROM visitor WHERE visitor.Level_of_membership > "value" ORDER BY visitor.Name DESC
|
||||
medium gold: SELECT name FROM visitor WHERE Level_of_membership > 4 ORDER BY Level_of_membership DESC
|
||||
|
||||
easy pred: SELECT AVG(visitor.Age) FROM visitor WHERE visitor.Level_of_membership > "value"
|
||||
easy gold: SELECT avg(age) FROM visitor WHERE Level_of_membership <= 4
|
||||
|
||||
medium pred: SELECT visitor.Name , visitor.Level_of_membership FROM visitor WHERE visitor.Age > "value" ORDER BY visitor.Age ASC
|
||||
medium gold: SELECT name , Level_of_membership FROM visitor WHERE Level_of_membership > 4 ORDER BY age DESC
|
||||
|
||||
hard pred: SELECT visitor.Name , visitor.Age FROM visitor JOIN visit GROUP BY visit.visitor_ID ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT t1.name , t1.age FROM visitor AS t1 JOIN visit AS t2 ON t1.id = t2.visitor_id ORDER BY t2.num_of_ticket DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM visitor WHERE visitor.ID NOT IN ( SELECT visit.visitor_ID FROM visit JOIN museum WHERE museum.Open_Year > "value" )
|
||||
extra gold: SELECT count(*) FROM visitor WHERE id NOT IN (SELECT t2.visitor_id FROM museum AS t1 JOIN visit AS t2 ON t1.Museum_ID = t2.Museum_ID WHERE t1.open_year > 2010)
|
||||
|
||||
easy pred: SELECT AVG(matches.winner_rank_points) FROM matches
|
||||
easy gold: SELECT avg(winner_rank) FROM matches
|
||||
|
||||
easy pred: SELECT MAX(matches.loser_rank_points) FROM matches
|
||||
easy gold: SELECT min(loser_rank) FROM matches
|
||||
|
||||
easy pred: SELECT MAX(matches.loser_rank_points) FROM matches
|
||||
easy gold: SELECT min(loser_rank) FROM matches
|
||||
|
||||
easy pred: SELECT matches.tourney_name FROM matches GROUP BY matches.tourney_id HAVING COUNT(*) > "value"
|
||||
easy gold: SELECT tourney_name FROM matches GROUP BY tourney_name HAVING count(*) > 10
|
||||
|
||||
hard pred: SELECT players.last_name FROM players JOIN matches WHERE matches.year = "value" INTERSECT SELECT matches.winner_name FROM players JOIN matches WHERE matches.year = "value"
|
||||
hard gold: SELECT winner_name FROM matches WHERE YEAR = 2013 INTERSECT SELECT winner_name FROM matches WHERE YEAR = 2016
|
||||
|
||||
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
|
||||
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , players.country_code FROM players ORDER BY players.birth_date DESC LIMIT 1
|
||||
medium gold: SELECT first_name , country_code FROM players ORDER BY birth_date LIMIT 1
|
||||
|
||||
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT players.first_name , players.country_code FROM players JOIN rankings GROUP BY rankings.player_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT T1.country_code , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id ORDER BY T2.tours DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT matches.winner_name , matches.winner_rank_points FROM matches GROUP BY matches.winner_rank ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT winner_name , winner_rank_points FROM matches GROUP BY winner_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT players.first_name , AVG(rankings.ranking_points) FROM players JOIN rankings GROUP BY players.first_name
|
||||
medium gold: SELECT avg(ranking) , T1.first_name FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id GROUP BY T1.first_name
|
||||
|
||||
medium pred: SELECT COUNT(*) , rankings.ranking_date FROM rankings GROUP BY rankings.ranking_date
|
||||
medium gold: SELECT sum(tours) , ranking_date FROM rankings GROUP BY ranking_date
|
||||
|
||||
medium pred: SELECT COUNT(matches.winner_entry) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
|
||||
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM matches WHERE matches.tourney_name = "value" AND matches.winner_hand = "value"
|
||||
medium gold: SELECT count(DISTINCT winner_name) FROM matches WHERE tourney_name = 'WTA Championships' AND winner_hand = 'L'
|
||||
|
||||
medium pred: SELECT battle.name , battle.date , battle.result FROM battle
|
||||
medium gold: SELECT name , date FROM battle
|
||||
|
||||
medium pred: SELECT MAX(death.killed) , MIN(death.killed) , death.caused_by_ship_id FROM death GROUP BY death.caused_by_ship_id
|
||||
medium gold: SELECT max(killed) , min(killed) FROM death
|
||||
|
||||
easy pred: SELECT AVG(death.injured) , * FROM death GROUP BY *
|
||||
easy gold: SELECT avg(injured) FROM death
|
||||
|
||||
hard pred: SELECT battle.id , battle.name FROM battle JOIN death GROUP BY battle.id HAVING SUM(death.killed) > "value"
|
||||
hard gold: SELECT T1.id , T1.name FROM battle AS T1 JOIN ship AS T2 ON T1.id = T2.lost_in_battle JOIN death AS T3 ON T2.id = T3.caused_by_ship_id GROUP BY T1.id HAVING sum(T3.killed) > 10
|
||||
|
||||
extra pred: SELECT ship.id , ship.name FROM ship JOIN death GROUP BY death.caused_by_ship_id ORDER BY SUM(death.injured) DESC LIMIT 1
|
||||
extra gold: SELECT T2.id , T2.name FROM death AS T1 JOIN ship AS t2 ON T1.caused_by_ship_id = T2.id GROUP BY T2.id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT battle.name , battle.result , battle.bulgarian_commander FROM battle WHERE battle.id NOT IN ( SELECT ship.lost_in_battle FROM ship WHERE ship.location = "value" )
|
||||
extra gold: SELECT name , RESULT , bulgarian_commander FROM battle EXCEPT SELECT T1.name , T1.result , T1.bulgarian_commander FROM battle AS T1 JOIN ship AS T2 ON T1.id = T2.lost_in_battle WHERE T2.location = 'English Channel'
|
||||
|
||||
medium pred: SELECT Addresses.line_1 FROM Addresses WHERE Addresses.line_2 BETWEEN "value" AND "value"
|
||||
medium gold: SELECT line_1 , line_2 FROM addresses
|
||||
|
||||
easy pred: SELECT COUNT(Degree_Programs.degree_program_id) FROM Degree_Programs
|
||||
easy gold: SELECT count(DISTINCT degree_summary_name) FROM Degree_Programs
|
||||
|
||||
medium pred: SELECT Courses.course_name , Courses.course_id FROM Courses JOIN Sections GROUP BY Sections.course_id HAVING COUNT(*) < "value"
|
||||
medium gold: SELECT T1.course_name , T1.course_id FROM Courses AS T1 JOIN Sections AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_id HAVING count(*) <= 2
|
||||
|
||||
medium pred: SELECT Students.first_name , Students.middle_name , Students.last_name , COUNT(*) FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.semester_id HAVING COUNT(*) = "value"
|
||||
medium gold: SELECT T1.first_name , T1.middle_name , T1.last_name , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id HAVING count(*) = 2
|
||||
|
||||
medium pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id HAVING COUNT(*) = "value"
|
||||
medium gold: SELECT T1.first_name , T1.middle_name , T1.last_name , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id HAVING count(*) = 2
|
||||
|
||||
hard pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students WHERE Student_Enrolment.degree_program_id = "value"
|
||||
hard gold: SELECT DISTINCT T1.first_name , T1.middle_name , T1.last_name FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id JOIN Degree_Programs AS T3 ON T2.degree_program_id = T3.degree_program_id WHERE T3.degree_summary_name = 'Bachelor'
|
||||
|
||||
hard pred: SELECT Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students WHERE Student_Enrolment.degree_program_id = "value"
|
||||
hard gold: SELECT DISTINCT T1.first_name , T1.middle_name , T1.last_name FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id JOIN Degree_Programs AS T3 ON T2.degree_program_id = T3.degree_program_id WHERE T3.degree_summary_name = 'Bachelor'
|
||||
|
||||
extra pred: SELECT Student_Enrolment.degree_program_id FROM Student_Enrolment GROUP BY Student_Enrolment.degree_program_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.degree_summary_name FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id GROUP BY T1.degree_summary_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Student_Enrolment.student_id , Students.first_name , Students.middle_name , Students.last_name FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.student_id , T1.first_name , T1.middle_name , T1.last_name , count(*) , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Students.first_name , Students.middle_name , COUNT(*) FROM Student_Enrolment JOIN Students GROUP BY Student_Enrolment.student_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.student_id , T1.first_name , T1.middle_name , T1.last_name , count(*) , T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Courses.course_name FROM Courses JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.course_name FROM Courses AS T1 JOIN Student_Enrolment_Courses AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Courses.course_name FROM Courses JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.course_name FROM Courses AS T1 JOIN Student_Enrolment_Courses AS T2 ON T1.course_id = T2.course_id GROUP BY T1.course_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Addresses.address_id , Addresses.line_1 FROM Addresses JOIN Students GROUP BY Addresses.address_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.address_id , T1.line_1 , T1.line_2 FROM Addresses AS T1 JOIN Students AS T2 ON T1.address_id = T2.current_address_id GROUP BY T1.address_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT MAX(Transcripts.transcript_date) FROM Transcripts
|
||||
medium gold: SELECT transcript_date FROM Transcripts ORDER BY transcript_date DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT Transcript_Contents.transcript_id , COUNT(*) FROM Transcript_Contents JOIN Student_Enrolment_Courses GROUP BY Student_Enrolment_Courses.student_enrolment_id ORDER BY COUNT(*) DESC LIMIT 1
|
||||
hard gold: SELECT count(*) , student_course_id FROM Transcript_Contents GROUP BY student_course_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT Student_Enrolment_Courses.course_id , COUNT(Transcript_Contents.transcript_id) FROM Student_Enrolment_Courses JOIN Transcript_Contents GROUP BY Student_Enrolment_Courses.course_id ORDER BY COUNT(Transcript_Contents.transcript_id) DESC LIMIT 1
|
||||
hard gold: SELECT count(*) , student_course_id FROM Transcript_Contents GROUP BY student_course_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Semesters.semester_name FROM Student_Enrolment JOIN Semesters JOIN Degree_Programs WHERE Degree_Programs.degree_summary_name = "value" INTERSECT SELECT Semesters.semester_name FROM Student_Enrolment JOIN Semesters JOIN Semesters WHERE Student_Enrolment.student_id = "value"
|
||||
extra gold: SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Master' INTERSECT SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Bachelor'
|
||||
|
||||
extra pred: SELECT Student_Enrolment.semester_id FROM Student_Enrolment WHERE Student_Enrolment.degree_program_id = "value" INTERSECT SELECT Student_Enrolment.semester_id FROM Student_Enrolment WHERE Student_Enrolment.degree_program_id = "value"
|
||||
extra gold: SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Master' INTERSECT SELECT DISTINCT T2.semester_id FROM Degree_Programs AS T1 JOIN Student_Enrolment AS T2 ON T1.degree_program_id = T2.degree_program_id WHERE degree_summary_name = 'Bachelor'
|
||||
|
||||
easy pred: SELECT Students.current_address_id FROM Students
|
||||
easy gold: SELECT count(DISTINCT current_address_id) FROM Students
|
||||
|
||||
medium pred: SELECT COUNT(TV_Channel.series_name) , TV_Channel.Content FROM TV_Channel
|
||||
medium gold: SELECT count(DISTINCT series_name) , count(DISTINCT content) FROM TV_Channel;
|
||||
|
||||
medium pred: SELECT TV_Channel.series_name FROM TV_series JOIN TV_Channel JOIN Cartoon WHERE Cartoon.Title = "value"
|
||||
medium gold: SELECT T1.series_name FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T2.Title = "The Rise of the Blue Beetle!";
|
||||
|
||||
medium pred: SELECT TV_Channel.series_name FROM TV_series JOIN TV_Channel JOIN Cartoon WHERE Cartoon.Title = "value"
|
||||
medium gold: SELECT T1.series_name FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T2.Title = "The Rise of the Blue Beetle!";
|
||||
|
||||
medium pred: SELECT Cartoon.Title FROM Cartoon JOIN TV_Channel JOIN TV_series WHERE TV_Channel.series_name = "value"
|
||||
medium gold: SELECT T2.Title FROM TV_Channel AS T1 JOIN Cartoon AS T2 ON T1.id = T2.Channel WHERE T1.series_name = "Sky Radio";
|
||||
|
||||
easy pred: SELECT TV_Channel.id FROM TV_Channel GROUP BY TV_Channel.id HAVING COUNT(*) > "value"
|
||||
easy gold: SELECT id FROM tv_channel GROUP BY country HAVING count(*) > 2
|
||||
|
||||
hard pred: SELECT TV_Channel.Package_Option FROM TV_Channel EXCEPT SELECT TV_Channel.Package_Option FROM TV_Channel JOIN Cartoon WHERE Cartoon.Directed_by = "value"
|
||||
hard gold: SELECT package_option FROM TV_Channel WHERE id NOT IN (SELECT channel FROM cartoon WHERE directed_by = 'Ben Jones')
|
||||
|
||||
hard pred: SELECT TV_Channel.Package_Option FROM TV_Channel EXCEPT SELECT TV_Channel.Package_Option FROM TV_Channel JOIN Cartoon WHERE Cartoon.Directed_by = "value"
|
||||
hard gold: SELECT package_option FROM TV_Channel WHERE id NOT IN (SELECT channel FROM cartoon WHERE directed_by = 'Ben Jones')
|
||||
|
||||
medium pred: SELECT people.Name FROM poker_player JOIN people GROUP BY people.Name ORDER BY COUNT(*) ASC
|
||||
medium gold: SELECT T1.Name FROM people AS T1 JOIN poker_player AS T2 ON T1.People_ID = T2.People_ID ORDER BY T2.Final_Table_Made
|
||||
|
||||
extra pred: SELECT CONTESTANTS.contestant_number , CONTESTANTS.contestant_name FROM CONTESTANTS JOIN VOTES ORDER BY VOTES.contestant_number ASC LIMIT 1
|
||||
extra gold: SELECT T1.contestant_number , T1.contestant_name FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number GROUP BY T1.contestant_number ORDER BY count(*) ASC LIMIT 1
|
||||
|
||||
extra pred: SELECT AREA_CODE_STATE.area_code FROM VOTES JOIN CONTESTANTS WHERE CONTESTANTS.contestant_name = "value" INTERSECT SELECT AREA_CODE_STATE.area_code FROM VOTES JOIN CONTESTANTS WHERE CONTESTANTS.contestant_name = "value"
|
||||
extra gold: SELECT T3.area_code FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number JOIN area_code_state AS T3 ON T2.state = T3.state WHERE T1.contestant_name = 'Tabatha Gehling' INTERSECT SELECT T3.area_code FROM contestants AS T1 JOIN votes AS T2 ON T1.contestant_number = T2.contestant_number JOIN area_code_state AS T3 ON T2.state = T3.state WHERE T1.contestant_name = 'Kelly Clauss'
|
||||
|
||||
easy pred: SELECT SUM(country.SurfaceArea) FROM country WHERE country.Continent = "value"
|
||||
easy gold: SELECT sum(SurfaceArea) FROM country WHERE Region = "Caribbean"
|
||||
|
||||
easy pred: SELECT country.Continent FROM country WHERE country.LocalName = "value"
|
||||
easy gold: SELECT Continent FROM country WHERE Name = "Anguilla"
|
||||
|
||||
extra pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.Name = "value" GROUP BY countrylanguage.Language ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Aruba" ORDER BY Percentage DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.Name = "value" GROUP BY countrylanguage.Language ORDER BY COUNT(country.Code) DESC LIMIT 1
|
||||
extra gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Aruba" ORDER BY Percentage DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT country.Population , country.GNP FROM country WHERE country.Continent = "value" ORDER BY country.GNP DESC LIMIT 1
|
||||
medium gold: SELECT sum(Population) , max(GNP) FROM country WHERE Continent = "Asia"
|
||||
|
||||
medium pred: SELECT AVG(country.LifeExpectancy) FROM country WHERE country.Continent = "value" AND country.Name = "value"
|
||||
medium gold: SELECT avg(LifeExpectancy) FROM country WHERE Continent = "Africa" AND GovernmentForm = "Republic"
|
||||
|
||||
medium pred: SELECT AVG(country.LifeExpectancy) FROM country WHERE country.Continent = "value" AND country.Continent = "value"
|
||||
medium gold: SELECT avg(LifeExpectancy) FROM country WHERE Continent = "Africa" AND GovernmentForm = "Republic"
|
||||
|
||||
easy pred: SELECT city.Population FROM city WHERE city.District = "value"
|
||||
easy gold: SELECT sum(Population) FROM city WHERE District = "Gelderland"
|
||||
|
||||
medium pred: SELECT AVG(country.GNP) , AVG(country.Population) FROM country WHERE country.GovernmentForm = "value"
|
||||
medium gold: SELECT avg(GNP) , sum(population) FROM country WHERE GovernmentForm = "US Territory"
|
||||
|
||||
medium pred: SELECT AVG(country.GNP) , AVG(country.Population) FROM country WHERE country.LocalName = "value"
|
||||
medium gold: SELECT avg(GNP) , sum(population) FROM country WHERE GovernmentForm = "US Territory"
|
||||
|
||||
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.Name = "value"
|
||||
medium gold: SELECT COUNT(*) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Afghanistan" AND IsOfficial = "T"
|
||||
|
||||
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.Name = "value"
|
||||
medium gold: SELECT COUNT(*) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Afghanistan" AND IsOfficial = "T"
|
||||
|
||||
extra pred: SELECT country.Name FROM countrylanguage JOIN country GROUP BY countrylanguage.CountryCode ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name ORDER BY COUNT(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT country.Continent FROM country GROUP BY country.Continent ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.Continent FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Continent ORDER BY COUNT(*) DESC LIMIT 1
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" AND countrylanguage.Language = "value"
|
||||
easy gold: SELECT COUNT(*) FROM (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch")
|
||||
|
||||
easy pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" AND countrylanguage.Language = "value"
|
||||
easy gold: SELECT COUNT(*) FROM (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch")
|
||||
|
||||
extra pred: SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value" INTERSECT SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value"
|
||||
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "French" AND T2.IsOfficial = "T"
|
||||
|
||||
extra pred: SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value" INTERSECT SELECT country.Name FROM country JOIN countrylanguage WHERE countrylanguage.Language = "value"
|
||||
extra gold: SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T" INTERSECT SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "French" AND T2.IsOfficial = "T"
|
||||
|
||||
extra pred: SELECT country.Name FROM countrylanguage JOIN country WHERE countrylanguage.Language = "value" OR countrylanguage.Language = "value"
|
||||
extra gold: select t1.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode where t2.language = "english" and isofficial = "t" union select t1.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode where t2.language = "dutch" and isofficial = "t"
|
||||
|
||||
extra pred: SELECT country.Name FROM countrylanguage JOIN country WHERE countrylanguage.Language = "value" OR countrylanguage.Language = "value"
|
||||
extra gold: SELECT * FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND IsOfficial = "T" UNION SELECT * FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Dutch" AND IsOfficial = "T"
|
||||
|
||||
extra pred: SELECT city.Name FROM city WHERE city.Population = "value" ORDER BY city.Population DESC LIMIT 1
|
||||
extra gold: SELECT T1.Name , T1.Population FROM city AS T1 JOIN countrylanguage AS T2 ON T1.CountryCode = T2.CountryCode WHERE T2.Language = "English" ORDER BY T1.Population DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT city.Name FROM city JOIN countrylanguage WHERE countrylanguage.Language = "value" ORDER BY city.Population DESC LIMIT 1
|
||||
extra gold: SELECT T1.Name , T1.Population FROM city AS T1 JOIN countrylanguage AS T2 ON T1.CountryCode = T2.CountryCode WHERE T2.Language = "English" ORDER BY T1.Population DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT country.Name , country.Population , country.LifeExpectancy FROM country WHERE country.SurfaceArea = "value" ORDER BY country.Continent DESC LIMIT 1
|
||||
hard gold: SELECT Name , Population , LifeExpectancy FROM country WHERE Continent = "Asia" ORDER BY SurfaceArea DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT AVG(country.LifeExpectancy) FROM country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language != "value"
|
||||
extra gold: SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T")
|
||||
|
||||
extra pred: SELECT AVG(country.LifeExpectancy) FROM country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language != "value"
|
||||
extra gold: SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" AND T2.IsOfficial = "T")
|
||||
|
||||
extra pred: SELECT SUM(country.Population) FROM country JOIN countrylanguage WHERE countrylanguage.Language != "value"
|
||||
extra gold: SELECT sum(Population) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English")
|
||||
|
||||
extra pred: SELECT country.Population FROM country JOIN countrylanguage WHERE countrylanguage.Language != "value"
|
||||
extra gold: SELECT sum(Population) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English")
|
||||
|
||||
medium pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.HeadOfState = "value"
|
||||
medium gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.HeadOfState = "Beatrix" AND T2.IsOfficial = "T"
|
||||
|
||||
medium pred: SELECT countrylanguage.Language FROM countrylanguage JOIN country WHERE country.HeadOfState = "value"
|
||||
medium gold: SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.HeadOfState = "Beatrix" AND T2.IsOfficial = "T"
|
||||
|
||||
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.IndepYear < "value"
|
||||
medium gold: SELECT count(DISTINCT T2.Language) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE IndepYear < 1930 AND T2.IsOfficial = "T"
|
||||
|
||||
medium pred: SELECT COUNT(countrylanguage.Language) FROM countrylanguage JOIN country WHERE country.IndepYear < "value"
|
||||
medium gold: SELECT count(DISTINCT T2.Language) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE IndepYear < 1930 AND T2.IsOfficial = "T"
|
||||
|
||||
hard pred: SELECT country.Name FROM country WHERE country.SurfaceArea > ( SELECT MAX(country.SurfaceArea) FROM country WHERE country.Continent = "value" )
|
||||
hard gold: SELECT Name FROM country WHERE SurfaceArea > (SELECT min(SurfaceArea) FROM country WHERE Continent = "Europe")
|
||||
|
||||
hard pred: SELECT country.Name FROM country WHERE country.SurfaceArea > ( SELECT MAX(country.SurfaceArea) FROM country WHERE country.Continent = "value" )
|
||||
hard gold: SELECT Name FROM country WHERE SurfaceArea > (SELECT min(SurfaceArea) FROM country WHERE Continent = "Europe")
|
||||
|
||||
extra pred: SELECT country.Name FROM country WHERE country.Population < ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
|
||||
extra gold: SELECT Name FROM country WHERE Continent = "Africa" AND population < (SELECT max(population) FROM country WHERE Continent = "Asia")
|
||||
|
||||
extra pred: SELECT country.Name FROM country WHERE country.Population < ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
|
||||
extra gold: SELECT Name FROM country WHERE Continent = "Africa" AND population < (SELECT min(population) FROM country WHERE Continent = "Asia")
|
||||
|
||||
extra pred: SELECT country.Name FROM country WHERE country.Population > ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
|
||||
extra gold: SELECT Name FROM country WHERE Continent = "Asia" AND population > (SELECT max(population) FROM country WHERE Continent = "Africa")
|
||||
|
||||
extra pred: SELECT country.Name FROM country WHERE country.Population > ( SELECT MAX(country.Population) FROM country WHERE country.Continent = "value" )
|
||||
extra gold: SELECT Name FROM country WHERE Continent = "Asia" AND population > (SELECT min(population) FROM country WHERE Continent = "Africa")
|
||||
|
||||
hard pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
|
||||
hard gold: SELECT CountryCode FROM countrylanguage EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
|
||||
|
||||
hard pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
|
||||
hard gold: SELECT CountryCode FROM countrylanguage EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
|
||||
|
||||
hard pred: SELECT country.Code FROM country WHERE country.GovernmentForm = "value" AND country.GovernmentForm != "value"
|
||||
hard gold: SELECT Code FROM country WHERE GovernmentForm != "Republic" EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
|
||||
|
||||
hard pred: SELECT country.Code FROM country WHERE country.GovernmentForm = "value" EXCEPT SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language != "value"
|
||||
hard gold: SELECT Code FROM country WHERE GovernmentForm != "Republic" EXCEPT SELECT CountryCode FROM countrylanguage WHERE LANGUAGE = "English"
|
||||
|
||||
extra pred: SELECT city.Name FROM city WHERE city.CountryCode = "value" AND city.CountryCode != "value"
|
||||
extra gold: SELECT DISTINCT T2.Name FROM country AS T1 JOIN city AS T2 ON T2.CountryCode = T1.Code WHERE T1.Continent = 'Europe' AND T1.Name NOT IN (SELECT T3.Name FROM country AS T3 JOIN countrylanguage AS T4 ON T3.Code = T4.CountryCode WHERE T4.IsOfficial = 'T' AND T4.Language = 'English')
|
||||
|
||||
extra pred: SELECT city.Name FROM city JOIN country WHERE countrylanguage.IsOfficial = "value" AND country.Continent != "value"
|
||||
extra gold: SELECT DISTINCT T2.Name FROM country AS T1 JOIN city AS T2 ON T2.CountryCode = T1.Code WHERE T1.Continent = 'Europe' AND T1.Name NOT IN (SELECT T3.Name FROM country AS T3 JOIN countrylanguage AS T4 ON T3.Code = T4.CountryCode WHERE T4.IsOfficial = 'T' AND T4.Language = 'English')
|
||||
|
||||
hard pred: SELECT city.Name FROM city JOIN country JOIN countrylanguage WHERE countrylanguage.IsOfficial = "value" AND countrylanguage.Language = "value"
|
||||
hard gold: select distinct t3.name from country as t1 join countrylanguage as t2 on t1.code = t2.countrycode join city as t3 on t1.code = t3.countrycode where t2.isofficial = 't' and t2.language = 'chinese' and t1.continent = "asia"
|
||||
|
||||
hard pred: SELECT city.Name FROM city JOIN country JOIN country JOIN countrylanguage WHERE country.Continent = "value" AND countrylanguage.Language = "value"
|
||||
hard gold: SELECT DISTINCT T3.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode JOIN city AS T3 ON T1.Code = T3.CountryCode WHERE T2.IsOfficial = 'T' AND T2.Language = 'Chinese' AND T1.Continent = "Asia"
|
||||
|
||||
medium pred: SELECT country.Population , country.Name , country.SurfaceArea FROM country ORDER BY country.SurfaceArea DESC LIMIT 1
|
||||
medium gold: SELECT Name , population , HeadOfState FROM country ORDER BY SurfaceArea DESC LIMIT 1
|
||||
|
||||
medium pred: SELECT country.Name , COUNT(*) FROM countrylanguage JOIN country GROUP BY country.Name HAVING COUNT(*) >= "value"
|
||||
medium gold: SELECT COUNT(T2.Language) , T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name HAVING COUNT(*) > 2
|
||||
|
||||
medium pred: SELECT country.Name , COUNT(*) FROM countrylanguage JOIN country GROUP BY country.Name HAVING COUNT(*) > "value"
|
||||
medium gold: SELECT COUNT(T2.Language) , T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode GROUP BY T1.Name HAVING COUNT(*) > 2
|
||||
|
||||
medium pred: SELECT country.GovernmentForm , AVG(country.Population) FROM country GROUP BY country.GovernmentForm HAVING AVG(country.LifeExpectancy) > "value"
|
||||
medium gold: SELECT sum(Population) , GovernmentForm FROM country GROUP BY GovernmentForm HAVING avg(LifeExpectancy) > 72
|
||||
|
||||
medium pred: SELECT country.GovernmentForm , SUM(country.Population) FROM country WHERE country.LifeExpectancy > "value" GROUP BY country.GovernmentForm
|
||||
medium gold: SELECT sum(Population) , GovernmentForm FROM country GROUP BY GovernmentForm HAVING avg(LifeExpectancy) > 72
|
||||
|
||||
medium pred: SELECT AVG(country.LifeExpectancy) , AVG(country.Population) , country.Continent FROM country GROUP BY country.Continent HAVING AVG(country.Population) < "value"
|
||||
medium gold: SELECT sum(Population) , avg(LifeExpectancy) , Continent FROM country GROUP BY Continent HAVING avg(LifeExpectancy) < 72
|
||||
|
||||
medium pred: SELECT country.Continent , SUM(country.Population) , SUM(country.LifeExpectancy) FROM country WHERE country.LifeExpectancy < "value" GROUP BY country.Continent
|
||||
medium gold: SELECT sum(Population) , avg(LifeExpectancy) , Continent FROM country GROUP BY Continent HAVING avg(LifeExpectancy) < 72
|
||||
|
||||
medium pred: SELECT country.Name FROM country WHERE country.Continent = "value" AND country.Population > "value"
|
||||
medium gold: SELECT Name FROM country WHERE continent = "Europe" AND Population = "80000"
|
||||
|
||||
medium pred: SELECT country.Name FROM country WHERE country.Continent = "value" AND country.Population > "value"
|
||||
medium gold: SELECT Name FROM country WHERE continent = "Europe" AND Population = "80000"
|
||||
|
||||
hard pred: SELECT SUM(country.Population) , AVG(country.SurfaceArea) FROM country WHERE country.Continent = "value"
|
||||
hard gold: select sum(population) , avg(surfacearea) from country where continent = "north america" and surfacearea > 3000
|
||||
|
||||
hard pred: SELECT SUM(country.Population) , AVG(country.SurfaceArea) FROM country WHERE country.Continent = "value"
|
||||
hard gold: select sum(population) , avg(surfacearea) from country where continent = "north america" and surfacearea > 3000
|
||||
|
||||
medium pred: SELECT countrylanguage.Language , MAX(countrylanguage.Percentage) FROM countrylanguage GROUP BY countrylanguage.Language ORDER BY SUM(countrylanguage.Percentage) DESC LIMIT 1
|
||||
medium gold: SELECT LANGUAGE , CountryCode , max(Percentage) FROM countrylanguage GROUP BY CountryCode
|
||||
|
||||
medium pred: SELECT countrylanguage.CountryCode , SUM(countrylanguage.Language) FROM countrylanguage GROUP BY countrylanguage.CountryCode ORDER BY SUM(countrylanguage.Percentage) DESC LIMIT 1
|
||||
medium gold: SELECT LANGUAGE , CountryCode , max(Percentage) FROM countrylanguage GROUP BY CountryCode
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" ORDER BY countrylanguage.Percentage DESC LIMIT 1
|
||||
extra gold: SELECT count(*) , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
|
||||
|
||||
extra pred: SELECT COUNT(*) FROM countrylanguage WHERE countrylanguage.Language = "value" GROUP BY * ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT count(*) , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
|
||||
|
||||
medium pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language = "value" ORDER BY countrylanguage.Percentage DESC LIMIT 1
|
||||
medium gold: SELECT CountryCode , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
|
||||
|
||||
medium pred: SELECT countrylanguage.CountryCode FROM countrylanguage WHERE countrylanguage.Language = "value" GROUP BY countrylanguage.CountryCode ORDER BY AVG(countrylanguage.Percentage) DESC LIMIT 1
|
||||
medium gold: SELECT CountryCode , max(Percentage) FROM countrylanguage WHERE LANGUAGE = "Spanish" GROUP BY CountryCode
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM Friend
|
||||
medium gold: SELECT student_id , count(*) FROM Friend GROUP BY student_id
|
||||
|
||||
medium pred: SELECT Highschooler.name , COUNT(*) FROM Highschooler JOIN Friend GROUP BY Highschooler.name
|
||||
medium gold: SELECT T2.name , count(*) FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id
|
||||
|
||||
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
|
||||
hard gold: SELECT T3.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id JOIN Highschooler AS T3 ON T1.friend_id = T3.id WHERE T2.name = "Kyle"
|
||||
|
||||
hard pred: SELECT Friend.friend_id FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
|
||||
hard gold: SELECT T3.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id JOIN Highschooler AS T3 ON T1.friend_id = T3.id WHERE T2.name = "Kyle"
|
||||
|
||||
medium pred: SELECT COUNT(Friend.friend_id) FROM Highschooler JOIN Friend WHERE Highschooler.name = "value"
|
||||
medium gold: SELECT count(*) FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.name = "Kyle"
|
||||
|
||||
hard pred: SELECT Highschooler.ID FROM Highschooler INTERSECT SELECT Likes.student_id FROM Likes
|
||||
hard gold: SELECT student_id FROM Friend INTERSECT SELECT liked_id FROM Likes
|
||||
|
||||
hard pred: SELECT Likes.student_id FROM Likes INTERSECT SELECT Likes.liked_id FROM Likes
|
||||
hard gold: SELECT student_id FROM Friend INTERSECT SELECT liked_id FROM Likes
|
||||
|
||||
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Likes INTERSECT SELECT Highschooler.name FROM Highschooler JOIN Likes
|
||||
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id INTERSECT SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.liked_id = T2.id
|
||||
|
||||
medium pred: SELECT COUNT(*) FROM Likes GROUP BY Likes.student_id
|
||||
medium gold: SELECT student_id , count(*) FROM Likes GROUP BY student_id
|
||||
|
||||
extra pred: SELECT Highschooler.name FROM Highschooler JOIN Likes ORDER BY Highschooler.name DESC LIMIT 1
|
||||
extra gold: SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Highschooler.name FROM Highschooler JOIN Likes ORDER BY Likes.student_id DESC LIMIT 1
|
||||
extra gold: SELECT T2.name FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.grade > "value" GROUP BY Friend.student_id ORDER BY COUNT(*) ASC
|
||||
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.grade > 5 GROUP BY T1.student_id HAVING count(*) >= 2
|
||||
|
||||
hard pred: SELECT Highschooler.name FROM Highschooler JOIN Friend WHERE Highschooler.grade > "value" GROUP BY Friend.student_id ORDER BY COUNT(*) ASC
|
||||
hard gold: SELECT T2.name FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.grade > 5 GROUP BY T1.student_id HAVING count(*) >= 2
|
||||
|
||||
medium pred: SELECT Likes.liked_id FROM Highschooler JOIN Likes WHERE Highschooler.name = "value"
|
||||
medium gold: SELECT count(*) FROM Likes AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id WHERE T2.name = "Kyle"
|
||||
|
||||
hard pred: SELECT AVG(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID IN ( SELECT Friend.student_id FROM Friend )
|
||||
hard gold: SELECT avg(grade) FROM Highschooler WHERE id IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
|
||||
|
||||
hard pred: SELECT AVG(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID IN ( SELECT Friend.student_id FROM Friend )
|
||||
hard gold: SELECT avg(grade) FROM Highschooler WHERE id IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
|
||||
|
||||
extra pred: SELECT MIN(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID NOT IN ( SELECT Friend.student_id FROM Friend )
|
||||
extra gold: SELECT min(grade) FROM Highschooler WHERE id NOT IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
|
||||
|
||||
extra pred: SELECT MIN(Highschooler.grade) FROM Highschooler WHERE Highschooler.ID NOT IN ( SELECT Friend.student_id FROM Friend )
|
||||
extra gold: SELECT min(grade) FROM Highschooler WHERE id NOT IN (SELECT T1.student_id FROM Friend AS T1 JOIN Highschooler AS T2 ON T1.student_id = T2.id)
|
||||
|
||||
hard pred: SELECT AVG(Dogs.age) FROM Dogs JOIN Treatments
|
||||
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
hard pred: SELECT AVG(Dogs.age) FROM Treatments JOIN Dogs
|
||||
hard gold: SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments )
|
||||
|
||||
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Treatments.professional_id HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
|
||||
|
||||
extra pred: SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments WHERE Professionals.state = "value" UNION SELECT Treatments.professional_id , Professionals.last_name , Professionals.cell_number FROM Professionals JOIN Treatments GROUP BY Professionals.state HAVING COUNT(*) > "value"
|
||||
extra gold: SELECT professional_id , last_name , cell_number FROM Professionals WHERE state = 'Indiana' UNION SELECT T1.professional_id , T1.last_name , T1.cell_number FROM Professionals AS T1 JOIN Treatments AS T2 ON T1.professional_id = T2.professional_id GROUP BY T1.professional_id HAVING count(*) > 2
|
||||
|
||||
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
|
||||
hard gold: select name from dogs where dog_id not in ( select dog_id from treatments group by dog_id having sum(cost_of_treatment) > 1000 )
|
||||
|
||||
hard pred: SELECT Dogs.name FROM Dogs EXCEPT SELECT Dogs.name FROM Dogs JOIN Treatments WHERE Treatments.cost_of_treatment > "value"
|
||||
hard gold: select name from dogs where dog_id not in ( select dog_id from treatments group by dog_id having sum(cost_of_treatment) > 1000 )
|
||||
|
||||
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
|
||||
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
|
||||
|
||||
hard pred: SELECT Professionals.first_name FROM Professionals JOIN Owners EXCEPT SELECT Owners.first_name FROM Owners
|
||||
hard gold: SELECT first_name FROM Professionals UNION SELECT first_name FROM Owners EXCEPT SELECT name FROM Dogs
|
||||
|
||||
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Breeds.breed_name FROM Dogs JOIN Breeds GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.breed_name FROM Breeds AS T1 JOIN Dogs AS T2 ON T1.breed_code = T2.breed_code GROUP BY T1.breed_name ORDER BY count(*) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs JOIN Charges GROUP BY Owners.owner_id ORDER BY SUM(Charges.charge_amount) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Owners.owner_id , Owners.zip_code FROM Owners JOIN Dogs GROUP BY Owners.owner_id ORDER BY SUM(Dogs.date_adopted) DESC LIMIT 1
|
||||
extra gold: SELECT T1.owner_id , T1.zip_code FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id JOIN Treatments AS T3 ON T2.dog_id = T3.dog_id GROUP BY T1.owner_id ORDER BY sum(T3.cost_of_treatment) DESC LIMIT 1
|
||||
|
||||
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Treatments JOIN Dogs GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
|
||||
|
||||
extra pred: SELECT Dogs.name , Treatments.date_of_treatment FROM Dogs JOIN Treatments GROUP BY Dogs.breed_code ORDER BY COUNT(*) DESC LIMIT 1
|
||||
extra gold: SELECT T1.name , T2.date_of_treatment FROM Dogs AS T1 JOIN Treatments AS T2 ON T1.dog_id = T2.dog_id WHERE T1.breed_code = ( SELECT breed_code FROM Dogs GROUP BY breed_code ORDER BY count(*) ASC LIMIT 1 )
|
||||
|
||||
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
|
||||
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
|
||||
|
||||
extra pred: SELECT Owners.last_name FROM Owners JOIN Dogs ORDER BY Dogs.age ASC LIMIT 1
|
||||
extra gold: SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id = T2.owner_id WHERE T2.age = ( SELECT max(age) FROM Dogs )
|
||||
|
||||
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
|
||||
medium gold: SELECT charge_type , charge_amount FROM Charges
|
||||
|
||||
medium pred: SELECT Charges.charge_type , SUM(Charges.charge_amount) FROM Charges GROUP BY Charges.charge_type
|
||||
medium gold: SELECT charge_type , charge_amount FROM Charges
|
||||
|
||||
easy pred: SELECT Charges.charge_type FROM Charges ORDER BY Charges.charge_amount DESC LIMIT 1
|
||||
easy gold: SELECT max(charge_amount) FROM Charges
|
||||
|
||||
easy pred: SELECT Charges.charge_amount FROM Charges ORDER BY Charges.charge_type DESC LIMIT 1
|
||||
easy gold: SELECT max(charge_amount) FROM Charges
|
||||
|
||||
medium pred: SELECT Sizes.size_code , Sizes.size_code FROM Sizes
|
||||
medium gold: SELECT DISTINCT breed_code , size_code FROM dogs
|
||||
|
||||
medium pred: SELECT singer.Name FROM singer JOIN song GROUP BY song.Singer_ID HAVING COUNT(*) > "value"
|
||||
medium gold: SELECT T1.Name FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID = T2.Singer_ID GROUP BY T1.Name HAVING COUNT(*) > 1
|
||||
|
||||
medium pred: SELECT Ref_Property_Types.property_type_description FROM Ref_Property_Types JOIN Properties ORDER BY Properties.property_type_code DESC LIMIT 1
|
||||
medium gold: SELECT T2.property_type_description FROM Properties AS T1 JOIN Ref_Property_Types AS T2 ON T1.property_type_code = T2.property_type_code GROUP BY T1.property_type_code
|
||||
|
||||
hard pred: SELECT Properties.property_name FROM Properties WHERE Properties.property_type_code = "value" UNION SELECT Properties.property_name FROM Properties WHERE Properties.room_count > "value"
|
||||
hard gold: SELECT property_name FROM Properties WHERE property_type_code = "House" UNION SELECT property_name FROM Properties WHERE property_type_code = "Apartment" AND room_count > 1
|
||||
|
||||
easy medium hard extra all
|
||||
count 248 446 174 166 1034
|
||||
|
||||
====================== EXACT MATCHING ACCURACY =====================
|
||||
exact match 0.927 0.796 0.684 0.512 0.763
|
||||
|
||||
---------------------PARTIAL MATCHING ACCURACY----------------------
|
||||
select 0.964 0.904 0.977 0.898 0.929
|
||||
select(no AGG) 0.976 0.933 0.977 0.916 0.948
|
||||
where 0.936 0.876 0.719 0.640 0.816
|
||||
where(no OP) 0.955 0.887 0.809 0.697 0.852
|
||||
group(no Having) 0.857 0.885 0.857 0.810 0.857
|
||||
group 0.857 0.831 0.810 0.797 0.820
|
||||
order 0.917 0.817 0.820 0.788 0.817
|
||||
and/or 1.000 0.986 0.977 0.915 0.977
|
||||
IUEN 0.000 0.000 0.667 0.500 0.582
|
||||
keywords 0.961 0.950 0.878 0.819 0.913
|
||||
---------------------- PARTIAL MATCHING RECALL ----------------------
|
||||
select 0.964 0.904 0.977 0.898 0.929
|
||||
select(no AGG) 0.976 0.933 0.977 0.916 0.948
|
||||
where 0.954 0.896 0.681 0.606 0.810
|
||||
where(no OP) 0.972 0.907 0.766 0.660 0.845
|
||||
group(no Having) 0.900 0.865 0.923 0.810 0.860
|
||||
group 0.900 0.812 0.872 0.797 0.823
|
||||
order 1.000 0.893 0.909 0.848 0.892
|
||||
and/or 0.992 0.995 0.994 0.974 0.991
|
||||
IUEN 0.000 0.000 0.667 0.529 0.605
|
||||
keywords 0.993 0.950 0.868 0.819 0.916
|
||||
---------------------- PARTIAL MATCHING F1 --------------------------
|
||||
select 0.964 0.904 0.977 0.898 0.929
|
||||
select(no AGG) 0.976 0.933 0.977 0.916 0.948
|
||||
where 0.945 0.886 0.699 0.623 0.813
|
||||
where(no OP) 0.963 0.897 0.787 0.678 0.849
|
||||
group(no Having) 0.878 0.875 0.889 0.810 0.858
|
||||
group 0.878 0.821 0.840 0.797 0.821
|
||||
order 0.957 0.854 0.862 0.817 0.853
|
||||
and/or 0.996 0.991 0.985 0.943 0.984
|
||||
IUEN 1.000 1.000 0.667 0.514 0.594
|
||||
keywords 0.977 0.950 0.873 0.819 0.914
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
task=evaluation
|
||||
read_model_path=saved_models/$1
|
||||
batch_size=20
|
||||
beam_size=5
|
||||
device=0
|
||||
|
||||
python scripts/text2sql.py --task $task --testing --read_model_path $read_model_path \
|
||||
--batch_size $batch_size --beam_size $beam_size --device $device
|
|
@ -0,0 +1,58 @@
|
|||
task=lgesql_large
|
||||
seed=999
|
||||
device=0
|
||||
testing='' #'--testing'
|
||||
read_model_path='saved_models/electra-msde-75.3/'
|
||||
|
||||
model=lgesql
|
||||
output_model=with_pruning # without_pruning
|
||||
local_and_nonlocal=$1 # mmc, msde, local
|
||||
plm=$2
|
||||
subword_aggregation=attentive-pooling
|
||||
schema_aggregation=head+tail
|
||||
gnn_hidden_size=512
|
||||
gnn_num_layers=8
|
||||
relation_share_heads='--relation_share_heads'
|
||||
score_function='affine'
|
||||
num_heads=8
|
||||
dropout=0.2
|
||||
attn_drop=0.0
|
||||
drop_connect=0.2
|
||||
|
||||
lstm=onlstm
|
||||
chunk_size=8
|
||||
att_vec_size=512
|
||||
sep_cxt=''
|
||||
lstm_hidden_size=512
|
||||
lstm_num_layers=1
|
||||
action_embed_size=128
|
||||
field_embed_size=64
|
||||
type_embed_size=64
|
||||
no_context_feeding='--no_context_feeding'
|
||||
no_parent_production_embed=''
|
||||
no_parent_field_embed=''
|
||||
no_parent_field_type_embed=''
|
||||
no_parent_state=''
|
||||
|
||||
batch_size=20
|
||||
grad_accumulate=5
|
||||
lr=1e-4
|
||||
layerwise_decay=0.8
|
||||
l2=0.1
|
||||
smoothing=0.15
|
||||
warmup_ratio=0.1
|
||||
lr_schedule=linear
|
||||
eval_after_epoch=1
|
||||
max_epoch=200
|
||||
max_norm=5
|
||||
beam_size=5
|
||||
|
||||
python scripts/text2gan.py --task $task --seed $seed --device $device $testing --read_model_path $read_model_path \
|
||||
--plm $plm --gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
|
||||
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
|
||||
--subword_aggregation $subword_aggregation --schema_aggregation $schema_aggregation --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
|
||||
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
|
||||
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
|
||||
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
|
||||
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
|
||||
--smoothing $smoothing --layerwise_decay $layerwise_decay --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size
|
|
@ -0,0 +1,56 @@
|
|||
task=lgesql_glove
|
||||
seed=999
|
||||
device=0
|
||||
testing='' #'--testing'
|
||||
read_model_path=''
|
||||
|
||||
model=lgesql
|
||||
output_model=with_pruning # without_pruning
|
||||
local_and_nonlocal=$1 # mmc, msde, local
|
||||
embed_size=300
|
||||
schema_aggregation=head+tail
|
||||
gnn_hidden_size=256
|
||||
gnn_num_layers=8
|
||||
relation_share_heads='' # '--relation_share_heads'
|
||||
score_function='affine'
|
||||
num_heads=8
|
||||
dropout=0.2
|
||||
attn_drop=0.0
|
||||
drop_connect=0.2
|
||||
|
||||
lstm=onlstm
|
||||
chunk_size=8
|
||||
att_vec_size=512
|
||||
sep_cxt=''
|
||||
lstm_hidden_size=512
|
||||
lstm_num_layers=1
|
||||
action_embed_size=128
|
||||
field_embed_size=64
|
||||
type_embed_size=64
|
||||
no_context_feeding='--no_context_feeding'
|
||||
no_parent_production_embed=''
|
||||
no_parent_field_embed=''
|
||||
no_parent_field_type_embed=''
|
||||
no_parent_state=''
|
||||
|
||||
batch_size=20
|
||||
grad_accumulate=2
|
||||
lr=5e-4
|
||||
l2=1e-4
|
||||
smooth=0.15
|
||||
warmup_ratio=0.1
|
||||
lr_schedule=linear
|
||||
eval_after_epoch=60
|
||||
max_epoch=100
|
||||
max_norm=5
|
||||
beam_size=5
|
||||
|
||||
python scripts/text2sql.py --task $task --seed $seed --device $device $testing $read_model_path \
|
||||
--gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
|
||||
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
|
||||
--schema_aggregation $schema_aggregation --embed_size $embed_size --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
|
||||
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
|
||||
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
|
||||
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
|
||||
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
|
||||
--smooth $smooth --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size
|
|
@ -0,0 +1,58 @@
|
|||
task=lgesql_large
|
||||
seed=999
|
||||
device=0
|
||||
testing='' #'--testing'
|
||||
read_model_path=''
|
||||
|
||||
model=lgesql
|
||||
output_model=with_pruning # without_pruning
|
||||
local_and_nonlocal=$1 # mmc, msde, local
|
||||
plm=$2
|
||||
subword_aggregation=attentive-pooling
|
||||
schema_aggregation=head+tail
|
||||
gnn_hidden_size=512
|
||||
gnn_num_layers=8
|
||||
relation_share_heads='--relation_share_heads'
|
||||
score_function='affine'
|
||||
num_heads=8
|
||||
dropout=0.2
|
||||
attn_drop=0.0
|
||||
drop_connect=0.2
|
||||
|
||||
lstm=onlstm
|
||||
chunk_size=8
|
||||
att_vec_size=512
|
||||
sep_cxt=''
|
||||
lstm_hidden_size=512
|
||||
lstm_num_layers=1
|
||||
action_embed_size=128
|
||||
field_embed_size=64
|
||||
type_embed_size=64
|
||||
no_context_feeding='--no_context_feeding'
|
||||
no_parent_production_embed=''
|
||||
no_parent_field_embed=''
|
||||
no_parent_field_type_embed=''
|
||||
no_parent_state=''
|
||||
|
||||
batch_size=20
|
||||
grad_accumulate=5
|
||||
lr=1e-4
|
||||
layerwise_decay=0.8
|
||||
l2=0.1
|
||||
smoothing=0.15
|
||||
warmup_ratio=0.1
|
||||
lr_schedule=linear
|
||||
eval_after_epoch=120
|
||||
max_epoch=200
|
||||
max_norm=5
|
||||
beam_size=5
|
||||
|
||||
python scripts/text2sql.py --task $task --seed $seed --device $device $testing $read_model_path \
|
||||
--plm $plm --gnn_hidden_size $gnn_hidden_size --dropout $dropout --attn_drop $attn_drop --att_vec_size $att_vec_size \
|
||||
--model $model --output_model $output_model --local_and_nonlocal $local_and_nonlocal --score_function $score_function $relation_share_heads \
|
||||
--subword_aggregation $subword_aggregation --schema_aggregation $schema_aggregation --gnn_num_layers $gnn_num_layers --num_heads $num_heads $sep_cxt \
|
||||
--lstm $lstm --chunk_size $chunk_size --drop_connect $drop_connect --lstm_hidden_size $lstm_hidden_size --lstm_num_layers $lstm_num_layers \
|
||||
--action_embed_size $action_embed_size --field_embed_size $field_embed_size --type_embed_size $type_embed_size \
|
||||
$no_context_feeding $no_parent_production_embed $no_parent_field_embed $no_parent_field_type_embed $no_parent_state \
|
||||
--batch_size $batch_size --grad_accumulate $grad_accumulate --lr $lr --l2 $l2 --warmup_ratio $warmup_ratio --lr_schedule $lr_schedule --eval_after_epoch $eval_after_epoch \
|
||||
--smoothing $smoothing --layerwise_decay $layerwise_decay --max_epoch $max_epoch --max_norm $max_norm --beam_size $beam_size
|
|
@ -0,0 +1,20 @@
|
|||
#!/bin/bash
|
||||
|
||||
train_data='data/train.json'
|
||||
dev_data='data/dev.json'
|
||||
table_data='data/tables.json'
|
||||
train_out='data/train.lgesql.bin'
|
||||
dev_out='data/dev.lgesql.bin'
|
||||
table_out='data/tables.bin'
|
||||
vocab_glove='pretrained_models/glove.42b.300d/vocab_glove.txt'
|
||||
vocab='pretrained_models/glove.42b.300d/vocab.txt'
|
||||
|
||||
#echo "Start to preprocess the original train dataset ..."
|
||||
#python3 -u preprocess/process_dataset.py --dataset_path ${train_data} --raw_table_path ${table_data} --table_path ${table_out} --output_path 'data/train.bin' --skip_large #--verbose > train.log
|
||||
echo "Start to preprocess the original dev dataset ..."
|
||||
python3 -u preprocess/process_dataset.py --dataset_path ${dev_data} --raw_table_path ${table_data} --table_path ${table_out} --output_path 'data/dev.bin' #--verbose > dev.log
|
||||
#echo "Start to build word vocab for the dataset ..."
|
||||
#python3 -u preprocess/build_glove_vocab.py --data_paths 'data/train.bin' --table_path ${table_out} --reference_file ${vocab_glove} --mwf 4 --output_path ${vocab}
|
||||
echo "Start to construct graphs for the dataset ..."
|
||||
#python3 -u preprocess/process_graphs.py --dataset_path 'data/train.bin' --table_path ${table_out} --method 'lgesql' --output_path ${train_out}
|
||||
python3 -u preprocess/process_graphs.py --dataset_path 'data/dev.bin' --table_path ${table_out} --method 'lgesql' --output_path ${dev_out}
|
|
@ -0,0 +1,7 @@
|
|||
saved_model=saved_models/$1
|
||||
output_path=saved_models/$1/predicted_sql.txt
|
||||
batch_size=10
|
||||
beam_size=5
|
||||
|
||||
python eval.py --db_dir 'data/database' --table_path 'data/tables.json' --dataset_path 'data/dev.json' --saved_model $saved_model --output_path $output_path --batch_size $batch_size --beam_size $beam_size
|
||||
python evaluation.py --gold data/dev_gold.sql --pred $output_path --db data/database --table data/tables.json --etype match > $saved_model/evaluation.log
|
|
@ -0,0 +1,124 @@
|
|||
#coding=utf8
|
||||
import sys, os, time, json, gc
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from argparse import Namespace
|
||||
from utils.args import init_args
|
||||
from utils.hyperparams import hyperparam_path
|
||||
from utils.initialization import *
|
||||
from utils.example import Example
|
||||
from utils.batch import Batch
|
||||
from utils.optimization import set_optimizer
|
||||
from model.model_utils import Registrable
|
||||
from model.model_constructor import *
|
||||
|
||||
# initialization params, output path, logger, random seed and torch.device
|
||||
args = init_args(sys.argv[1:])
|
||||
exp_path = hyperparam_path(args)
|
||||
logger = set_logger(exp_path, args.testing)
|
||||
set_random_seed(args.seed)
|
||||
device = set_torch_device(args.device)
|
||||
logger.info("Initialization finished ...")
|
||||
logger.info("Output path is %s" % (exp_path))
|
||||
logger.info("Random seed is set to %d" % (args.seed))
|
||||
logger.info("Use GPU with index %s" % (args.device) if args.device >= 0 else "Use CPU as target torch device")
|
||||
|
||||
# load dataset and vocabulary
|
||||
start_time = time.time()
|
||||
if args.read_model_path:
|
||||
params = json.load(open(os.path.join(args.read_model_path, 'params.json')), object_hook=lambda d: Namespace(**d))
|
||||
params.lazy_load = True
|
||||
else:
|
||||
params = args
|
||||
# set up the grammar, transition system, evaluator, etc.
|
||||
Example.configuration(plm=params.plm, method=params.model)
|
||||
train_dataset, dev_dataset = Example.load_dataset('train'), Example.load_dataset('dev')
|
||||
logger.info("Load dataset and database finished, cost %.4fs ..." % (time.time() - start_time))
|
||||
logger.info("Dataset size: train -> %d ; dev -> %d" % (len(train_dataset), len(dev_dataset)))
|
||||
sql_trans, evaluator = Example.trans, Example.evaluator
|
||||
args.word_vocab, args.relation_num = len(Example.word_vocab), len(Example.relation_vocab)
|
||||
|
||||
# model init, set optimizer
|
||||
model = Registrable.by_name('text2sql')(params, sql_trans).to(device)
|
||||
if args.read_model_path:
|
||||
check_point = torch.load(open(os.path.join(args.read_model_path, 'model.bin'), 'rb'), map_location=device)
|
||||
model.load_state_dict(check_point['model'])
|
||||
logger.info("Load saved model from path: %s" % (args.read_model_path))
|
||||
else:
|
||||
json.dump(vars(params), open(os.path.join(exp_path, 'params.json'), 'w'), indent=4)
|
||||
if params.plm is None:
|
||||
ratio = Example.word2vec.load_embeddings(model.encoder.input_layer.word_embed, Example.word_vocab, device=device)
|
||||
logger.info("Init model and word embedding layer with a coverage %.2f" % (ratio))
|
||||
# logger.info(str(model))
|
||||
|
||||
def decode(choice, output_path, acc_type='sql', use_checker=False):
|
||||
assert acc_type in ['beam', 'ast', 'sql'] and choice in ['train', 'dev']
|
||||
model.eval()
|
||||
dataset = train_dataset if choice == 'train' else dev_dataset
|
||||
all_hyps = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(dataset), args.batch_size):
|
||||
current_batch = Batch.from_example_list(dataset[i: i + args.batch_size], device, train=False)
|
||||
hyps = model.parse(current_batch, args.beam_size)
|
||||
all_hyps.extend(hyps)
|
||||
acc = evaluator.acc(all_hyps, dataset, output_path, acc_type=acc_type, etype='match', use_checker=use_checker)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return acc
|
||||
|
||||
if not args.testing:
|
||||
num_training_steps = ((len(train_dataset) + args.batch_size - 1) // args.batch_size) * args.max_epoch
|
||||
num_warmup_steps = int(num_training_steps * args.warmup_ratio)
|
||||
logger.info('Total training steps: %d;\t Warmup steps: %d' % (num_training_steps, num_warmup_steps))
|
||||
optimizer, scheduler = set_optimizer(model, args, num_warmup_steps, num_training_steps)
|
||||
start_epoch, nsamples, best_result = 0, len(train_dataset), {'dev_acc': 0.}
|
||||
train_index, step_size = np.arange(nsamples), args.batch_size // args.grad_accumulate
|
||||
if args.read_model_path and args.load_optimizer:
|
||||
optimizer.load_state_dict(check_point['optim'])
|
||||
scheduler.load_state_dict(check_point['scheduler'])
|
||||
start_epoch = check_point['epoch'] + 1
|
||||
logger.info('Start training ......')
|
||||
for i in range(start_epoch, args.max_epoch):
|
||||
start_time = time.time()
|
||||
epoch_loss, epoch_gp_loss, count = 0, 0, 0
|
||||
np.random.shuffle(train_index)
|
||||
model.train()
|
||||
for j in range(0, nsamples, step_size):
|
||||
count += 1
|
||||
cur_dataset = [train_dataset[k] for k in train_index[j: j + step_size]]
|
||||
current_batch = Batch.from_example_list(cur_dataset, device, train=True, smoothing=args.smoothing)
|
||||
loss, gp_loss = model(current_batch) # see utils/batch.py for batch elements
|
||||
epoch_loss += loss.item()
|
||||
epoch_gp_loss += gp_loss.item()
|
||||
loss += gp_loss
|
||||
loss.backward()
|
||||
if count == args.grad_accumulate or j + step_size >= nsamples:
|
||||
count = 0
|
||||
model.pad_embedding_grad_zero()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
logger.info('Training: \tEpoch: %d\tTime: %.4f\tTraining loss: %.4f/%.4f' % (i, time.time() - start_time, epoch_loss, epoch_gp_loss))
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if i < args.eval_after_epoch: # avoid unnecessary evaluation
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
dev_acc = decode('dev', os.path.join(exp_path, 'dev.iter' + str(i)), acc_type='sql')
|
||||
logger.info('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f' % (i, time.time() - start_time, dev_acc))
|
||||
if dev_acc > best_result['dev_acc']:
|
||||
best_result['dev_acc'], best_result['iter'] = dev_acc, i
|
||||
torch.save({
|
||||
'epoch': i, 'model': model.state_dict(),
|
||||
'optim': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict()
|
||||
}, open(os.path.join(exp_path, 'model.bin'), 'wb'))
|
||||
logger.info('NEW BEST MODEL: \tEpoch: %d\tDev acc: %.4f' % (i, dev_acc))
|
||||
|
||||
logger.info('FINAL BEST RESULT: \tEpoch: %d\tDev acc: %.4f' % (best_result['iter'], best_result['dev_acc']))
|
||||
else:
|
||||
start_time = time.time()
|
||||
dev_acc = decode('dev', output_path=os.path.join(args.read_model_path, 'dev.eval'), acc_type='sql')
|
||||
dev_acc_checker = decode('dev', output_path=os.path.join(args.read_model_path, 'dev.eval.checker'), acc_type='sql', use_checker=True)
|
||||
logger.info("Evaluation costs %.2fs ; Dev dataset exact match/checker is %.4f/%.4f ." % (time.time() - start_time, dev_acc, dev_acc_checker))
|
|
@ -0,0 +1,14 @@
|
|||
conda create -n text2sql python=3.6
|
||||
source activate text2sql
|
||||
pip install torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -r requirements.txt
|
||||
python -c "import stanza; stanza.download('en')"
|
||||
python -c "from embeddings import GloveEmbedding; emb = GloveEmbedding('common_crawl_48', d_emb=300)"
|
||||
python -c "import nltk; nltk.download('stopwords')"
|
||||
mkdir -p pretrained_models && cd pretrained_models
|
||||
git lfs install
|
||||
git clone https://huggingface.co/bert-large-uncased-whole-word-masking
|
||||
git clone https://huggingface.co/google/electra-large-discriminator
|
||||
mkdir -p glove.42b.300d && cd glove.42b.300d
|
||||
wget -c http://nlp.stanford.edu/data/glove.42B.300d.zip && unzip glove.42B.300d.zip
|
||||
awk -v FS=' ' '{print $1}' glove.42B.300d.txt > vocab_glove.txt
|
Binary file not shown.
After Width: | Height: | Size: 315 KiB |
|
@ -0,0 +1,84 @@
|
|||
#coding=utf-8
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
def init_args(params=sys.argv[1:]):
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser = add_argument_base(arg_parser)
|
||||
arg_parser = add_argument_encoder(arg_parser)
|
||||
arg_parser = add_argument_decoder(arg_parser)
|
||||
opt = arg_parser.parse_args(params)
|
||||
if opt.model == 'rgatsql' and opt.local_and_nonlocal == 'msde':
|
||||
opt.local_and_nonlocal = 'global'
|
||||
if opt.model == 'lgesql' and opt.local_and_nonlocal == 'global':
|
||||
opt.local_and_nonlocal = 'msde'
|
||||
return opt
|
||||
|
||||
def add_argument_base(arg_parser):
|
||||
#### General configuration ####
|
||||
arg_parser.add_argument('--task', default='text2sql', help='task name')
|
||||
arg_parser.add_argument('--seed', default=999, type=int, help='Random seed')
|
||||
arg_parser.add_argument('--device', type=int, default=0, help='Use which device: -1 -> cpu ; the index of gpu o.w.')
|
||||
arg_parser.add_argument('--testing', action='store_true', help='training or evaluation mode')
|
||||
arg_parser.add_argument('--read_model_path', type=str, help='read pretrained model path')
|
||||
#### Training Hyperparams ####
|
||||
arg_parser.add_argument('--batch_size', default=20, type=int, help='Batch size')
|
||||
arg_parser.add_argument('--grad_accumulate', default=1, type=int, help='accumulate grad and update once every x steps')
|
||||
arg_parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
|
||||
arg_parser.add_argument('--layerwise_decay', type=float, default=1.0, help='layerwise decay rate for lr, used for PLM')
|
||||
arg_parser.add_argument('--l2', type=float, default=1e-4, help='weight decay coefficient')
|
||||
arg_parser.add_argument('--warmup_ratio', type=float, default=0.1, help='warmup steps proportion')
|
||||
arg_parser.add_argument('--lr_schedule', default='linear', choices=['constant', 'linear', 'ratsql', 'cosine'], help='lr scheduler')
|
||||
arg_parser.add_argument('--eval_after_epoch', default=40, type=int, help='Start to evaluate after x epoch')
|
||||
arg_parser.add_argument('--load_optimizer', action='store_true', default=False, help='Whether to load optimizer state')
|
||||
arg_parser.add_argument('--max_epoch', type=int, default=100, help='terminate after maximum epochs')
|
||||
arg_parser.add_argument('--max_norm', default=5., type=float, help='clip gradients')
|
||||
return arg_parser
|
||||
|
||||
def add_argument_encoder(arg_parser):
|
||||
# Encoder Hyperparams
|
||||
arg_parser.add_argument('--model', choices=['rgatsql', 'lgesql'], default='lgesql', help='which text2sql model to use')
|
||||
arg_parser.add_argument('--local_and_nonlocal', choices=['mmc', 'msde', 'local', 'global'], default='mmc',
|
||||
help='how to integrate local and non-local relations: mmc -> multi-head multi-view concatenation ; msde -> mixed static and dynamic embeddings')
|
||||
arg_parser.add_argument('--output_model', choices=['without_pruning', 'with_pruning'], default='without_pruning', help='whether add graph pruning')
|
||||
arg_parser.add_argument('--plm', type=str, choices=['bert-base-uncased', 'bert-large-uncased', 'bert-large-uncased-whole-word-masking',
|
||||
'roberta-base', 'roberta-large', 'grappa_large_jnt', 'electra-base-discriminator', 'electra-large-discriminator'], help='pretrained model name')
|
||||
arg_parser.add_argument('--subword_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling'], default='attentive-pooling', help='aggregate subword feats from PLM')
|
||||
arg_parser.add_argument('--schema_aggregation', choices=['mean-pooling', 'max-pooling', 'attentive-pooling', 'head+tail'], default='head+tail', help='aggregate schema words feats')
|
||||
arg_parser.add_argument('--dropout', type=float, default=0.2, help='feature dropout rate')
|
||||
arg_parser.add_argument('--attn_drop', type=float, default=0., help='dropout rate of attention weights')
|
||||
arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings, only used in glove.42B.300d')
|
||||
arg_parser.add_argument('--gnn_num_layers', default=8, type=int, help='num of GNN layers in encoder')
|
||||
arg_parser.add_argument('--gnn_hidden_size', default=256, type=int, help='size of GNN layers hidden states')
|
||||
arg_parser.add_argument('--num_heads', default=8, type=int, help='num of heads in multihead attn')
|
||||
arg_parser.add_argument('--relation_share_layers', action='store_true')
|
||||
arg_parser.add_argument('--relation_share_heads', action='store_true')
|
||||
arg_parser.add_argument('--score_function', choices=['affine', 'bilinear', 'biaffine', 'dot'], default='affine', help='graph pruning score function')
|
||||
arg_parser.add_argument('--smoothing', type=float, default=0.15, help='label smoothing factor for graph pruning')
|
||||
return arg_parser
|
||||
|
||||
def add_argument_decoder(arg_parser):
|
||||
# Decoder Hyperparams
|
||||
arg_parser.add_argument('--lstm', choices=['lstm', 'onlstm'], default='onlstm', help='Type of LSTM used, ONLSTM or traditional LSTM')
|
||||
arg_parser.add_argument('--chunk_size', default=8, type=int, help='parameter of ONLSTM')
|
||||
arg_parser.add_argument('--att_vec_size', default=512, type=int, help='size of attentional vector')
|
||||
arg_parser.add_argument('--sep_cxt', action='store_true', help='when calculating context vectors, use seperate cxt for question and schema')
|
||||
arg_parser.add_argument('--drop_connect', type=float, default=0.2, help='recurrent connection dropout rate in decoder lstm')
|
||||
arg_parser.add_argument('--lstm_num_layers', type=int, default=1, help='num_layers of decoder')
|
||||
arg_parser.add_argument('--lstm_hidden_size', default=512, type=int, help='Size of LSTM hidden states')
|
||||
arg_parser.add_argument('--action_embed_size', default=128, type=int, help='Size of ApplyRule/GenToken action embeddings')
|
||||
arg_parser.add_argument('--field_embed_size', default=64, type=int, help='Embedding size of ASDL fields')
|
||||
arg_parser.add_argument('--type_embed_size', default=64, type=int, help='Embeddings ASDL types')
|
||||
arg_parser.add_argument('--no_context_feeding', action='store_true', default=False,
|
||||
help='Do not use embedding of context vectors')
|
||||
arg_parser.add_argument('--no_parent_production_embed', default=False, action='store_true',
|
||||
help='Do not use embedding of parent ASDL production to update decoder LSTM state')
|
||||
arg_parser.add_argument('--no_parent_field_embed', default=False, action='store_true',
|
||||
help='Do not use embedding of parent field to update decoder LSTM state')
|
||||
arg_parser.add_argument('--no_parent_field_type_embed', default=False, action='store_true',
|
||||
help='Do not use embedding of the ASDL type of parent field to update decoder LSTM state')
|
||||
arg_parser.add_argument('--no_parent_state', default=False, action='store_true',
|
||||
help='Do not use the parent hidden state to update decoder LSTM state')
|
||||
arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search')
|
||||
arg_parser.add_argument('--decode_max_step', default=100, type=int, help='Maximum number of time steps used in decoding')
|
||||
return arg_parser
|
|
@ -0,0 +1,230 @@
|
|||
#coding=utf8
|
||||
import torch
|
||||
import numpy as np
|
||||
from utils.example import Example, get_position_ids
|
||||
from utils.constants import PAD, UNK
|
||||
from model.model_utils import lens2mask, cached_property
|
||||
import torch.nn.functional as F
|
||||
|
||||
def from_example_list_base(ex_list, device='cpu', train=True):
|
||||
"""
|
||||
question_lens: torch.long, bsize
|
||||
questions: torch.long, bsize x max_question_len, include [CLS] if add_cls
|
||||
table_lens: torch.long, bsize, number of tables for each example
|
||||
table_word_lens: torch.long, number of words for each table name
|
||||
tables: torch.long, sum_of_tables x max_table_word_len
|
||||
column_lens: torch.long, bsize, number of columns for each example
|
||||
column_word_lens: torch.long, number of words for each column name
|
||||
columns: torch.long, sum_of_columns x max_column_word_len
|
||||
"""
|
||||
batch = Batch(ex_list, device)
|
||||
plm = Example.plm
|
||||
pad_idx = Example.word_vocab[PAD] if plm is None else Example.tokenizer.pad_token_id
|
||||
|
||||
question_lens = [len(ex.question) for ex in ex_list]
|
||||
batch.question_lens = torch.tensor(question_lens, dtype=torch.long, device=device)
|
||||
batch.table_lens = torch.tensor([len(ex.table) for ex in ex_list], dtype=torch.long, device=device)
|
||||
table_word_lens = [len(t) for ex in ex_list for t in ex.table]
|
||||
batch.table_word_lens = torch.tensor(table_word_lens, dtype=torch.long, device=device)
|
||||
batch.column_lens = torch.tensor([len(ex.column) for ex in ex_list], dtype=torch.long, device=device)
|
||||
column_word_lens = [len(c) for ex in ex_list for c in ex.column]
|
||||
batch.column_word_lens = torch.tensor(column_word_lens, dtype=torch.long, device=device)
|
||||
|
||||
if plm is None: # glove.42B.300d
|
||||
questions = [ex.question_id + [pad_idx] * (batch.max_question_len - len(ex.question_id)) for ex in ex_list]
|
||||
batch.questions = torch.tensor(questions, dtype=torch.long, device=device)
|
||||
tables = [t + [pad_idx] * (batch.max_table_word_len - len(t)) for ex in ex_list for t in ex.table_id]
|
||||
batch.tables = torch.tensor(tables, dtype=torch.long, device=device)
|
||||
columns = [c + [pad_idx] * (batch.max_column_word_len - len(c)) for ex in ex_list for c in ex.column_id]
|
||||
batch.columns = torch.tensor(columns, dtype=torch.long, device=device)
|
||||
else:
|
||||
# prepare inputs for pretrained models
|
||||
batch.inputs = {"input_ids": None, "attention_mask": None, "token_type_ids": None, "position_ids": None}
|
||||
input_lens = [len(ex.input_id) for ex in ex_list]
|
||||
max_len = max(input_lens)
|
||||
input_ids = [ex.input_id + [pad_idx] * (max_len - len(ex.input_id)) for ex in ex_list]
|
||||
batch.inputs["input_ids"] = torch.tensor(input_ids, dtype=torch.long, device=device)
|
||||
attention_mask = [[1] * l + [0] * (max_len - l) for l in input_lens]
|
||||
batch.inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.float, device=device)
|
||||
token_type_ids = [ex.segment_id + [0] * (max_len - len(ex.segment_id)) for ex in ex_list]
|
||||
batch.inputs["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long, device=device)
|
||||
position_ids = [get_position_ids(ex, shuffle=train) + [0] * (max_len - len(ex.input_id)) for ex in ex_list]
|
||||
batch.inputs["position_ids"] = torch.tensor(position_ids, dtype=torch.long, device=device)
|
||||
# extract representations after plm, remove [SEP]
|
||||
question_mask_plm = [ex.question_mask_plm + [0] * (max_len - len(ex.question_mask_plm)) for ex in ex_list]
|
||||
batch.question_mask_plm = torch.tensor(question_mask_plm, dtype=torch.bool, device=device)
|
||||
table_mask_plm = [ex.table_mask_plm + [0] * (max_len - len(ex.table_mask_plm)) for ex in ex_list]
|
||||
batch.table_mask_plm = torch.tensor(table_mask_plm, dtype=torch.bool, device=device)
|
||||
column_mask_plm = [ex.column_mask_plm + [0] * (max_len - len(ex.column_mask_plm)) for ex in ex_list]
|
||||
batch.column_mask_plm = torch.tensor(column_mask_plm, dtype=torch.bool, device=device)
|
||||
# subword aggregation
|
||||
question_subword_lens = [l for ex in ex_list for l in ex.question_subword_len]
|
||||
batch.question_subword_lens = torch.tensor(question_subword_lens, dtype=torch.long, device=device)
|
||||
table_subword_lens = [l for ex in ex_list for l in ex.table_subword_len]
|
||||
batch.table_subword_lens = torch.tensor(table_subword_lens, dtype=torch.long, device=device)
|
||||
column_subword_lens = [l for ex in ex_list for l in ex.column_subword_len]
|
||||
batch.column_subword_lens = torch.tensor(column_subword_lens, dtype=torch.long, device=device)
|
||||
|
||||
batch.question_unk_mask, batch.table_unk_mask, batch.column_unk_mask = None, None, None
|
||||
if not train and plm is None:
|
||||
# during evaluation, for words not in vocab but in glove vocab, extract its correpsonding embedding
|
||||
word2vec, unk_idx = Example.word2vec, Example.word_vocab[UNK]
|
||||
|
||||
question_unk_mask = (batch.questions == unk_idx).cpu()
|
||||
if question_unk_mask.any().item():
|
||||
raw_questions = np.array([ex.question + [PAD] * (batch.max_question_len - len(ex.question)) for ex in ex_list], dtype='<U100')
|
||||
unk_words = raw_questions[question_unk_mask.numpy()].tolist()
|
||||
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
|
||||
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
|
||||
if oov_flag.any().item():
|
||||
batch.question_unk_mask = question_unk_mask.masked_scatter_(torch.clone(question_unk_mask), oov_flag).to(device)
|
||||
batch.question_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
|
||||
|
||||
table_unk_mask = (batch.tables == unk_idx).cpu()
|
||||
if table_unk_mask.any().item():
|
||||
raw_tables = np.array([t + [PAD] * (batch.max_table_word_len - len(t)) for ex in ex_list for t in ex.table], dtype='<U100')
|
||||
unk_words = raw_tables[table_unk_mask.numpy()].tolist()
|
||||
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
|
||||
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
|
||||
if oov_flag.any().item():
|
||||
batch.table_unk_mask = table_unk_mask.masked_scatter_(torch.clone(table_unk_mask), oov_flag).to(device)
|
||||
batch.table_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
|
||||
|
||||
column_unk_mask = (batch.columns == unk_idx).cpu()
|
||||
if column_unk_mask.any().item():
|
||||
raw_columns = np.array([c + [PAD] * (batch.max_column_word_len - len(c)) for ex in ex_list for c in ex.column], dtype='<U100')
|
||||
unk_words = raw_columns[column_unk_mask.numpy()].tolist()
|
||||
unk_word_embeddings = [word2vec.emb(w) for w in unk_words]
|
||||
oov_flag = torch.tensor([True if e is not None else False for e in unk_word_embeddings], dtype=torch.bool)
|
||||
if oov_flag.any().item():
|
||||
batch.column_unk_mask = column_unk_mask.masked_scatter_(torch.clone(column_unk_mask), oov_flag).to(device)
|
||||
batch.column_unk_embeddings = torch.tensor([e for e in unk_word_embeddings if e is not None], dtype=torch.float, device=device)
|
||||
return batch
|
||||
|
||||
def from_example_list_text2sql(ex_list, device='cpu', train=True, **kwargs):
|
||||
""" New fields: batch.lens, batch.max_len, batch.relations, batch.relations_mask
|
||||
"""
|
||||
batch = from_example_list_base(ex_list, device, train)
|
||||
batch.graph = Example.graph_factory.batch_graphs(ex_list, device, train=train, **kwargs)
|
||||
if train:
|
||||
batch.max_action_num = max([len(ex.tgt_action) for ex in ex_list])
|
||||
return batch
|
||||
|
||||
class Batch():
|
||||
|
||||
def __init__(self, examples, device='cpu'):
|
||||
super(Batch, self).__init__()
|
||||
self.examples = examples
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def from_example_list(cls, ex_list, device='cpu', train=True, method='text2sql', **kwargs):
|
||||
method_dict = {
|
||||
"text2sql": from_example_list_text2sql,
|
||||
}
|
||||
return method_dict[method](ex_list, device, train=train, **kwargs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.examples[idx]
|
||||
|
||||
@cached_property
|
||||
def max_question_len(self):
|
||||
return torch.max(self.question_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_table_len(self):
|
||||
return torch.max(self.table_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_column_len(self):
|
||||
return torch.max(self.column_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_table_word_len(self):
|
||||
return torch.max(self.table_word_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_column_word_len(self):
|
||||
return torch.max(self.column_word_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_question_subword_len(self):
|
||||
return torch.max(self.question_subword_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_table_subword_len(self):
|
||||
return torch.max(self.table_subword_lens).item()
|
||||
|
||||
@cached_property
|
||||
def max_column_subword_len(self):
|
||||
return torch.max(self.column_subword_lens).item()
|
||||
|
||||
""" Different types of nodes are seperated instead of concatenated together """
|
||||
@cached_property
|
||||
def mask(self):
|
||||
return torch.cat([self.question_mask, self.table_mask, self.column_mask], dim=1)
|
||||
|
||||
@cached_property
|
||||
def question_mask(self):
|
||||
return lens2mask(self.question_lens)
|
||||
|
||||
@cached_property
|
||||
def table_mask(self):
|
||||
return lens2mask(self.table_lens)
|
||||
|
||||
@cached_property
|
||||
def column_mask(self):
|
||||
return lens2mask(self.column_lens)
|
||||
|
||||
@cached_property
|
||||
def table_word_mask(self):
|
||||
return lens2mask(self.table_word_lens)
|
||||
|
||||
@cached_property
|
||||
def column_word_mask(self):
|
||||
return lens2mask(self.column_word_lens)
|
||||
|
||||
@cached_property
|
||||
def question_subword_mask(self):
|
||||
return lens2mask(self.question_subword_lens)
|
||||
|
||||
@cached_property
|
||||
def table_subword_mask(self):
|
||||
return lens2mask(self.table_subword_lens)
|
||||
|
||||
@cached_property
|
||||
def column_subword_mask(self):
|
||||
return lens2mask(self.column_subword_lens)
|
||||
|
||||
def get_frontier_field_idx(self, t):
|
||||
ids = []
|
||||
for e in self.examples:
|
||||
if t < len(e.tgt_action):
|
||||
ids.append(Example.grammar.field2id[e.tgt_action[t].frontier_field])
|
||||
# assert self.grammar.id2field[ids[-1]] == e.tgt_action[t].frontier_field
|
||||
else:
|
||||
ids.append(0)
|
||||
return torch.tensor(ids, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_frontier_prod_idx(self, t):
|
||||
ids = []
|
||||
for e in self.examples:
|
||||
if t < len(e.tgt_action):
|
||||
ids.append(Example.grammar.prod2id[e.tgt_action[t].frontier_prod])
|
||||
# assert self.grammar.id2prod[ids[-1]] == e.tgt_action[t].frontier_prod
|
||||
else:
|
||||
ids.append(0)
|
||||
return torch.tensor(ids, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_frontier_field_type_idx(self, t):
|
||||
ids = []
|
||||
for e in self.examples:
|
||||
if t < len(e.tgt_action):
|
||||
ids.append(Example.grammar.type2id[e.tgt_action[t].frontier_field.type])
|
||||
# assert self.grammar.id2type[ids[-1]] == e.tgt_action[t].frontier_field.type
|
||||
else:
|
||||
ids.append(0)
|
||||
return torch.tensor(ids, dtype=torch.long, device=self.device)
|
|
@ -0,0 +1,19 @@
|
|||
PAD = '[PAD]'
|
||||
BOS = '[CLS]'
|
||||
EOS = '[SEP]'
|
||||
UNK = '[UNK]'
|
||||
|
||||
GRAMMAR_FILEPATH = 'asdl/sql/grammar/sql_asdl_v2.txt'
|
||||
SCHEMA_TYPES = ['table', 'others', 'text', 'time', 'number', 'boolean']
|
||||
MAX_RELATIVE_DIST = 2
|
||||
# relations: type_1-type_2-rel_name, r represents reverse edge, b represents bidirectional edge
|
||||
RELATIONS = ['question-question-dist' + str(i) if i != 0 else 'question-question-identity' for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1)] + \
|
||||
['table-table-identity', 'table-table-fk', 'table-table-fkr', 'table-table-fkb'] + \
|
||||
['column-column-identity', 'column-column-sametable', 'column-column-fk', 'column-column-fkr'] + \
|
||||
['table-column-pk', 'column-table-pk', 'table-column-has', 'column-table-has'] + \
|
||||
['question-column-exactmatch', 'question-column-partialmatch','question-column-partialsemanticmatch', 'question-column-nomatch','question-column-semanticmatch', 'question-column-valuematch',
|
||||
'column-question-exactmatch', 'column-question-partialmatch','column-question-partialsemanticmatch', 'column-question-nomatch','column-question-semanticmatch', 'column-question-valuematch'] + \
|
||||
['question-table-exactmatch', 'question-table-partialmatch','question-table-partialsemanticmatch', 'question-table-nomatch','question-table-semanticmatch',
|
||||
'table-question-exactmatch', 'table-question-partialmatch', 'table-question-partialsemanticmatch','table-question-semanticmatch','table-question-nomatch'] + \
|
||||
['question-question-generic', 'table-table-generic', 'column-column-generic', 'table-column-generic', 'column-table-generic'] + \
|
||||
['*-*-identity', '*-question-generic', 'question-*-generic', '*-table-generic', 'table-*-generic', '*-column-generic', 'column-*-generic']
|
|
@ -0,0 +1,375 @@
|
|||
#coding=utf8
|
||||
import sys, tempfile, os
|
||||
import pickle, json
|
||||
import numpy as np
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from evaluation import evaluate, build_foreign_key_map_from_json, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, eval_exec_match
|
||||
from evaluation import Evaluator as Engine
|
||||
from process_sql import get_schema, Schema, get_sql
|
||||
|
||||
class Evaluator():
|
||||
|
||||
def __init__(self, transition_system, table_path='data/tables.json', database_dir='data/database'):
|
||||
super(Evaluator, self).__init__()
|
||||
self.transition_system = transition_system
|
||||
self.kmaps = build_foreign_key_map_from_json(table_path)
|
||||
self.database_dir = database_dir
|
||||
self.engine = Engine()
|
||||
self.checker = Checker(table_path, database_dir)
|
||||
self.acc_dict = {
|
||||
"sql": self.sql_acc, # use golden sql as references
|
||||
"ast": self.ast_acc, # compare ast accuracy, ast may be incorrect when constructed from raw sql
|
||||
"beam": self.beam_acc, # if the correct answer exist in the beam, assume the result is true
|
||||
}
|
||||
|
||||
def acc(self, pred_hyps, dataset, output_path=None, acc_type='sql', etype='match', use_checker=False):
|
||||
assert len(pred_hyps) == len(dataset) and acc_type in self.acc_dict and etype in ['match', 'exec']
|
||||
acc_method = self.acc_dict[acc_type]
|
||||
return acc_method(pred_hyps, dataset, output_path, etype, use_checker)
|
||||
|
||||
def beam_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
|
||||
scores, results = {}, []
|
||||
for each in ['easy', 'medium', 'hard', 'extra', 'all']:
|
||||
scores[each] = [0, 0.] # first is count, second is total score
|
||||
for idx, pred in enumerate(pred_hyps):
|
||||
question, gold_sql, db = dataset[idx].ex['question'], dataset[idx].query, dataset[idx].db
|
||||
for b_id, hyp in enumerate(pred):
|
||||
pred_sql = self.transition_system.ast_to_surface_code(hyp.tree, db)
|
||||
score, hardness = self.single_acc(pred_sql, gold_sql, db['db_id'], etype)
|
||||
if int(score) == 1:
|
||||
scores[hardness][0] += 1
|
||||
scores[hardness][1] += 1.0
|
||||
scores['all'][0] += 1
|
||||
scores['all'][1] += 1.0
|
||||
results.append((hardness, question, gold_sql, pred_sql, b_id, True))
|
||||
break
|
||||
else:
|
||||
scores[hardness][0] += 1
|
||||
scores['all'][0] += 1
|
||||
pred_sql = self.transition_system.ast_to_surface_code(pred[0].tree, db)
|
||||
results.append((hardness, question, gold_sql, pred_sql, 0, False))
|
||||
for each in scores:
|
||||
accuracy = scores[each][1] / float(scores[each][0]) if scores[each][0] != 0 else 0.
|
||||
scores[each].append(accuracy)
|
||||
of = open(output_path, 'w', encoding='utf8') if output_path is not None else \
|
||||
tempfile.TemporaryFile('w+t')
|
||||
for item in results:
|
||||
of.write('Level: %s\n' % (item[0]))
|
||||
of.write('Question: %s\n' % (item[1]))
|
||||
of.write('Gold SQL: %s\n' %(item[2]))
|
||||
of.write('Pred SQL (%s): %s\n' % (item[4], item[3]))
|
||||
of.write('Correct: %s\n\n' % (item[5]))
|
||||
for each in scores:
|
||||
of.write('Level %s: %s\n' % (each, scores[each]))
|
||||
of.close()
|
||||
return scores['all'][2]
|
||||
|
||||
def single_acc(self, pred_sql, gold_sql, db, etype):
|
||||
"""
|
||||
@return:
|
||||
score(float): 0 or 1, etype score
|
||||
hardness(str): one of 'easy', 'medium', 'hard', 'extra'
|
||||
"""
|
||||
db_name = db
|
||||
db = os.path.join(self.database_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, gold_sql)
|
||||
hardness = self.engine.eval_hardness(g_sql)
|
||||
try:
|
||||
p_sql = get_sql(schema, pred_sql)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
kmap = self.kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) # kmap: map __tab.col__ to pivot __tab.col__
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
if etype == 'exec':
|
||||
score = float(eval_exec_match(db, pred_sql, gold_sql, p_sql, g_sql))
|
||||
if etype == 'match':
|
||||
score = float(self.engine.eval_exact_match(p_sql, g_sql))
|
||||
return score, hardness
|
||||
|
||||
def ast_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
|
||||
pred_asts = [hyp[0].tree for hyp in pred_hyps]
|
||||
ref_asts = [ex.ast for ex in dataset]
|
||||
dbs = [ex.db for ex in dataset]
|
||||
pred_sqls, ref_sqls = [], []
|
||||
for pred, ref, db in zip(pred_asts, ref_asts, dbs):
|
||||
pred_sql = self.transition_system.ast_to_surface_code(pred, db)
|
||||
ref_sql = self.transition_system.ast_to_surface_code(ref, db)
|
||||
pred_sqls.append(pred_sql)
|
||||
ref_sqls.append(ref_sql)
|
||||
with tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_pred, \
|
||||
tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_ref:
|
||||
of = open(output_path, 'w', encoding='utf8') if output_path is not None \
|
||||
else tempfile.TemporaryFile('w+t')
|
||||
# write pred and ref sqls
|
||||
for s in pred_sqls:
|
||||
tmp_pred.write(s + '\n')
|
||||
tmp_pred.flush()
|
||||
for s, db in zip(ref_sqls, dbs):
|
||||
tmp_ref.write(s + '\t' + db['db_id'] + '\n')
|
||||
tmp_ref.flush()
|
||||
# calculate ast accuracy
|
||||
old_print = sys.stdout
|
||||
sys.stdout = of
|
||||
result_type = 'exact' if etype == 'match' else 'exec'
|
||||
all_exact_acc = evaluate(tmp_ref.name, tmp_pred.name, self.database_dir, etype, self.kmaps)['all'][result_type]
|
||||
sys.stdout = old_print
|
||||
of.close()
|
||||
return float(all_exact_acc)
|
||||
|
||||
def sql_acc(self, pred_hyps, dataset, output_path, etype, use_checker):
|
||||
pred_sqls, ref_sqls = [], [ex.query for ex in dataset]
|
||||
dbs = [ex.db for ex in dataset]
|
||||
for idx, hyp in enumerate(pred_hyps):
|
||||
if use_checker:
|
||||
pred_sql = self.obtain_sql(hyp, dbs[idx])
|
||||
else:
|
||||
best_ast = hyp[0].tree # by default, the top beam prediction
|
||||
pred_sql = self.transition_system.ast_to_surface_code(best_ast, dbs[idx])
|
||||
pred_sqls.append(pred_sql)
|
||||
with tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_pred, \
|
||||
tempfile.NamedTemporaryFile('w+t', encoding='utf8', suffix='.sql') as tmp_ref:
|
||||
of = open(output_path, 'w', encoding='utf8') if output_path is not None \
|
||||
else tempfile.TemporaryFile('w+t')
|
||||
# write pred and ref sqls
|
||||
for s in pred_sqls:
|
||||
tmp_pred.write(s + '\n')
|
||||
tmp_pred.flush()
|
||||
for s, db in zip(ref_sqls, dbs):
|
||||
tmp_ref.write(s + '\t' + db['db_id'] + '\n')
|
||||
tmp_ref.flush()
|
||||
# calculate sql accuracy
|
||||
old_print = sys.stdout
|
||||
sys.stdout = of
|
||||
result_type = 'exact' if etype == 'match' else 'exec'
|
||||
all_exact_acc = evaluate(tmp_ref.name, tmp_pred.name, self.database_dir, etype, self.kmaps)['all'][result_type]
|
||||
sys.stdout = old_print
|
||||
of.close()
|
||||
return float(all_exact_acc)
|
||||
|
||||
def obtain_sql(self, hyps, db):
|
||||
beam = len(hyps)
|
||||
for hyp in hyps:
|
||||
cur_ast = hyp.tree
|
||||
pred_sql = self.transition_system.ast_to_surface_code(cur_ast, db)
|
||||
if self.checker.validity_check(pred_sql, db['db_id']):
|
||||
sql = pred_sql
|
||||
break
|
||||
else:
|
||||
best_ast = hyps[0].tree
|
||||
sql = self.transition_system.ast_to_surface_code(best_ast, db)
|
||||
return sql
|
||||
|
||||
class Checker():
|
||||
|
||||
def __init__(self, table_path='data/tables.json', db_dir='data/database'):
|
||||
super(Checker, self).__init__()
|
||||
self.table_path = table_path
|
||||
self.db_dir = db_dir
|
||||
self.schemas, self.database, self.tables = self._get_schemas_from_json(self.table_path)
|
||||
|
||||
def _get_schemas_from_json(self, table_path):
|
||||
database_list = json.load(open(table_path, 'r'))
|
||||
database = {}
|
||||
for db in database_list:
|
||||
database[db['db_id']] = db
|
||||
tables, schemas = {}, {}
|
||||
for db in database_list:
|
||||
db_id = db['db_id']
|
||||
schema = {}
|
||||
column_names_original = db['column_names_original']
|
||||
table_names_original = db['table_names_original']
|
||||
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
|
||||
for i, tabn in enumerate(table_names_original):
|
||||
table = str(tabn.lower())
|
||||
cols = [str(col.lower()) for td, col in column_names_original if td == i]
|
||||
schema[table] = cols
|
||||
schemas[db_id] = schema
|
||||
return schemas, database, tables
|
||||
|
||||
def validity_check(self, sql: str, db: str):
|
||||
""" Check whether the given sql query is valid, including:
|
||||
1. only use columns in tables mentioned in FROM clause
|
||||
2. comparison operator or MAX/MIN/SUM/AVG only applied to columns of type number/time
|
||||
@params:
|
||||
sql(str): SQL query
|
||||
db(str): db_id field, database name
|
||||
@return:
|
||||
flag(boolean)
|
||||
"""
|
||||
schema, table = self.schemas[db], self.tables[db]
|
||||
schema = SchemaID(schema, table)
|
||||
try:
|
||||
sql = get_sql(schema, sql)
|
||||
return self.sql_check(sql, self.database[db])
|
||||
except Exception as e:
|
||||
print('Runtime error occurs:', e)
|
||||
return False
|
||||
|
||||
def sql_check(self, sql: dict, db: dict):
|
||||
if sql['intersect']:
|
||||
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['intersect'], db)
|
||||
if sql['union']:
|
||||
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['union'], db)
|
||||
if sql['except']:
|
||||
return self.sqlunit_check(sql, db) & self.sqlunit_check(sql['except'], db)
|
||||
return self.sqlunit_check(sql, db)
|
||||
|
||||
def sqlunit_check(self, sql: dict, db: dict):
|
||||
if sql['from']['table_units'][0][0] == 'sql':
|
||||
if not self.sql_check(sql['from']['table_units'][0][1], db): return False
|
||||
table_ids = []
|
||||
else:
|
||||
table_ids = list(map(lambda table_unit: table_unit[1], sql['from']['table_units']))
|
||||
return self.select_check(sql['select'], table_ids, db) & \
|
||||
self.cond_check(sql['where'], table_ids, db) & \
|
||||
self.groupby_check(sql['groupBy'], table_ids, db) & \
|
||||
self.cond_check(sql['having'], table_ids, db) & \
|
||||
self.orderby_check(sql['orderBy'], table_ids, db)
|
||||
|
||||
def select_check(self, select, table_ids: list, db: dict):
|
||||
select = select[1]
|
||||
for agg_id, val_unit in select:
|
||||
if not self.valunit_check(val_unit, table_ids, db): return False
|
||||
# MAX/MIN/SUM/AVG
|
||||
# if agg_id in [1, 2, 4, 5] and (self.valunit_type(val_unit, db) not in ['number', 'time']):
|
||||
# return False
|
||||
return True
|
||||
|
||||
def cond_check(self, cond, table_ids: list, db: dict):
|
||||
if len(cond) == 0:
|
||||
return True
|
||||
for idx in range(0, len(cond), 2):
|
||||
cond_unit = cond[idx]
|
||||
_, cmp_op, val_unit, val1, val2 = cond_unit
|
||||
flag = self.valunit_check(val_unit, table_ids, db)
|
||||
# if cmp_op in [3, 4, 5, 6]: # >, <, >=, <=
|
||||
# flag &= (self.valunit_type(val_unit, db) in ['number', 'time'])
|
||||
if type(val1) == dict:
|
||||
flag &= self.sql_check(val1, db)
|
||||
if type(val2) == dict:
|
||||
flag &= self.sql_check(val2, db)
|
||||
if not flag: return False
|
||||
return True
|
||||
|
||||
def groupby_check(self, groupby, table_ids: list, db: dict):
|
||||
if not groupby: return True
|
||||
for col_unit in groupby:
|
||||
if not self.colunit_check(col_unit, table_ids, db): return False
|
||||
return True
|
||||
|
||||
def orderby_check(self, orderby, table_ids: list, db: dict):
|
||||
if not orderby: return True
|
||||
orderby = orderby[1]
|
||||
for val_unit in orderby:
|
||||
if not self.valunit_check(val_unit, table_ids, db): return False
|
||||
return True
|
||||
|
||||
def colunit_check(self, col_unit: list, table_ids: list, db: dict):
|
||||
""" Check from the following aspects:
|
||||
1. column belongs to the tables in FROM clause
|
||||
2. column type is valid for AGG_OP
|
||||
"""
|
||||
agg_id, col_id, _ = col_unit
|
||||
if col_id == 0: return True
|
||||
tab_id = db['column_names'][col_id][0]
|
||||
if tab_id not in table_ids: return False
|
||||
col_type = db['column_types'][col_id]
|
||||
if agg_id in [1, 2, 4, 5]: # MAX, MIN, SUM, AVG
|
||||
return (col_type in ['time', 'number'])
|
||||
return True
|
||||
|
||||
def valunit_check(self, val_unit: list, table_ids: list, db: dict):
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
if unit_op == 0:
|
||||
return self.colunit_check(col_unit1, table_ids, db)
|
||||
if not (self.colunit_check(col_unit1, table_ids, db) and self.colunit_check(col_unit2, table_ids, db)):
|
||||
return False
|
||||
# COUNT/SUM/AVG -> number
|
||||
agg_id1, col_id1, _ = col_unit1
|
||||
agg_id2, col_id2, _ = col_unit2
|
||||
t1 = 'number' if agg_id1 > 2 else db['column_types'][col_id1]
|
||||
t2 = 'number' if agg_id2 > 2 else db['column_types'][col_id2]
|
||||
if (t1 not in ['number', 'time']) or (t2 not in ['number', 'time']) or t1 != t2:
|
||||
return False
|
||||
return True
|
||||
|
||||
def valunit_type(self, val_unit: list, db: dict):
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
if unit_op == 0:
|
||||
agg_id, col_id, _ = col_unit1
|
||||
if agg_id > 2: return 'number'
|
||||
else: return ('number' if col_id == 0 else db['column_types'][col_id])
|
||||
else:
|
||||
return 'number'
|
||||
|
||||
class SchemaID():
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema, table):
|
||||
self._schema = schema
|
||||
self._table = table
|
||||
self._idMap = self._map(self._schema, self._table)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema, table):
|
||||
column_names_original = table['column_names_original']
|
||||
table_names_original = table['table_names_original']
|
||||
#print 'column_names_original: ', column_names_original
|
||||
#print 'table_names_original: ', table_names_original
|
||||
for i, (tab_id, col) in enumerate(column_names_original):
|
||||
if tab_id == -1:
|
||||
idMap = {'*': i}
|
||||
else:
|
||||
key = table_names_original[tab_id].lower()
|
||||
val = col.lower()
|
||||
idMap[key + "." + val] = i
|
||||
|
||||
for i, tab in enumerate(table_names_original):
|
||||
key = tab.lower()
|
||||
idMap[key] = i
|
||||
return idMap
|
||||
|
||||
if __name__ == '__main__':
|
||||
checker = Checker('data/tables.json', 'data/database')
|
||||
train, dev = json.load(open('data/train.json', 'r')), json.load(open('data/dev.json', 'r'))
|
||||
test = json.load(open('data/test.json', 'r'))
|
||||
count = 0
|
||||
for idx, ex in enumerate(train):
|
||||
sql, db = ex['query'].strip(), ex['db_id']
|
||||
flag = checker.validity_check(sql, db)
|
||||
if not flag:
|
||||
print(idx, ': ' + sql + '\t' + db)
|
||||
count += 1
|
||||
print('Total invalid is %d' % (count))
|
|
@ -0,0 +1,151 @@
|
|||
#coding=utf8
|
||||
import os, pickle, json
|
||||
import torch, random
|
||||
import numpy as np
|
||||
from asdl.asdl import ASDLGrammar
|
||||
from asdl.transition_system import TransitionSystem
|
||||
from utils.constants import UNK, GRAMMAR_FILEPATH, SCHEMA_TYPES, RELATIONS
|
||||
from utils.graph_example import GraphFactory
|
||||
from utils.vocab import Vocab
|
||||
from utils.word2vec import Word2vecUtils
|
||||
from transformers import AutoTokenizer
|
||||
from utils.evaluator import Evaluator
|
||||
from itertools import chain
|
||||
|
||||
class Example():
|
||||
|
||||
@classmethod
|
||||
def configuration(cls, plm=None, method='lgesql', table_path='data/tables.json', tables='data/tables.bin', db_dir='data/database'):
|
||||
cls.plm, cls.method = plm, method
|
||||
cls.grammar = ASDLGrammar.from_filepath(GRAMMAR_FILEPATH)
|
||||
cls.trans = TransitionSystem.get_class_by_lang('sql')(cls.grammar)
|
||||
cls.tables = pickle.load(open(tables, 'rb')) if type(tables) == str else tables
|
||||
cls.evaluator = Evaluator(cls.trans, table_path, db_dir)
|
||||
if plm is None:
|
||||
cls.word2vec = Word2vecUtils()
|
||||
cls.tokenizer = lambda x: x
|
||||
cls.word_vocab = Vocab(padding=True, unk=True, boundary=True, default=UNK,
|
||||
filepath='./pretrained_models/glove.42b.300d/vocab.txt', specials=SCHEMA_TYPES) # word vocab for glove.42B.300d
|
||||
else:
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(os.path.join('./pretrained_models', plm))
|
||||
cls.word_vocab = cls.tokenizer.get_vocab()
|
||||
cls.relation_vocab = Vocab(padding=False, unk=False, boundary=False, iterable=RELATIONS, default=None)
|
||||
cls.graph_factory = GraphFactory(cls.method, cls.relation_vocab)
|
||||
|
||||
@classmethod
|
||||
def load_dataset(cls, choice, debug=False):
|
||||
assert choice in ['train', 'dev']
|
||||
fp = os.path.join('data', choice + '.' + cls.method + '.bin')
|
||||
datasets = pickle.load(open(fp, 'rb'))
|
||||
# question_lens = [len(ex['processed_question_toks']) for ex in datasets]
|
||||
# print('Max/Min/Avg question length in %s dataset is: %d/%d/%.2f' % (choice, max(question_lens), min(question_lens), float(sum(question_lens))/len(question_lens)))
|
||||
# action_lens = [len(ex['actions']) for ex in datasets]
|
||||
# print('Max/Min/Avg action length in %s dataset is: %d/%d/%.2f' % (choice, max(action_lens), min(action_lens), float(sum(action_lens))/len(action_lens)))
|
||||
examples, outliers = [], 0
|
||||
for ex in datasets:
|
||||
if choice == 'train' and len(cls.tables[ex['db_id']]['column_names']) > 100:
|
||||
outliers += 1
|
||||
continue
|
||||
examples.append(cls(ex, cls.tables[ex['db_id']]))
|
||||
if debug and len(examples) >= 100:
|
||||
return examples
|
||||
if choice == 'train':
|
||||
print("Skip %d extremely large samples in training dataset ..." % (outliers))
|
||||
return examples
|
||||
|
||||
def __init__(self, ex: dict, db: dict):
|
||||
super(Example, self).__init__()
|
||||
self.ex = ex
|
||||
self.db = db
|
||||
|
||||
""" Mapping word to corresponding index """
|
||||
if Example.plm is None:
|
||||
self.question = ex['processed_question_toks']
|
||||
self.question_id = [Example.word_vocab[w] for w in self.question]
|
||||
self.column = [[db['column_types'][idx].lower()] + c for idx, c in enumerate(db['processed_column_toks'])]
|
||||
self.column_id = [[Example.word_vocab[w] for w in c] for c in self.column]
|
||||
self.table = [['table'] + t for t in db['processed_table_toks']]
|
||||
self.table_id = [[Example.word_vocab[w] for w in t] for t in self.table]
|
||||
else:
|
||||
t = Example.tokenizer
|
||||
self.question = [q.lower() for q in ex['raw_question_toks']]
|
||||
self.question_id = [t.cls_token_id] # map token to id
|
||||
self.question_mask_plm = [] # remove SEP token in our case
|
||||
self.question_subword_len = [] # subword len for each word, exclude SEP token
|
||||
for w in self.question:
|
||||
toks = t.convert_tokens_to_ids(t.tokenize(w))
|
||||
self.question_id.extend(toks)
|
||||
self.question_subword_len.append(len(toks))
|
||||
self.question_mask_plm = [0] + [1] * (len(self.question_id) - 1) + [0]
|
||||
self.question_id.append(t.sep_token_id)
|
||||
|
||||
self.table = [['table'] + t.lower().split() for t in db['table_names']]
|
||||
self.table_id, self.table_mask_plm, self.table_subword_len = [], [], []
|
||||
self.table_word_len = []
|
||||
for s in self.table:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = t.convert_tokens_to_ids(t.tokenize(w))
|
||||
self.table_id.extend(toks)
|
||||
self.table_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
self.table_word_len.append(l)
|
||||
self.table_mask_plm = [1] * len(self.table_id)
|
||||
|
||||
self.column = [[db['column_types'][idx].lower()] + c.lower().split() for idx, (_, c) in enumerate(db['column_names'])]
|
||||
self.column_id, self.column_mask_plm, self.column_subword_len = [], [], []
|
||||
self.column_word_len = []
|
||||
for s in self.column:
|
||||
l = 0
|
||||
for w in s:
|
||||
toks = t.convert_tokens_to_ids(t.tokenize(w))
|
||||
self.column_id.extend(toks)
|
||||
self.column_subword_len.append(len(toks))
|
||||
l += len(toks)
|
||||
self.column_word_len.append(l)
|
||||
self.column_mask_plm = [1] * len(self.column_id) + [0]
|
||||
self.column_id.append(t.sep_token_id)
|
||||
|
||||
self.input_id = self.question_id + self.table_id + self.column_id
|
||||
self.segment_id = [0] * len(self.question_id) + [1] * (len(self.table_id) + len(self.column_id)) \
|
||||
if Example.plm != 'grappa_large_jnt' and not Example.plm.startswith('roberta') \
|
||||
else [0] * (len(self.question_id) + len(self.table_id) + len(self.column_id))
|
||||
|
||||
self.question_mask_plm = self.question_mask_plm + [0] * (len(self.table_id) + len(self.column_id))
|
||||
self.table_mask_plm = [0] * len(self.question_id) + self.table_mask_plm + [0] * len(self.column_id)
|
||||
self.column_mask_plm = [0] * (len(self.question_id) + len(self.table_id)) + self.column_mask_plm
|
||||
|
||||
self.graph = Example.graph_factory.graph_construction(ex, db)
|
||||
|
||||
# outputs
|
||||
self.query = ' '.join(ex['query'].split('\t'))
|
||||
self.ast = ex['ast']
|
||||
self.tgt_action = ex['actions']
|
||||
self.used_tables, self.used_columns = ex['used_tables'], ex['used_columns']
|
||||
|
||||
def get_position_ids(ex, shuffle=True):
|
||||
# cluster columns with their corresponding table and randomly shuffle tables and columns
|
||||
# [CLS] q1 q2 ... [SEP] * t1 c1 c2 c3 t2 c4 c5 ... [SEP]
|
||||
db, table_word_len, column_word_len = ex.db, ex.table_word_len, ex.column_word_len
|
||||
table_num, column_num = len(db['table_names']), len(db['column_names'])
|
||||
question_position_id = list(range(len(ex.question_id)))
|
||||
start = len(question_position_id)
|
||||
table_position_id, column_position_id = [None] * table_num, [None] * column_num
|
||||
column_position_id[0] = list(range(start, start + column_word_len[0]))
|
||||
start += column_word_len[0] # special symbol * first
|
||||
table_idxs = list(range(table_num))
|
||||
if shuffle:
|
||||
random.shuffle(table_idxs)
|
||||
for idx in table_idxs:
|
||||
col_idxs = db['table2columns'][idx]
|
||||
table_position_id[idx] = list(range(start, start + table_word_len[idx]))
|
||||
start += table_word_len[idx]
|
||||
if shuffle:
|
||||
random.shuffle(col_idxs)
|
||||
for col_id in col_idxs:
|
||||
column_position_id[col_id] = list(range(start, start + column_word_len[col_id]))
|
||||
start += column_word_len[col_id]
|
||||
position_id = question_position_id + list(chain.from_iterable(table_position_id)) + \
|
||||
list(chain.from_iterable(column_position_id)) + [start]
|
||||
assert len(position_id) == len(ex.input_id)
|
||||
return position_id
|
|
@ -0,0 +1,76 @@
|
|||
#coding=utf8
|
||||
import numpy as np
|
||||
import dgl, torch, math
|
||||
|
||||
class GraphExample():
|
||||
|
||||
pass
|
||||
|
||||
class BatchedGraph():
|
||||
|
||||
pass
|
||||
|
||||
class GraphFactory():
|
||||
|
||||
def __init__(self, method='rgatsql', relation_vocab=None):
|
||||
super(GraphFactory, self).__init__()
|
||||
self.method = eval('self.' + method)
|
||||
self.batch_method = eval('self.batch_' + method)
|
||||
self.relation_vocab = relation_vocab
|
||||
|
||||
def graph_construction(self, ex: dict, db: dict):
|
||||
return self.method(ex, db)
|
||||
|
||||
def rgatsql(self, ex, db):
|
||||
graph = GraphExample()
|
||||
local_edges = ex['graph'].local_edges
|
||||
rel_ids = list(map(lambda r: self.relation_vocab[r[2]], local_edges))
|
||||
graph.local_edges = torch.tensor(rel_ids, dtype=torch.long)
|
||||
global_edges = ex['graph'].global_edges
|
||||
rel_ids = list(map(lambda r: self.relation_vocab[r[2]], global_edges))
|
||||
graph.global_edges = torch.tensor(rel_ids, dtype=torch.long)
|
||||
graph.local_g, graph.global_g = ex['graph'].local_g, ex['graph'].global_g
|
||||
graph.gp = ex['graph'].gp
|
||||
graph.question_mask = torch.tensor(ex['graph'].question_mask, dtype=torch.bool)
|
||||
graph.schema_mask = torch.tensor(ex['graph'].schema_mask, dtype=torch.bool)
|
||||
graph.node_label = torch.tensor(ex['graph'].node_label, dtype=torch.float)
|
||||
# extract local relations (used in msde), global_edges = local_edges + nonlocal_edges
|
||||
local_enum, global_enum = graph.local_edges.size(0), graph.global_edges.size(0)
|
||||
graph.local_mask = torch.tensor([1] * local_enum + [0] * (global_enum - local_enum), dtype=torch.bool)
|
||||
return graph
|
||||
|
||||
def lgesql(self, ex, db):
|
||||
graph = self.rgatsql(ex, db)
|
||||
# add line graph
|
||||
graph.lg = ex['graph'].lg
|
||||
return graph
|
||||
|
||||
def batch_graphs(self, ex_list, device, train=True, **kwargs):
|
||||
""" Batch graphs in example list """
|
||||
return self.batch_method(ex_list, device, train=train, **kwargs)
|
||||
|
||||
def batch_lgesql(self, ex_list, device, train=True, **kwargs):
|
||||
bg = self.batch_rgatsql(ex_list, device, train=train, **kwargs)
|
||||
src_ids, dst_ids = bg.local_g.edges(order='eid')
|
||||
bg.src_ids, bg.dst_ids = src_ids.long(), dst_ids.long()
|
||||
bg.lg = dgl.batch([ex.graph.lg for ex in ex_list]).to(device)
|
||||
return bg
|
||||
|
||||
def batch_rgatsql(self, ex_list, device, train=True, **kwargs):
|
||||
# method = kwargs.pop('local_and_nonlocal', 'global')
|
||||
graph_list = [ex.graph for ex in ex_list]
|
||||
bg = BatchedGraph()
|
||||
bg.local_g = dgl.batch([ex.local_g for ex in graph_list]).to(device)
|
||||
# bg.local_edges = torch.cat([ex.local_edges for ex in graph_list], dim=0).to(device)
|
||||
bg.local_mask = torch.cat([ex.graph.local_mask for ex in ex_list], dim=0).to(device)
|
||||
bg.global_g = dgl.batch([ex.global_g for ex in graph_list]).to(device)
|
||||
bg.global_edges = torch.cat([ex.global_edges for ex in graph_list], dim=0).to(device)
|
||||
if train:
|
||||
bg.question_mask = torch.cat([ex.question_mask for ex in graph_list], dim=0).to(device)
|
||||
bg.schema_mask = torch.cat([ex.schema_mask for ex in graph_list], dim=0).to(device)
|
||||
smoothing = kwargs.pop('smoothing', 0.0)
|
||||
node_label = torch.cat([ex.node_label for ex in graph_list], dim=0)
|
||||
node_label = node_label.masked_fill_(~ node_label.bool(), 2 * smoothing) - smoothing
|
||||
bg.node_label = node_label.to(device)
|
||||
bg.gp = dgl.batch([ex.gp for ex in graph_list]).to(device)
|
||||
return bg
|
|
@ -0,0 +1,48 @@
|
|||
#coding=utf8
|
||||
import sys, os
|
||||
|
||||
EXP_PATH = 'exp'
|
||||
|
||||
def hyperparam_path(args):
|
||||
if args.read_model_path and args.testing:
|
||||
return args.read_model_path
|
||||
exp_path = hyperparam_path_text2sql(args)
|
||||
if not os.path.exists(exp_path):
|
||||
os.makedirs(exp_path)
|
||||
return exp_path
|
||||
|
||||
def hyperparam_path_text2sql(args):
|
||||
task = 'task_%s__model_%s_view_%s' % (args.task, args.model, args.local_and_nonlocal)
|
||||
task += '' if 'without' in args.output_model else '_gp_%s' % (args.smoothing)
|
||||
# encoder params
|
||||
exp_path = 'emb_%s' % (args.embed_size) if args.plm is None else 'plm_%s' % (args.plm)
|
||||
exp_path += '__gnn_%s_x_%s' % (args.gnn_hidden_size, args.gnn_num_layers)
|
||||
exp_path += '__share' if args.relation_share_layers else ''
|
||||
exp_path += '__head_%s' % (args.num_heads)
|
||||
exp_path += '__share' if args.relation_share_heads else ''
|
||||
exp_path += '__dp_%s' % (args.dropout)
|
||||
exp_path += '__dpa_%s' % (args.attn_drop)
|
||||
exp_path += '__dpc_%s' % (args.drop_connect)
|
||||
# decoder params
|
||||
# exp_path += '__cell_%s_%s_x_%s' % (args.lstm, args.lstm_hidden_size, args.lstm_num_layers)
|
||||
# exp_path += '_chunk_%s' % (args.chunk_size) if args.lstm == 'onlstm' else ''
|
||||
# exp_path += '_no' if args.no_parent_state else ''
|
||||
# exp_path += '__attvec_%s' % (args.att_vec_size)
|
||||
# exp_path += '__sepcxt' if args.sep_cxt else '__jointcxt'
|
||||
# exp_path += '_no' if args.no_context_feeding else ''
|
||||
# exp_path += '__ae_%s' % (args.action_embed_size)
|
||||
# exp_path += '_no' if args.no_parent_production_embed else ''
|
||||
# exp_path += '__fe_%s' % ('no' if args.no_parent_field_embed else args.field_embed_size)
|
||||
# exp_path += '__te_%s' % ('no' if args.no_parent_field_type_embed else args.type_embed_size)
|
||||
# training params
|
||||
exp_path += '__bs_%s' % (args.batch_size)
|
||||
exp_path += '__lr_%s' % (args.lr) if args.plm is None else '__lr_%s_ld_%s' % (args.lr, args.layerwise_decay)
|
||||
exp_path += '__l2_%s' % (args.l2)
|
||||
exp_path += '__wp_%s' % (args.warmup_ratio)
|
||||
exp_path += '__sd_%s' % (args.lr_schedule)
|
||||
exp_path += '__me_%s' % (args.max_epoch)
|
||||
exp_path += '__mn_%s' % (args.max_norm)
|
||||
exp_path += '__bm_%s' % (args.beam_size)
|
||||
exp_path += '__seed_%s' % (args.seed)
|
||||
exp_path = os.path.join(EXP_PATH, task, exp_path)
|
||||
return exp_path
|
|
@ -0,0 +1,46 @@
|
|||
#coding=utf8
|
||||
""" Utility functions include:
|
||||
1. set output logging path
|
||||
2. set random seed for all libs
|
||||
3. select torch.device
|
||||
"""
|
||||
|
||||
import sys, os, logging
|
||||
import random, torch, dgl
|
||||
import numpy as np
|
||||
|
||||
def set_logger(exp_path, testing=False):
|
||||
logFormatter = logging.Formatter('%(asctime)s - %(message)s') #('%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('mylogger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if testing:
|
||||
fileHandler = logging.FileHandler('%s/log_test.txt' % (exp_path), mode='w')
|
||||
else:
|
||||
fileHandler = logging.FileHandler('%s/log_train.txt' % (exp_path), mode='w')
|
||||
fileHandler.setFormatter(logFormatter)
|
||||
logger.addHandler(fileHandler)
|
||||
consoleHandler = logging.StreamHandler(sys.stdout)
|
||||
consoleHandler.setFormatter(logFormatter)
|
||||
logger.addHandler(consoleHandler)
|
||||
return logger
|
||||
|
||||
def set_random_seed(random_seed=999):
|
||||
random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
dgl.random.seed(random_seed)
|
||||
|
||||
def set_torch_device(deviceId):
|
||||
if deviceId < 0:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
assert torch.cuda.device_count() >= deviceId + 1
|
||||
device = torch.device("cuda:%d" % (deviceId))
|
||||
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # used when debug
|
||||
## These two sentences are used to ensure reproducibility with cudnnbacken
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
return device
|
|
@ -0,0 +1,243 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for BERT model."""
|
||||
|
||||
import logging
|
||||
import re, math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_optimizer(model, args, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||
plm = hasattr(model.encoder.input_layer, 'plm_model')
|
||||
if plm and args.layerwise_decay <= 0.: # fix plm params
|
||||
for n, p in model.named_parameters():
|
||||
if 'plm_model' in n:
|
||||
p.requires_grad = False
|
||||
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
if plm and 0. < args.layerwise_decay <= 0.5: # seperate lr for plm
|
||||
grouped_params = [
|
||||
{'params': list(set([p for n, p in params if 'plm_model' in n and not any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': args.l2},
|
||||
{'params': list(set([p for n, p in params if 'plm_model' in n and any(nd in n for nd in no_decay)])), 'lr': args.layerwise_decay * args.lr, 'weight_decay': 0.0},
|
||||
{'params': list(set([p for n, p in params if 'plm_model' not in n and not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2},
|
||||
{'params': list(set([p for n, p in params if 'plm_model' not in n and any(nd in n for nd in no_decay)])), 'weight_decay': 0.0},
|
||||
]
|
||||
print('Use seperate lr %f for pretrained model ...' % (args.lr * args.layerwise_decay))
|
||||
elif plm and 0.5 < args.layerwise_decay < 1.: # lr decay layerwise for plm
|
||||
pattern = r'encoder\.layer\.(.*?)\.'
|
||||
num_layers = int(model.encoder.input_layer.plm_model.config.num_hidden_layers)
|
||||
groups = {"decay": defaultdict(list), "no_decay": defaultdict(list)} # record grouped params
|
||||
for n, p in params:
|
||||
res = re.search(pattern, n) if 'plm_model' in n else None
|
||||
depth = int(res.group(1)) if res is not None else 0 if 'plm_model' in n else num_layers
|
||||
if any(nd in n for nd in no_decay):
|
||||
groups["no_decay"][int(depth)].append(p)
|
||||
else:
|
||||
groups["decay"][int(depth)].append(p)
|
||||
grouped_params = []
|
||||
for d in groups["decay"]:
|
||||
lr = args.lr * (args.layerwise_decay ** (num_layers - d))
|
||||
grouped_params.append({'params': list(set(groups["decay"][d])), 'lr': lr, 'weight_decay': args.l2})
|
||||
for d in groups["no_decay"]:
|
||||
lr = args.lr * (args.layerwise_decay ** (num_layers - d))
|
||||
grouped_params.append({'params': list(set(groups["no_decay"][d])), 'lr': lr, 'weight_decay': 0.0})
|
||||
print('Use layerwise decay (rate %f) lr %f for pretrained model ...' % (args.layerwise_decay, args.lr))
|
||||
else: # the same lr for plm and other modules
|
||||
grouped_params = [
|
||||
{'params': list(set([p for n, p in params if not any(nd in n for nd in no_decay)])), 'weight_decay': args.l2},
|
||||
{'params': list(set([p for n, p in params if any(nd in n for nd in no_decay)])), 'weight_decay': 0.0},
|
||||
]
|
||||
print('Use the same lr %f for all parameters ...' % (args.lr))
|
||||
optimizer = AdamW(grouped_params, lr=args.lr, max_grad_norm=args.max_norm)
|
||||
schedule_func = schedule_dict[args.lr_schedule]
|
||||
scheduler = schedule_func(optimizer, num_warmup_steps, num_training_steps, last_epoch=last_epoch)
|
||||
return optimizer, scheduler
|
||||
|
||||
def get_ratsql_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||
""" Create a schedule with a learning rate that decreases according to the formular
|
||||
in RATSQL model
|
||||
"""
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1.0, num_warmup_steps))
|
||||
return max(0.0, math.sqrt((num_training_steps - current_step) / float(num_training_steps - num_warmup_steps)))
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
def get_constant_schedule(optimizer, *args, last_epoch=-1):
|
||||
""" Create a schedule with a constant learning rate.
|
||||
"""
|
||||
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
|
||||
""" Create a schedule with a constant learning rate preceded by a warmup
|
||||
period during which the learning rate increases linearly between 0 and 1.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1.0, num_warmup_steps))
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||
""" Create a schedule with a learning rate that decreases linearly after
|
||||
linearly increasing during a warmup period.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(
|
||||
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
|
||||
""" Create a schedule with a learning rate that decreases following the
|
||||
values of the cosine function between 0 and `pi * cycles` after a warmup
|
||||
period during which it increases linearly between 0 and 1.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1
|
||||
):
|
||||
""" Create a schedule with a learning rate that decreases following the
|
||||
values of the cosine function with several hard restarts, after a warmup
|
||||
period during which it increases linearly between 0 and 1.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
if progress >= 1.0:
|
||||
return 0.0
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
schedule_dict = {
|
||||
"constant": get_constant_schedule,
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"ratsql": get_ratsql_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
}
|
||||
|
||||
class AdamW(Optimizer):
|
||||
""" Implements Adam algorithm with weight decay fix.
|
||||
|
||||
Parameters:
|
||||
lr (float): learning rate. Default 1e-3.
|
||||
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
|
||||
eps (float): Adams epsilon. Default: 1e-6
|
||||
weight_decay (float): Weight decay. Default: 0.0
|
||||
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, max_grad_norm=-1, correct_bias=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, max_grad_norm=max_grad_norm, correct_bias=correct_bias)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
|
||||
|
||||
return loss
|
|
@ -0,0 +1,70 @@
|
|||
#coding=utf8
|
||||
from utils.constants import PAD, UNK, BOS, EOS
|
||||
|
||||
class Vocab():
|
||||
|
||||
def __init__(self, padding=False, unk=False, boundary=False, min_freq=1,
|
||||
filepath=None, iterable=None, default=UNK, specials=[]):
|
||||
super(Vocab, self).__init__()
|
||||
self.word2id = dict()
|
||||
self.id2word = dict()
|
||||
self.default = default # if default is None, ensure that no oov words
|
||||
if padding:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[PAD], self.id2word[idx] = idx, PAD
|
||||
if unk:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[UNK], self.id2word[idx] = idx, UNK
|
||||
if boundary:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[BOS], self.id2word[idx] = idx, BOS
|
||||
self.word2id[EOS], self.id2word[idx + 1] = idx + 1, EOS
|
||||
for w in specials:
|
||||
if w not in self.word2id:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[w], self.id2word[idx] = idx, w
|
||||
if filepath is not None:
|
||||
self.from_filepath(filepath, min_freq=min_freq)
|
||||
elif iterable is not None:
|
||||
self.from_iterable(iterable)
|
||||
assert (self.default is None) or (self.default in self.word2id)
|
||||
|
||||
def from_filepath(self, filepath, min_freq=1):
|
||||
with open(filepath, 'r', encoding='utf-8') as inf:
|
||||
for line in inf:
|
||||
line = line.strip()
|
||||
if line == '': continue
|
||||
line = line.split('\t') # ignore count or frequency
|
||||
if len(line) == 1:
|
||||
word, freq = line[0], min_freq
|
||||
else:
|
||||
assert len(line) == 2
|
||||
word, freq = line
|
||||
word = word.lower()
|
||||
if word not in self.word2id and int(freq) >= min_freq:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[word] = idx
|
||||
self.id2word[idx] = word
|
||||
|
||||
def from_iterable(self, iterable):
|
||||
for item in iterable:
|
||||
if item not in self.word2id:
|
||||
idx = len(self.word2id)
|
||||
self.word2id[item] = idx
|
||||
self.id2word[idx] = item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2id)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2id)
|
||||
|
||||
def __getitem__(self, key):
|
||||
""" If self.default is None, it means we do not allow out of vocabulary token;
|
||||
If self.default is not None, we get the idx of self.default if key does not exist.
|
||||
"""
|
||||
if self.default is None:
|
||||
return self.word2id[key]
|
||||
else:
|
||||
return self.word2id.get(key, self.word2id[self.default])
|
|
@ -0,0 +1,37 @@
|
|||
#coding=utf8
|
||||
|
||||
from embeddings import GloveEmbedding
|
||||
import numpy as np
|
||||
from utils.constants import PAD
|
||||
import torch, random
|
||||
|
||||
class Word2vecUtils():
|
||||
|
||||
def __init__(self):
|
||||
super(Word2vecUtils, self).__init__()
|
||||
self.word_embed = GloveEmbedding('common_crawl_48', d_emb=300)
|
||||
self.initializer = lambda: np.random.normal(size=300).tolist()
|
||||
|
||||
def load_embeddings(self, module, vocab, device='cpu'):
|
||||
""" Initialize the embedding with glove and char embedding
|
||||
"""
|
||||
emb_size = module.weight.data.size(-1)
|
||||
assert emb_size == 300, 'Embedding size is not 300, cannot be initialized by GLOVE'
|
||||
outliers = 0
|
||||
for word in vocab.word2id:
|
||||
if word == PAD: # PAD symbol is always 0-vector
|
||||
module.weight.data[vocab[PAD]] = torch.zeros(emb_size, dtype=torch.float, device=device)
|
||||
continue
|
||||
word_emb = self.word_embed.emb(word, default='none')
|
||||
if word_emb[0] is None: # oov
|
||||
word_emb = self.initializer()
|
||||
outliers += 1
|
||||
module.weight.data[vocab[word]] = torch.tensor(word_emb, dtype=torch.float, device=device)
|
||||
return 1 - outliers / float(len(vocab))
|
||||
|
||||
def emb(self, word):
|
||||
word_emb = self.word_embed.emb(word, default='none')
|
||||
if word_emb[0] is None:
|
||||
return None
|
||||
else:
|
||||
return word_emb
|
Loading…
Reference in New Issue