DAMO-ConvAI/sunsql/model/decoder/sql_parser.py

398 lines
24 KiB
Python
Raw Permalink Normal View History

2022-09-19 09:42:32 +08:00
# coding=utf-8
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from asdl.sql.sql_transition_system import SelectColumnAction, SelectTableAction
from asdl.transition_system import ApplyRuleAction, ReduceAction
from asdl.action_info import ActionInfo
from asdl.hypothesis import Hypothesis
from asdl.decode_hypothesis import DecodeHypothesis
from utils.batch import Batch
from model.decoder.onlstm import LSTM, ONLSTM
from model.model_utils import Registrable, MultiHeadAttention
@Registrable.register('decoder_tranx')
class SqlParser(nn.Module):
def __init__(self, args, transition_system):
super(SqlParser, self).__init__()
self.args = args
self.transition_system = transition_system
self.grammar = self.transition_system.grammar
# the last entry is the embedding for Reduce action
self.production_embed = nn.Embedding(len(transition_system.grammar) + 1, args.action_embed_size)
# embedding table for ASDL fields in constructors
self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size)
# embedding table for ASDL types
self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size)
# input of decoder lstm
input_dim = args.action_embed_size # previous action
# parent node
cxt_num = 2 if args.sep_cxt else 1
input_dim += args.gnn_hidden_size * cxt_num * int(not args.no_context_feeding)
input_dim += args.action_embed_size * int(not args.no_parent_production_embed)
input_dim += args.field_embed_size * int(not args.no_parent_field_embed)
input_dim += args.type_embed_size * int(not args.no_parent_field_type_embed)
input_dim += args.lstm_hidden_size * int(not args.no_parent_state)
cell_constructor = ONLSTM if args.lstm == 'onlstm' else LSTM
self.decoder_lstm = cell_constructor(input_dim, args.lstm_hidden_size, num_layers=args.lstm_num_layers,
chunk_num=args.chunk_size, dropout=args.dropout, dropconnect=args.drop_connect)
# transform column embedding to production embedding space
self.column_lstm_input = nn.Linear(args.gnn_hidden_size, args.action_embed_size)
# transform table embedding to production embedding space
self.table_lstm_input = self.column_lstm_input # nn.Linear(args.gnn_hidden_size, args.action_embed_size)
self.context_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
if args.sep_cxt: # calculate seperate context vector for question and database schema
self.schema_attn = MultiHeadAttention(args.gnn_hidden_size, args.lstm_hidden_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
# feature vector before ApplyRule or SelectColumn/Table
self.att_vec_linear = nn.Sequential(nn.Linear(args.lstm_hidden_size + cxt_num * args.gnn_hidden_size, args.att_vec_size), nn.Tanh())
self.apply_rule_affine = nn.Linear(args.att_vec_size, args.action_embed_size, bias=False)
self.apply_rule = lambda x: F.linear(x, self.production_embed.weight) # re-use the action embedding matrix
self.select_column = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
self.select_table = MultiHeadAttention(args.gnn_hidden_size, args.att_vec_size, args.gnn_hidden_size, args.gnn_hidden_size,
num_heads=args.num_heads, feat_drop=args.dropout)
def score(self, encodings, mask, h0, batch):
""" Training function
@input:
encodings: encoded representations and mask matrix from encoder
bsize x seqlen x gnn_hidden_size
batch: see utils.batch, we use fields
batch.examples, batch.get_frontier_prod_idx(t),
batch.get_frontier_field_idx(t), batch.get_frontier_field_type_idx(t),
batch.max_action_num, example.tgt_action (ActionInfo)
output:
loss: sum of loss for each training batch
"""
args = self.args
zero_action_embed = encodings.new_zeros(args.action_embed_size)
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
q, tab, col = encodings.split(split_len, dim=1)
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
if args.sep_cxt:
encodings, mask = q, q_mask
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
schema_context, _ = self.schema_attn(schema, h0, schema_mask)
context, _ = self.context_attn(encodings, h0, mask)
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
h_c = (h0, h0.new_zeros(h0.size()))
action_probs = [[] for _ in range(encodings.size(0))]
history_states = []
for t in range(batch.max_action_num):
# x: [prev_action_embed, [prev_context, parent_production, parent_field, parent_type, parent_state]]
if t == 0:
x = encodings.new_zeros(encodings.size(0), self.decoder_lstm.input_size)
offset = args.action_embed_size
if not args.no_context_feeding:
x[:, offset: offset + args.gnn_hidden_size] = context
offset += args.gnn_hidden_size
if args.sep_cxt:
x[:, offset: offset + args.gnn_hidden_size] = schema_context
offset += args.gnn_hidden_size
offset += args.action_embed_size * int(not args.no_parent_production_embed)
offset += args.field_embed_size * int(not args.no_parent_field_embed)
if not args.no_parent_field_type_embed:
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * x.size(0), dtype=torch.long, device=x.device)
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
else:
prev_action_embed = []
for e_id, example in enumerate(batch.examples):
if t < len(example.tgt_action):
prev_action = example.tgt_action[t - 1].action
if isinstance(prev_action, ApplyRuleAction):
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
elif isinstance(prev_action, ReduceAction):
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
elif isinstance(prev_action, SelectColumnAction):
# map schema item into prod embed space
col_embed = self.column_lstm_input(col[e_id, prev_action.column_id])
prev_action_embed.append(col_embed)
elif isinstance(prev_action, SelectTableAction):
tab_embed = self.table_lstm_input(tab[e_id, prev_action.table_id])
prev_action_embed.append(tab_embed)
else:
raise ValueError('Unrecognized previous action object!')
else:
prev_action_embed.append(zero_action_embed)
inputs = [torch.stack(prev_action_embed)]
if not args.no_context_feeding:
inputs.append(context)
if args.sep_cxt:
inputs.append(schema_context)
if not args.no_parent_production_embed:
inputs.append(self.production_embed(batch.get_frontier_prod_idx(t)))
if not args.no_parent_field_embed:
inputs.append(self.field_embed(batch.get_frontier_field_idx(t)))
if not args.no_parent_field_type_embed:
inputs.append(self.type_embed(batch.get_frontier_field_type_idx(t)))
if not args.no_parent_state:
actions_t = [e.tgt_action[t] if t < len(e.tgt_action) else None for e in batch.examples]
parent_state = torch.stack([history_states[p_t][e_id]
for e_id, p_t in
enumerate([a_t.parent_t if a_t else 0 for a_t in actions_t])])
inputs.append(parent_state)
x = torch.cat(inputs, dim=-1)
# advance decoder lstm and attention calculation
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c, start=(t==0))
out = out.squeeze(1)
context, _ = self.context_attn(encodings, out, mask)
if args.sep_cxt:
schema_context, _ = self.schema_attn(schema, out, schema_mask)
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # bsize x args.att_vec_size
else:
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
# action logprobs
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # bsize x prod_num
_, select_tab_prob = self.select_table(tab, att_vec, tab_mask)
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
_, select_col_prob = self.select_column(col, att_vec, col_mask)
select_col_logprob = torch.log(select_col_prob + 1e-32)
for e_id, example in enumerate(batch.examples):
if t < len(example.tgt_action):
action_t = example.tgt_action[t].action
if isinstance(action_t, ApplyRuleAction):
logprob_t = apply_rule_logprob[e_id, self.grammar.prod2id[action_t.production]]
# print('Rule %s with prob %s' %(action_t.production, logprob_t.item()))
elif isinstance(action_t, ReduceAction):
logprob_t = apply_rule_logprob[e_id, len(self.grammar)]
# print('Rule %s with prob %s' % ('Reduce', logprob_t.item()))
elif isinstance(action_t, SelectColumnAction):
logprob_t = select_col_logprob[e_id, action_t.column_id]
# print('SelectColumn %s with prob %s' % (action_t.column_id, logprob_t.item()))
elif isinstance(action_t, SelectTableAction):
logprob_t = select_tab_logprob[e_id, action_t.table_id]
# print('SelectTable %s with prob %s' % (action_t.table_id, logprob_t.item()))
else:
raise ValueError('Unrecognized action object!')
action_probs[e_id].append(logprob_t)
h_c = (h_t, c_t)
history_states.append(h_t[-1])
# loss is negative sum of all the action probabilities
loss = - torch.stack([torch.stack(logprob_i).sum() for logprob_i in action_probs]).sum()
return loss
def parse(self, encodings, mask, h0, batch, beam_size=5):
""" Parse one by one, batch size for each args in encodings is 1
"""
args = self.args
zero_action_embed = encodings.new_zeros(args.action_embed_size)
assert encodings.size(0) == 1 and mask.size(0) == 1
encodings, mask = encodings.repeat(beam_size, 1, 1), mask.repeat(beam_size, 1)
split_len = [batch.max_question_len, batch.max_table_len, batch.max_column_len]
q, tab, col = encodings.split(split_len, dim=1)
q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
if args.sep_cxt:
encodings, mask = q, q_mask
schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat([tab_mask, col_mask], dim=1)
schema_context, _ = self.schema_attn(schema[:1], h0, schema_mask[:1])
context, _ = self.context_attn(encodings[:1], h0, mask[:1])
h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
# h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
h_c = (h0, h0.new_zeros(h0.size()))
hypotheses = [DecodeHypothesis()]
hyp_states, completed_hypotheses, t = [[]], [], 0
while t < args.decode_max_step:
hyp_num = len(hypotheses)
cur_encodings, cur_mask = encodings[:hyp_num], mask[:hyp_num]
if args.sep_cxt:
cur_schema, cur_schema_mask = schema[:hyp_num], schema_mask[:hyp_num]
cur_tab, cur_tab_mask, cur_col, cur_col_mask = tab[:hyp_num], tab_mask[:hyp_num], col[:hyp_num], col_mask[:hyp_num]
# x: [prev_action_embed, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_state]
if t == 0:
x = encodings.new_zeros(hyp_num, self.decoder_lstm.input_size)
offset = args.action_embed_size
if not args.no_context_feeding:
x[:, offset: offset + args.gnn_hidden_size] = context
offset += args.gnn_hidden_size
if args.sep_cxt:
x[:, offset: offset + args.gnn_hidden_size] = schema_context
offset += args.gnn_hidden_size
offset += args.action_embed_size * int(not args.no_parent_production_embed)
offset += args.field_embed_size * int(not args.no_parent_field_embed)
if not args.no_parent_field_type_embed:
start_rule = torch.tensor([self.grammar.type2id[self.grammar.root_type]] * hyp_num, dtype=torch.long, device=x.device)
x[:, offset: offset + args.type_embed_size] = self.type_embed(start_rule)
else:
prev_action_embed = []
for e_id, hyp in enumerate(hypotheses):
prev_action = hyp.actions[-1]
if isinstance(prev_action, ApplyRuleAction):
prev_action_embed.append(self.production_embed.weight[self.grammar.prod2id[prev_action.production]])
elif isinstance(prev_action, ReduceAction):
prev_action_embed.append(self.production_embed.weight[len(self.grammar)])
elif isinstance(prev_action, SelectColumnAction): # need to first map schema item into prod embed space
col_embed = self.column_lstm_input(cur_col[e_id, prev_action.column_id])
prev_action_embed.append(col_embed)
elif isinstance(prev_action, SelectTableAction):
tab_embed = self.table_lstm_input(cur_tab[e_id, prev_action.table_id])
prev_action_embed.append(tab_embed)
else:
raise ValueError('Unrecognized previous action object!')
inputs = [torch.stack(prev_action_embed)]
if not args.no_context_feeding:
inputs.append(context)
if args.sep_cxt:
inputs.append(schema_context)
if not args.no_parent_production_embed:
frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
frontier_prod_embeds = self.production_embed(
torch.tensor([self.grammar.prod2id[prod] for prod in frontier_prods], dtype=torch.long, device=encodings.device))
inputs.append(frontier_prod_embeds)
if not args.no_parent_field_embed:
frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
frontier_field_embeds = self.field_embed(
torch.tensor([self.grammar.field2id[field] for field in frontier_fields], dtype=torch.long, device=encodings.device))
inputs.append(frontier_field_embeds)
if not args.no_parent_field_type_embed:
frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
frontier_field_type_embeds = self.type_embed(
torch.tensor([self.grammar.type2id[tp] for tp in frontier_field_types], dtype=torch.long, device=encodings.device))
inputs.append(frontier_field_type_embeds)
if not args.no_parent_state:
p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
parent_states = torch.stack([hyp_states[hyp_id][p_t] for hyp_id, p_t in enumerate(p_ts)])
inputs.append(parent_states)
x = torch.cat(inputs, dim=-1) # hyp_num x input_size
# advance decoder lstm and attention calculation
out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c)
out = out.squeeze(1)
context, _ = self.context_attn(cur_encodings, out, cur_mask)
if args.sep_cxt:
schema_context, _ = self.schema_attn(cur_schema, out, cur_schema_mask)
att_vec = self.att_vec_linear(torch.cat([out, context, schema_context], dim=-1)) # hyp_num x args.att_vec_size
else:
att_vec = self.att_vec_linear(torch.cat([out, context], dim=-1))
# action logprobs
apply_rule_logprob = F.log_softmax(self.apply_rule(self.apply_rule_affine(att_vec)), dim=-1) # remaining_sents*beam x prod_num
_, select_tab_prob = self.select_table(cur_tab, att_vec, cur_tab_mask)
select_tab_logprob = torch.log(select_tab_prob + 1e-32)
_, select_col_prob = self.select_column(cur_col, att_vec, cur_col_mask)
select_col_logprob = torch.log(select_col_prob + 1e-32)
new_hyp_meta = []
for hyp_id, hyp in enumerate(hypotheses):
action_types = self.transition_system.get_valid_continuation_types(hyp)
for action_type in action_types:
if action_type == ApplyRuleAction:
productions = self.transition_system.get_valid_continuating_productions(hyp)
for production in productions:
prod_id = self.grammar.prod2id[production]
prod_score = apply_rule_logprob[hyp_id, prod_id]
new_hyp_score = hyp.score + prod_score
meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
'score': prod_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == ReduceAction:
action_score = apply_rule_logprob[hyp_id, len(self.grammar)]
new_hyp_score = hyp.score + action_score
meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
'score': action_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == SelectColumnAction:
col_num = cur_col_mask[hyp_id].int().sum().item()
for col_id in range(col_num):
col_sel_score = select_col_logprob[hyp_id, col_id]
new_hyp_score = hyp.score + col_sel_score
meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
'score': col_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif action_type == SelectTableAction:
tab_num = cur_tab_mask[hyp_id].int().sum().item()
for tab_id in range(tab_num):
tab_sel_score = select_tab_logprob[hyp_id, tab_id]
new_hyp_score = hyp.score + tab_sel_score
meta_entry = {'action_type': 'sel_tab', 'tab_id': tab_id,
'score': tab_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
else:
raise ValueError('Unrecognized action type while decoding!')
if not new_hyp_meta: break
# rank and pick
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta])
top_new_hyp_scores, meta_ids = new_hyp_scores.topk(min(beam_size, new_hyp_scores.size(0)))
live_hyp_ids, new_hypotheses = [], []
for new_hyp_score, meta_id in zip(top_new_hyp_scores.tolist(), meta_ids.tolist()):
action_info = ActionInfo()
hyp_meta_entry = new_hyp_meta[meta_id]
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
prev_hyp = hypotheses[prev_hyp_id]
action_type_str = hyp_meta_entry['action_type']
if action_type_str == 'apply_rule':
prod_id = hyp_meta_entry['prod_id']
if prod_id < len(self.grammar):
production = self.grammar.id2prod[prod_id]
action = ApplyRuleAction(production)
else:
action = ReduceAction()
elif action_type_str == 'sel_col':
action = SelectColumnAction(hyp_meta_entry['col_id'])
elif action_type_str == 'sel_tab':
action = SelectTableAction(hyp_meta_entry['tab_id'])
else:
raise ValueError('Unrecognized action type str!')
action_info.action = action
action_info.t = t
if t > 0:
action_info.parent_t = prev_hyp.frontier_node.created_time
action_info.frontier_prod = prev_hyp.frontier_node.production
action_info.frontier_field = prev_hyp.frontier_field.field
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
new_hyp.score = new_hyp_score
if new_hyp.completed:
completed_hypotheses.append(new_hyp)
else:
new_hypotheses.append(new_hyp)
live_hyp_ids.append(prev_hyp_id)
if live_hyp_ids:
if not args.no_parent_state:
hyp_states = [hyp_states[i] + [h_t[-1, i]] for i in live_hyp_ids]
h_c = (h_t[:, live_hyp_ids], c_t[:, live_hyp_ids])
if not args.no_context_feeding:
context = context[live_hyp_ids]
if args.sep_cxt:
schema_context = schema_context[live_hyp_ids]
hypotheses = new_hypotheses
t += 1
else:
break
# for idx, hyp in enumerate(new_hypotheses):
# print('Idx %d' % (idx))
# print('Tree:', hyp.tree.to_string())
# print('Action:', hyp.action_infos[-1])
# print('============================================')
# input('********************Next parse step ....************************')
if len(completed_hypotheses) == 0: # no completed sql
completed_hypotheses.append(DecodeHypothesis())
else:
completed_hypotheses.sort(key=lambda hyp: - hyp.score) # / hyp.tree.size
return completed_hypotheses