632 lines
30 KiB
Python
632 lines
30 KiB
Python
import torch
|
|
import csv
|
|
import os
|
|
import re
|
|
import json
|
|
import numpy as np
|
|
from settings import parse_args
|
|
|
|
class PromptCLSDataset(torch.utils.data.Dataset):
|
|
def __init__(self, task_name, tokz, data_path, num_workers=8, ctx_max_len=100):
|
|
self.num_workers = num_workers
|
|
self.data_path = data_path
|
|
self.ctx_max_len = ctx_max_len
|
|
self.tokz = tokz
|
|
self.max_ans_len = 0
|
|
self.pseudo_data_prompt = self._pseudo_data_prompt(task_name) # Prompt used to infer pseudo data which contains task ids.
|
|
self.pseudo_ans_prompt = self._pseudo_ans_prompt() # Answer prompt added before inferring the label.
|
|
self.pseudo_data_prompt_id = tokz.encode(self.pseudo_data_prompt) # prompt id used to infer pseudo data
|
|
self.pseudo_ans_prompt_id = tokz.encode(self.pseudo_ans_prompt) # prompt id used to infer pseudo data
|
|
|
|
self.pseudo_gene_prompt = self._pseudo_general_prompt() #
|
|
self.pseudo_gene_prompt_id = tokz.encode(self.pseudo_gene_prompt)
|
|
|
|
# load data
|
|
with open(data_path, "r", encoding='utf-8') as f:
|
|
data = [json.loads(i) for i in f.readlines()]
|
|
data = [self.parse_example(i) for i in data] # [utter, label]
|
|
|
|
self.data = []
|
|
if len(data) > 0:
|
|
self.data = self.data_tokenization(task_name, data)
|
|
|
|
@staticmethod
|
|
def apply_prompt(task_name, text, ctx_max_len): # make a prompt based contexts
|
|
text = text.split(' ')[:ctx_max_len]
|
|
# return f"In the \"{task_name}\" task, which intent category best describes: \" {' '.join(text)} \""
|
|
# return f"In the \"{task_name}\" task, which intent category best describes: \" {' '.join(text)} \"? Answer:" # add spaces before and after the utterence.
|
|
return f"In the \"{task_name}\" task, which intent category best describes: \" {' '.join(text)} \"? Answer: " # add spaces before and after the utterence.
|
|
|
|
@staticmethod
|
|
def apply_general_prompt(text, ctx_max_len):
|
|
text = text.split(' ')[:ctx_max_len]
|
|
return f"In the current task, which intent category best describes: \" {' '.join(text)} \"? Answer: "
|
|
|
|
@staticmethod
|
|
def _pseudo_data_prompt(task_name): # make a prompt to infer pseudo data
|
|
return f"In the \"{task_name}\" task, which intent category best describes: \""
|
|
|
|
@staticmethod
|
|
def _pseudo_general_prompt(): # make a prompt to infer pseudo data
|
|
return f"In the current task, which intent category best describes: \""
|
|
|
|
@staticmethod
|
|
def _pseudo_ans_prompt(): # make a prompt to infer pseudo data
|
|
return f" \"? Answer: "
|
|
|
|
@staticmethod
|
|
def parse_pseudo_data(text): # parse utterance and labels from generated pseudo data
|
|
try:
|
|
task_name = re.findall(r'In the "(.+?)" task, which intent category best describes:', text)[0]
|
|
# utter = re.findall(r'task, which intent category best describes: "(.+?)"', text)[0]
|
|
utter = re.findall(r'task, which intent category best describes: " (.+?) "\? Answer: ', text)[0]
|
|
label = text.replace(f'In the "{task_name}" task, which intent category best describes: " {utter} "? Answer: ', '').strip()
|
|
# return {'task_name': task_name, 'utter': utter}
|
|
return {'task_name': task_name, 'utter': utter, 'label': label}
|
|
except:
|
|
return None
|
|
|
|
@staticmethod
|
|
def parse_example(example):
|
|
text = example['userInput']['text']
|
|
ans = example['intent']
|
|
# ans = example['intent'].replace("-", " ").replace("_", " ") # replace - and _ to make the text seems more natural
|
|
return text, ans
|
|
|
|
def parallel_tokenization(self, task_name, d):
|
|
# print(d, flush=True)
|
|
ori_utter = d[0]
|
|
prompt = self._pseudo_data_prompt(task_name)
|
|
gene_prompt = self._pseudo_general_prompt()
|
|
context = self.apply_prompt(task_name, d[0], self.ctx_max_len) # apply prompt to utterence
|
|
general_context = self.apply_general_prompt(d[0], self.ctx_max_len) # apply general prompt to utterence, no task id
|
|
# ans = d[1]
|
|
ans = d[1].replace("-", " ").replace("_", " ") # replace - and _ to make the text seems more natural
|
|
|
|
prompt_id = self.tokz.encode(prompt)
|
|
gene_prompt_id = self.tokz.encode(gene_prompt)
|
|
input_id = self.tokz.encode(d[0])
|
|
context_id = self.tokz.encode(context)
|
|
general_context_id = self.tokz.encode(general_context)
|
|
ans_id = self.tokz.encode(ans)
|
|
utter_id = self.tokz.encode(" "+ori_utter)
|
|
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
|
|
|
# * Ids construction.
|
|
return {
|
|
'utter_id': utter_id + [self.tokz.eos_token_id],
|
|
'posterior_id': [self.tokz.bos_token_id] + prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'input_id': [self.tokz.bos_token_id] + prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
# 'input_id': [self.tokz.bos_token_id] + context_id + [self.tokz.eos_token_id], # bos, prompt1, utter, prompt2, eos
|
|
'prompt_id': [self.tokz.bos_token_id] + prompt_id + [self.tokz.eos_token_id], # bos, prompt1, eos
|
|
'gene_prompt_id': [self.tokz.bos_token_id] + gene_prompt_id + [self.tokz.eos_token_id], # bos, prompt1, eos
|
|
'gene_posterior_id': [self.tokz.bos_token_id] + gene_prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'gene_input_id': [self.tokz.bos_token_id] + gene_prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'all_id': [self.tokz.bos_token_id] + context_id + ans_id + [self.tokz.eos_token_id],
|
|
'context_id': [self.tokz.bos_token_id] + context_id, # bos, prompt1, utter, prompt2
|
|
'general_context_id': [self.tokz.bos_token_id] + general_context_id,
|
|
'gene_all_id': [self.tokz.bos_token_id] + general_context_id + ans_id + [self.tokz.eos_token_id],
|
|
'ans_id': ans_id + [self.tokz.eos_token_id],
|
|
}
|
|
|
|
def data_tokenization(self, task_name, data):
|
|
# TODO : parallelize
|
|
data = [self.parallel_tokenization(task_name, i) for i in data]
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
return self.data[index]
|
|
|
|
|
|
class MixedCLSDataset(PromptCLSDataset):
|
|
def __init__(self, task2data, tokz, ctx_max_len=100, curr_data=None):
|
|
'''
|
|
task2data : {'task_name': [data1, data2, data3, ...]}
|
|
'''
|
|
self.ctx_max_len = ctx_max_len
|
|
self.tokz = tokz
|
|
self.max_ans_len = 0
|
|
|
|
self.data = []
|
|
for task in task2data:
|
|
self.data += self.data_tokenization(task, task2data[task])
|
|
if curr_data is not None:
|
|
self.data += curr_data
|
|
|
|
class PseudoCLSDataset(PromptCLSDataset):
|
|
def __init__(self, taskname, data, tokz, ctx_max_len=100):
|
|
self.ctx_max_len = ctx_max_len
|
|
self.tokz = tokz
|
|
self.max_ans_len = 0
|
|
self.data = self.data_tokenization(taskname, data)
|
|
|
|
def get_dataclass_dict(task2data, curr_task, curr_data, tokz, ctx_max_len):
|
|
'''
|
|
task2data : {'task_name': [data1, data2, data3, ...]}
|
|
'''
|
|
data_dict = {}
|
|
data_dict[curr_task] = curr_data
|
|
for task in task2data:
|
|
data_dict[task] = PseudoCLSDataset(task, task2data[task], tokz, ctx_max_len=ctx_max_len)
|
|
|
|
return data_dict
|
|
|
|
class PromptSlotTaggingDataset(torch.utils.data.Dataset):
|
|
def __init__(self, task_name, tokz, data_path, num_workers=8, ctx_max_len=100):
|
|
self.num_workers = num_workers
|
|
self.data_path = data_path
|
|
self.ctx_max_len = ctx_max_len
|
|
self.tokz = tokz
|
|
self.max_ans_len = 0
|
|
self.pseudo_data_prompt = self._pseudo_data_prompt1(task_name) # prompt used to infer pseudo data
|
|
# self.pseudo_data_prompt = self._pseudo_data_prompt(task_name) # prompt used to infer pseudo data
|
|
self.pseudo_data_prompt_id = tokz.encode(self.pseudo_data_prompt) # prompt id used to infer pseudo data
|
|
self.pseudo_ans_prompt = self._pseudo_ans_prompt() # Answer prompt added before inferring the label.
|
|
self.pseudo_ans_prompt_id = tokz.encode(self.pseudo_ans_prompt) # prompt id used to infer pseudo data
|
|
|
|
self.pseudo_gene_prompt = self._pseudo_general_prompt() #
|
|
self.pseudo_gene_prompt_id = tokz.encode(self.pseudo_gene_prompt)
|
|
|
|
# * load data
|
|
with open(data_path, "r", encoding='utf-8') as f:
|
|
data = [json.loads(i) for i in f.readlines()]
|
|
data = [self.parse_example(i) for i in data] # data format [utter, label]
|
|
|
|
self.data = []
|
|
if len(data) > 0:
|
|
self.data = self.data_tokenization(task_name, data)
|
|
|
|
@staticmethod
|
|
def apply_prompt(task_name, text, ctx_max_len): # make a prompt based contexts
|
|
text = text.split(' ')[:ctx_max_len]
|
|
# If there are slots and values, what are they in this sentence \"{text}\"?
|
|
return f"In the \"{task_name}\" task, if there are any slots and values, what are they in this sentence: \" {' '.join(text)} \"? Answer: "
|
|
|
|
@staticmethod
|
|
def _pseudo_data_prompt(task_name): # make a prompt to infer pseudo data
|
|
return f"In the \"{task_name}\" task, if there are any slots and values, what are they in this sentence: \""
|
|
|
|
@staticmethod
|
|
def _pseudo_ans_prompt(): # make a prompt to infer pseudo data
|
|
return f" \"? Answer: "
|
|
|
|
@staticmethod
|
|
def apply_prompt1(task_name, text, ctx_max_len): # make a prompt based contexts
|
|
text = text.split(' ')[:ctx_max_len]
|
|
# If there are slots and values, what are they in this sentence \"{text}\"?
|
|
return f"In the \"{task_name}\" task, what are slots and values: \" {' '.join(text)} \"? Answer: "
|
|
|
|
@staticmethod
|
|
def _pseudo_data_prompt1(task_name): # make a prompt to infer pseudo data
|
|
return f"In the \"{task_name}\" task, what are slots and values: \""
|
|
|
|
@staticmethod
|
|
def _pseudo_general_prompt(): # make a prompt to infer pseudo data
|
|
return f"In the current task, if there are any slots and values, what are they in this sentence: \""
|
|
|
|
@staticmethod
|
|
def apply_general_prompt(text, ctx_max_len):
|
|
text = text.split(' ')[:ctx_max_len]
|
|
return f"In the current task, if there are any slots and values, what are they in this sentence: \" {' '.join(text)} \"? Answer: "
|
|
|
|
@staticmethod
|
|
def parse_pseudo_data(text): # parse utterance and labels from generated pseudo data
|
|
try:
|
|
task_name = re.findall(r'In the "(.+?)" task, if there are any slots and values, what are they in this sentence:', text)[0]
|
|
utter = re.findall(r'task, if there are any slots and values, what are they in this sentence: " (.+?) "\? Answer: ', text)[0]
|
|
label = text.replace(f'In the "{task_name}" task, if there are any slots and values, what are they in this sentence: " {utter} "? Answer: ', '').strip()
|
|
return {'task_name': task_name, 'utter': utter, 'label': label}
|
|
except:
|
|
return None
|
|
|
|
@staticmethod
|
|
def parse_example(example):
|
|
text = example['userInput']['text']
|
|
# Create slots dictionary
|
|
slot_to_word = {}
|
|
for label in example.get('labels', []):
|
|
slot = label['slot']
|
|
start = label['valueSpan'].get('startIndex', 0)
|
|
end = label['valueSpan'].get('endIndex', -1)
|
|
slot_to_word[slot] = [text[start:end].strip(), (start, end)]
|
|
|
|
slot_to_word = sorted(slot_to_word.items(), key=lambda x: x[1][1])
|
|
|
|
if len(slot_to_word)>0:
|
|
answer_list = []
|
|
for key, slot in slot_to_word:
|
|
slot = slot[0]
|
|
answer_list.append(key + ': ' + slot)
|
|
answer = '; '.join(answer_list)
|
|
else:
|
|
answer = "No slot in this sentence."
|
|
return text, answer
|
|
|
|
def parallel_tokenization(self, task_name, d):
|
|
input_text, label_slots = d
|
|
ori_utter = d[0]
|
|
|
|
# * Apply prompt on input text.
|
|
prompt = self._pseudo_data_prompt1(task_name)
|
|
context = self.apply_prompt1(task_name, input_text, self.ctx_max_len)
|
|
# prompt = self._pseudo_data_prompt(task_name)
|
|
# context = self.apply_prompt(task_name, input_text, self.ctx_max_len)
|
|
context_id = self.tokz.encode(context) # prompt + input
|
|
|
|
gene_prompt = self._pseudo_general_prompt()
|
|
general_context = self.apply_general_prompt(d[0], self.ctx_max_len) # apply general prompt to utterence, no task id
|
|
|
|
# * Answer part
|
|
ans_id = self.tokz.encode(label_slots)
|
|
prompt_id = self.tokz.encode(prompt)
|
|
gene_prompt_id = self.tokz.encode(gene_prompt)
|
|
input_id = self.tokz.encode(d[0])
|
|
context_id = self.tokz.encode(context)
|
|
general_context_id = self.tokz.encode(general_context)
|
|
utter_id = self.tokz.encode(" "+ori_utter)
|
|
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
|
|
|
return {
|
|
'utter_id': utter_id + [self.tokz.eos_token_id],
|
|
'posterior_id': [self.tokz.bos_token_id] + prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'input_id': [self.tokz.bos_token_id] + prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'prompt_id': [self.tokz.bos_token_id] + prompt_id + [self.tokz.eos_token_id], # bos, prompt1, eos
|
|
'gene_prompt_id': [self.tokz.bos_token_id] + gene_prompt_id + [self.tokz.eos_token_id], # bos, prompt1, eos
|
|
'gene_posterior_id': [self.tokz.bos_token_id] + gene_prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'gene_input_id': [self.tokz.bos_token_id] + gene_prompt_id + utter_id + [self.tokz.eos_token_id], # bos, prompt1, utter, eos
|
|
'all_id': [self.tokz.bos_token_id] + context_id + ans_id + [self.tokz.eos_token_id],
|
|
'context_id': [self.tokz.bos_token_id] + context_id,
|
|
'general_context_id': [self.tokz.bos_token_id] + general_context_id,
|
|
'gene_all_id': [self.tokz.bos_token_id] + general_context_id + ans_id + [self.tokz.eos_token_id],
|
|
'ans_id': ans_id + [self.tokz.bos_token_id]
|
|
}
|
|
|
|
def data_tokenization(self, task_name, data):
|
|
# TODO : parallelize
|
|
data = [self.parallel_tokenization(task_name, i) for i in data]
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
return self.data[index]
|
|
|
|
class MixedSlotTaggingDataset(PromptSlotTaggingDataset):
|
|
def __init__(self, task2data, tokz, ctx_max_len=100, curr_data=None):
|
|
'''
|
|
task2data : {'task_name': [data1, data2, data3, ...]}
|
|
'''
|
|
self.ctx_max_len = ctx_max_len
|
|
self.tokz = tokz
|
|
self.max_ans_len = 0
|
|
|
|
self.data = []
|
|
for task in task2data:
|
|
self.data += self.data_tokenization(task, task2data[task])
|
|
if curr_data is not None:
|
|
self.data += curr_data
|
|
|
|
args = parse_args()
|
|
# * Information for different tasks.
|
|
if args.data_type == 'intent':
|
|
TASK2INFO = {
|
|
"banking": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "banking",
|
|
"task_type": "CLS",
|
|
},
|
|
"clinc": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "clinc",
|
|
"task_type": "CLS",
|
|
},
|
|
"hwu": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "hwu",
|
|
"task_type": "CLS",
|
|
},
|
|
"dstc8": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "dstc8",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"restaurant8": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "restaurant8",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"atis": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "atis",
|
|
"task_type": "CLS",
|
|
},
|
|
"tod": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "FB_TOD_SF",
|
|
"task_type": "CLS",
|
|
},
|
|
"mit_movie_eng": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_movie_eng",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"mit_movie_trivia": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_movie_trivia",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"mit_restaurant": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_restaurant",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"snips": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "snips",
|
|
"task_type": "CLS",
|
|
},
|
|
"top_split1": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "top_split1",
|
|
"task_type": "CLS",
|
|
},
|
|
"top_split2": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "top_split2",
|
|
"task_type": "CLS",
|
|
},
|
|
"top_split3": {
|
|
"dataset_class": PromptCLSDataset,
|
|
"dataset_folder": "top_split3",
|
|
"task_type": "CLS",
|
|
}
|
|
}
|
|
else:
|
|
TASK2INFO = {
|
|
"dstc8": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "dstc8",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"restaurant8": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "restaurant8",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"atis": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "atis",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"mit_movie_eng": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_movie_eng",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"mit_movie_trivia": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_movie_trivia",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"mit_restaurant": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "mit_restaurant",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
"snips": {
|
|
"dataset_class": PromptSlotTaggingDataset,
|
|
"dataset_folder": "snips",
|
|
"task_type": "SlotTagging",
|
|
},
|
|
}
|
|
|
|
|
|
class PinnedBatch:
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __getitem__(self, k):
|
|
return self.data[k]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.data[key] = value
|
|
|
|
def pin_memory(self):
|
|
for k in self.data.keys():
|
|
self.data[k] = self.data[k].pin_memory()
|
|
return self
|
|
|
|
def __repr__(self):
|
|
return self.data.__repr__()
|
|
|
|
def __str__(self):
|
|
return self.data.__str__()
|
|
|
|
def items(self):
|
|
return self.data.items()
|
|
|
|
def keys(self):
|
|
return self.data.keys()
|
|
|
|
|
|
def rolling_window(a, window):
|
|
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
|
|
strides = a.strides + (a.strides[-1],)
|
|
c = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
|
|
return c
|
|
|
|
def vview(a):
|
|
return np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
|
|
|
|
# * Get sublist start index in the full list.
|
|
def sublist_start_index(a, b):
|
|
n = min(len(b), len(a))
|
|
target_lists = [rolling_window(np.array(a), i) for i in range(1, n + 1)]
|
|
res = [np.flatnonzero(vview(target_lists[-1]) == s)
|
|
for s in vview(np.array([b]))][0]
|
|
if len(res) == 0:
|
|
k = 3
|
|
for i in range(1,k):
|
|
# print(b[:-i],flush=True)
|
|
return sublist_start_index(a, b[:-i])
|
|
# raise ValueError('The utterence is not in the input.')
|
|
else:
|
|
return res[0]
|
|
|
|
def pad_seq(seq, pad, max_len, pad_left=False):
|
|
if pad_left:
|
|
return [pad] * (max_len - len(seq)) + seq
|
|
else:
|
|
return seq + [pad] * (max_len - len(seq))
|
|
|
|
|
|
class PadBatchSeq:
|
|
def __init__(self, pad_id=0):
|
|
self.pad_id = pad_id
|
|
|
|
def __call__(self, batch):
|
|
# Fetch all ids.
|
|
utter_id = [i['utter_id'] for i in batch]
|
|
input_id = [i['input_id'] for i in batch]
|
|
posterior_id = [i['posterior_id'] for i in batch]
|
|
prompt_id = [i['prompt_id'] for i in batch]
|
|
gene_prompt_id = [i['gene_prompt_id'] for i in batch]
|
|
gene_input_id = [i['gene_input_id'] for i in batch]
|
|
gene_posterior_id = [i['gene_posterior_id'] for i in batch]
|
|
context_id = [i['context_id'] for i in batch]
|
|
general_context_id = [i['general_context_id'] for i in batch]
|
|
ans_id = [i['ans_id'] for i in batch]
|
|
all_id = [i['all_id'] for i in batch]
|
|
gene_all_id = [i['gene_all_id'] for i in batch]
|
|
|
|
# id check
|
|
# for i in batch:
|
|
# print('input_id',i['input_id'],flush=True)
|
|
# print('context_id',i['context_id'],flush=True)
|
|
|
|
# Store lengths.
|
|
all_lens = [len(i) for i in all_id]
|
|
gene_all_lens = [len(i) for i in gene_all_id]
|
|
context_lens = [len(i) for i in context_id]
|
|
general_context_lens = [len(i) for i in general_context_id]
|
|
ans_lens = [len(i) for i in ans_id]
|
|
prompt_lens = [len(i) for i in prompt_id]
|
|
gene_prompt_lens = [len(i) for i in gene_prompt_id]
|
|
input_lens = [len(i) for i in input_id]
|
|
gene_input_lens = [len(i) for i in gene_input_id]
|
|
posterior_lens = [len(i) for i in posterior_id]
|
|
gene_posterior_lens = [len(i) for i in gene_posterior_id]
|
|
utter_lens = [len(i) for i in utter_id]
|
|
|
|
# Construct masks
|
|
ans_mask = torch.ByteTensor([[1] * ans_lens[i] + [0] * (max(ans_lens)-ans_lens[i]) for i in range(len(ans_id))])
|
|
context_mask = torch.ByteTensor([[1] * context_lens[i] + [0] * (max(context_lens)-context_lens[i]) for i in range(len(context_id))])
|
|
general_context_mask = torch.ByteTensor([[1] * general_context_lens[i] + [0] * (max(general_context_lens)-general_context_lens[i]) for i in range(len(general_context_id))])
|
|
prompt_mask = torch.ByteTensor([[1] * prompt_lens[i] + [0] * (max(prompt_lens)-prompt_lens[i]) for i in range(len(prompt_id))])
|
|
gene_prompt_mask = torch.ByteTensor([[1] * gene_prompt_lens[i] + [0] * (max(gene_prompt_lens)-gene_prompt_lens[i]) for i in range(len(gene_prompt_id))])
|
|
input_mask = torch.ByteTensor([[1] * input_lens[i] + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))])
|
|
gene_input_mask = torch.ByteTensor([[1] * gene_input_lens[i] + [0] * (max(gene_input_lens)-gene_input_lens[i]) for i in range(len(gene_input_id))])
|
|
utter_mask = torch.ByteTensor([[1] * utter_lens[i] + [0] * (max(utter_lens)-utter_lens[i]) for i in range(len(utter_id))])
|
|
posterior_mask = torch.ByteTensor([[1] * posterior_lens[i] + [0] * (max(posterior_lens)-posterior_lens[i]) for i in range(len(posterior_id))])
|
|
gene_posterior_mask = torch.ByteTensor([[1] * gene_posterior_lens[i] + [0] * (max(gene_posterior_lens)-gene_posterior_lens[i]) for i in range(len(gene_posterior_id))])
|
|
all_mask = torch.ByteTensor([[1] * all_lens[i] + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))])
|
|
gene_all_mask = torch.ByteTensor([[1] * gene_all_lens[i] + [0] * (max(gene_all_lens)-gene_all_lens[i]) for i in range(len(gene_all_id))])
|
|
|
|
# * Record the mask so that we do not need to calculate losses on prompt tokens. Change 1202/2021, since I add 'Answer' to prompts.
|
|
# * Construct label mask for all (prompts, utterences, answers).
|
|
all_label_mask = torch.ByteTensor([[0] * (context_lens[i]) + [1] * (all_lens[i] - context_lens[i]) + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))]) # Only calculate losses on answer tokens.
|
|
gene_all_label_mask = torch.ByteTensor([[0] * (general_context_lens[i]) + [1] * (gene_all_lens[i] - general_context_lens[i]) + [0] * (max(gene_all_lens)-gene_all_lens[i]) for i in range(len(gene_all_id))]) # Only calculate losses on answer tokens.
|
|
input_label_mask = torch.ByteTensor([[0] * (prompt_lens[i]-1) +[1] * (input_lens[i]-prompt_lens[i]+1) + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))]) # Record the mask so that we do not need to calculate losses on prompt tokens.
|
|
gene_input_label_mask = torch.ByteTensor([[0] * (gene_prompt_lens[i]-1) +[1] * (gene_input_lens[i]-gene_prompt_lens[i]+1) + [0] * (max(gene_input_lens)-gene_input_lens[i]) for i in range(len(gene_input_id))]) # Record the mask so that we do not need to calculate losses on prompt tokens.
|
|
|
|
# * Construct label mask for inputs.
|
|
# input_label_mask = []
|
|
# # print('max input id',max(input_lens),flush=True)
|
|
# # print('input_mask',input_mask.size(),flush=True)
|
|
# for i in range(len(input_id)):
|
|
# # print(input_id[i],flush=True)
|
|
# # print(utter_id[i],flush=True)
|
|
# l = sublist_start_index(input_id[i],utter_id[i])
|
|
# # print('l',l,'len utter',len(utter_id))
|
|
# input_label_mask.append([0] * l + [1] * len(utter_id[i]) + [0] * (max(input_lens)-l-len(utter_id[i])))
|
|
# # print(i,len(input_label_mask[i]),flush=True)
|
|
# input_label_mask = torch.ByteTensor(input_label_mask)
|
|
# print('input_label_mask',input_label_mask.size(),flush=True)
|
|
|
|
# Return ids and masks
|
|
res = {}
|
|
res['prompt_id'] = torch.tensor([pad_seq(i, self.pad_id, max(prompt_lens)) for i in prompt_id], dtype=torch.long)
|
|
res['gene_prompt_id'] = torch.tensor([pad_seq(i, self.pad_id, max(gene_prompt_lens)) for i in gene_prompt_id], dtype=torch.long)
|
|
res['input_id'] = torch.tensor([pad_seq(i, self.pad_id, max(input_lens)) for i in input_id], dtype=torch.long)
|
|
res['gene_input_id'] = torch.tensor([pad_seq(i, self.pad_id, max(gene_input_lens)) for i in gene_input_id], dtype=torch.long)
|
|
res['posterior_id'] = torch.tensor([pad_seq(i, self.pad_id, max(posterior_lens)) for i in posterior_id], dtype=torch.long)
|
|
res['gene_posterior_id'] = torch.tensor([pad_seq(i, self.pad_id, max(gene_posterior_lens)) for i in gene_posterior_id], dtype=torch.long)
|
|
res["ans_id"] = torch.tensor([pad_seq(i, self.pad_id, max(ans_lens)) for i in ans_id], dtype=torch.long)
|
|
# res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens)) for i in context_id], dtype=torch.long) # Not pad left.
|
|
res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens), pad_left=True) for i in context_id], dtype=torch.long)
|
|
res["general_context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(general_context_lens), pad_left=True) for i in general_context_id], dtype=torch.long)
|
|
res["all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(all_lens)) for i in all_id], dtype=torch.long)
|
|
res["gene_all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(gene_all_lens)) for i in gene_all_id], dtype=torch.long)
|
|
res["utter_id"] = torch.tensor([pad_seq(i, self.pad_id, max(utter_lens)) for i in utter_id], dtype=torch.long)
|
|
|
|
res["all_lens"] = torch.tensor(all_lens, dtype=torch.long)
|
|
res["context_lens"] = torch.tensor(context_lens, dtype=torch.long)
|
|
res["general_context_lens"] = torch.tensor(general_context_lens, dtype=torch.long)
|
|
res["ans_lens"] = torch.tensor(ans_lens, dtype=torch.long)
|
|
res["prompt_lens"] = torch.tensor(prompt_lens, dtype=torch.long)
|
|
|
|
res["all_mask"], res["context_mask"], res['prompt_mask'], res['ans_mask'], res['input_mask'] = all_mask, context_mask, prompt_mask, ans_mask, input_mask
|
|
res['utter_mask'] = utter_mask
|
|
res['posterior_mask'] = posterior_mask
|
|
res['input_label_mask'] = input_label_mask
|
|
res['all_label_mask'] = all_label_mask
|
|
|
|
res['general_context_mask'] = general_context_mask
|
|
res['gene_all_mask'] = gene_all_mask
|
|
res['gene_input_mask'] = gene_input_mask
|
|
res['gene_input_label_mask'] = gene_input_label_mask
|
|
res['gene_prompt_mask'] = gene_prompt_mask
|
|
res['gene_all_label_mask'] = gene_all_label_mask
|
|
res['gene_input_label_mask'] = gene_input_label_mask
|
|
res['gene_posterior_mask'] = gene_posterior_mask
|
|
|
|
return PinnedBatch(res)
|
|
|
|
|
|
def get_datasets(path, tasks, tokz, num_workers=8, ctx_max_len=100):
|
|
res = {}
|
|
|
|
for task in tasks:
|
|
res[task] = {}
|
|
info = TASK2INFO[task]
|
|
res[task]['train'] = info['dataset_class'](
|
|
task, tokz, os.path.join(path, info['dataset_folder'], 'ripe_data', 'train.json'), num_workers=num_workers, ctx_max_len=ctx_max_len)
|
|
res[task]['val'] = TASK2INFO[task]['dataset_class'](
|
|
task, tokz, os.path.join(path, info['dataset_folder'], 'ripe_data','valid.json'), num_workers=num_workers, ctx_max_len=ctx_max_len)
|
|
res[task]['test'] = TASK2INFO[task]['dataset_class'](
|
|
task, tokz, os.path.join(path, info['dataset_folder'], 'ripe_data','test.json'), num_workers=num_workers, ctx_max_len=ctx_max_len)
|
|
|
|
return res
|
|
|
|
|
|
if __name__=='__main__':
|
|
data_path = 'PLL_DATA/banking/ripe_data/test.json'
|
|
from transformers import GPT2Tokenizer
|
|
tokz = GPT2Tokenizer.from_pretrained('gpt2')
|
|
task_info = []
|
|
|
|
dataset = PromptCLSDataset('banking',tokz, data_path)
|
|
sample = dataset[10]
|
|
print(sample,flush=True)
|
|
print('input:',tokz.decode(sample['input_id']),flush=True)
|
|
print('prompt:',tokz.decode(sample['prompt_id']),flush=True)
|
|
print('all:',tokz.decode(sample['all_id']),flush=True)
|
|
print('context:',tokz.decode(sample['context_id']),flush=True)
|
|
print('posterior:',tokz.decode(sample['posterior_id']),flush=True)
|
|
print('ans:',tokz.decode(sample['ans_id']),flush=True)
|
|
|