add reranker and reader in modelzoo/research/tprr

This commit is contained in:
huenrui 2021-03-16 10:37:53 +08:00
parent e71b601e65
commit 58673a086d
17 changed files with 2765 additions and 29 deletions

View File

@ -38,6 +38,9 @@ Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyper
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia
through TF-IDF.
The dataset of re-ranker consists of two parts:
Wikipedia data: the 2017 English Wikipedia dump version.
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
# [Features](#contents)
@ -64,6 +67,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
```python
# run evaluation example with HotPotQA dev dataset
sh run_eval_ascend.sh
sh run_eval_ascend_reranker_reader.sh
```
# [Script Description](#contents)
@ -75,25 +79,39 @@ After installing MindSpore via the official website and Dataset is correctly gen
└─tprr
├─README.md
├─scripts
| ├─run_eval_ascend.sh # Launch evaluation in ascend
| ├─run_eval_ascend.sh # Launch retriever evaluation in ascend
| └─run_eval_ascend_reranker_reader # Launch re-ranker and reader evaluation in ascend
|
├─src
| ├─config.py # Evaluation configurations
| ├─onehop.py # Onehop model
| ├─onehop_bert.py # Onehop bert model
| ├─process_data.py # Data preprocessing
| ├─twohop.py # Twohop model
| ├─twohop_bert.py # Twohop bert model
| └─utils.py # Utils for evaluation
| ├─build_reranker_data.py # build data for re-ranker from result of retriever
| ├─config.py # Evaluation configurations for retriever
| ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script
| ├─onehop.py # Onehop model of retriever
| ├─onehop_bert.py # Onehop bert model of retriever
| ├─process_data.py # Data preprocessing for retriever
| ├─reader.py # Reader model
| ├─reader_albert_xxlarge.py # Albert-xxlarge module of reader model
| ├─reader_downstream.py # Downstream module of reader model
| ├─reader_eval.py # Reader evaluation script
| ├─rerank_albert_xxlarge.py # Albert-xxlarge module of re-ranker model
| ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader
| ├─rerank_and_reader_utils.py # Utils for re-ranker and reader
| ├─rerank_downstream.py # Downstream module of re-ranker model
| ├─reranker.py # Re-ranker model
| ├─reranker_eval.py # Re-ranker evaluation script
| ├─twohop.py # Twohop model of retriever
| ├─twohop_bert.py # Twohop bert model of retriever
| └─utils.py # Utils for retriever
|
└─retriever_eval.py # Evaluation net for retriever
├─retriever_eval.py # Evaluation net for retriever
└─reranker_and_reader_eval.py # Evaluation net for re-ranker and reader
```
## [Script Parameters](#contents)
Parameters for evaluation can be set in config.py.
Parameters for retriever evaluation can be set in config.py.
- config for TPRR retriever dataset
- config for TPRR retriever
```python
"q_len": 64, # Max query length
@ -108,17 +126,30 @@ Parameters for evaluation can be set in config.py.
config.py for more configuration.
Parameters for re-ranker and reader evaluation can be passed directly at execution time.
- parameters for TPRR re-ranker and reader
```python
"seq_len": 512, # sequence length
"rerank_batch_size": 32, # batch size for re-ranker evaluation
"reader_batch_size": 448, # batch size for reader evaluation
"sp_threshold": 8 # threshold for picking supporting sentence
```
config.py for more configuration.
## [Evaluation Process](#contents)
### Evaluation
- Evaluation on Ascend
- Retriever evaluation on Ascend
```python
sh run_eval_ascend.sh
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
Evaluation result will be stored in the scripts path, whose folder name begins with "eval_tr". You can find the result like the
followings in log.
```python
@ -138,6 +169,35 @@ Parameters for evaluation can be set in config.py.
evaluation time (h): 20.155506462653477
```
- Re-ranker and reader evaluation on Ascend
Use the output of retriever as input of re-ranker
```python
sh run_eval_ascend_reranker_reader.sh
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
followings in log.
```python
total top1 pem: 0.8803511141120864
...
em: 0.67440918298447
f1: 0.8025625656569652
prec: 0.8292800393689271
recall: 0.8136908451841731
sp_em: 0.6009453072248481
sp_f1: 0.844555664157302
sp_prec: 0.8640844345841021
sp_recall: 0.8446123918845106
joint_em: 0.4537474679270763
joint_f1: 0.715119580346802
joint_prec: 0.7540052057184267
joint_recall: 0.7250240424067661
```
# [Model Description](#contents)
## [Performance](#contents)
@ -154,6 +214,8 @@ Parameters for evaluation can be set in config.py.
| Batch_size | 1 |
| Output | inference path |
| PEM | 0.9188 |
| total top1 pem | 0.88 |
| joint_f1 | 0.7151 |
# [Description of random situation](#contents)

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main file"""
from mindspore import context
from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data
from src.reranker_eval import rerank
from src.reader_eval import read
from src.hotpot_evaluate_v1 import hotpotqa_eval
from src.build_reranker_data import get_rerank_data
def rerank_and_retriever_eval():
"""main function"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
parser = get_parse()
args = parser.parse_args()
if args.get_reranker_data:
get_rerank_data(args)
if args.run_reranker:
rerank(args)
if args.cal_reranker_metrics:
total_top1_pem, _, _ = \
cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file)
print(f"total top1 pem: {total_top1_pem}")
if args.select_reader_data:
select_reader_dev_data(args)
if args.run_reader:
read(args)
if args.cal_reader_metrics:
metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file)
for k in metrics:
print(f"{k}: {metrics[k]}")
if __name__ == "__main__":
rerank_and_retriever_eval()

View File

@ -21,16 +21,16 @@ export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
if [ -d "eval_tr" ];
then
rm -rf ./eval
rm -rf ./eval_tr
fi
mkdir ./eval
mkdir ./eval_tr
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
cp ../*.py ./eval_tr
cp *.sh ./eval_tr
cp -r ../src ./eval_tr
cd ./eval_tr || exit
env > env.log
echo "start evaluation"

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# eval script
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation"
python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 &
cd ..

View File

@ -0,0 +1,430 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""build reranker data from retriever result"""
import pickle
import gzip
from tqdm import tqdm
from src.rerank_and_reader_utils import read_json, make_wiki_id, convert_text_to_tokens, normalize_title, \
whitespace_tokenize, DocDB, _largest_valid_index, generate_mapping, InputFeatures, Example
from transformers import AutoTokenizer
def judge_para(data):
"""judge whether is valid para"""
for _, para_tokens in data["context"].items():
if len(para_tokens) == 1:
return False
return True
def judge_sp(data, sent_name2id, para2id):
"""judge whether is valid sp"""
for sp in data['sp']:
title = normalize_title(sp[0])
name = normalize_title(sp[0]) + '_{}'.format(sp[1])
if title in para2id and name not in sent_name2id:
return False
return True
def judge(path, path_set, reverse=False, golds=None, mode='or'):
"""judge function"""
if path[0] == path[-1]:
return False
if path in path_set:
return False
if reverse and path[::-1] in path_set:
return False
if not golds:
return True
if mode == 'or':
return any(gold not in path for gold in golds)
if mode == 'and':
return all(gold not in path for gold in golds)
return False
def get_context_and_sents(path, doc_db):
"""get context ans sentences"""
context = {}
sents = {}
for title in path:
para_info = doc_db.get_doc_info(title)
if title.endswith('_0'):
title = title[:-2]
context[title] = pickle.loads(para_info[1])
sents[title] = pickle.loads(para_info[2])
return context, sents
def gen_dev_data(dev_file, db_path, topk_file):
"""generate dev data"""
# ----------------------------------------db info-----------------------------------------------
topk_data = read_json(topk_file) # path
doc_db = DocDB(db_path) # db get offset
print('load db successfully!')
# ---------------------------------------------supervision ------------------------------------------
dev_data = read_json(dev_file)
qid2sp = {}
qid2ans = {}
qid2type = {}
qid2path = {}
for _, data in enumerate(dev_data):
sp_facts = data['supporting_facts'] if 'supporting_facts' in data else None
qid2sp[data['_id']] = sp_facts
qid2ans[data['_id']] = data['answer'] if 'answer' in data else None
qid2type[data['_id']] = data['type'] if 'type' in data else None
qid2path[data['_id']] = list(set(list(zip(*sp_facts))[0])) if sp_facts else None
new_dev_data = []
for _, data in enumerate(tqdm(topk_data)):
qid = data['q_id']
question = data['question']
topk_titles = data['topk_titles']
gold_path = list(map(normalize_title, qid2path[qid])) if qid2path[qid] else None
all_titles = []
for titles in topk_titles:
titles = list(map(normalize_title, titles))
if len(titles) == 1:
continue
path = titles[:2]
if judge(path, all_titles):
all_titles.append(titles[:2])
if len(titles) == 3:
path = titles[1:]
if judge(path, all_titles):
all_titles.append(titles[1:])
# --------------------------------------------------process query-----------------------------------
question = " ".join(whitespace_tokenize(question))
question = question.strip()
q_tokens, _ = convert_text_to_tokens(question)
gold_path = list(map(lambda x: make_wiki_id(x, 0), gold_path)) if gold_path else None
for path in all_titles:
context, sents = get_context_and_sents(path, doc_db)
ans_label = int(gold_path[0] in path and gold_path[1] in path) if gold_path else None
new_dev_data.append({
'qid': qid,
'type': qid2type[qid],
'question': question,
'q_tokens': q_tokens,
'context': context,
'sents': sents,
'answer': qid2ans[qid],
'sp': qid2sp[qid],
'ans_para': None,
'is_impossible': not ans_label == 1
})
return new_dev_data
def read_hotpot_examples(path_data):
"""reader examples"""
examples = []
max_sent_cnt = 0
failed = 0
for _, data in enumerate(path_data):
if not judge_para(data):
failed += 1
continue
question = data['question']
question = " ".join(whitespace_tokenize(question))
question = question.strip()
path = list(map(normalize_title, data["context"].keys()))
qid = data['qid']
q_tokens = data['q_tokens']
# -------------------------------------add para------------------------------------------------------------
doc_tokens = []
para_start_end_position = []
title_start_end_position = []
sent_start_end_position = []
sent_names = []
sent_name2id = {}
para2id = {}
for para, para_tokens in data["context"].items():
sents = data["sents"][para]
para = normalize_title(para)
title_tokens = convert_text_to_tokens(para)[0]
para_node_id = len(para_start_end_position)
para2id[para] = para_node_id
doc_offset = len(doc_tokens)
doc_tokens += title_tokens
doc_tokens += para_tokens
title_start_end_position.append((doc_offset, doc_offset + len(title_tokens) - 1))
doc_offset += len(title_tokens)
para_start_end_position.append((doc_offset, doc_offset + len(para_tokens) - 1, para))
for idx, sent in enumerate(sents):
if sent[0] == -1 and sent[1] == -1:
continue
sent_names.append([para, idx]) # local name
sent_node_id = len(sent_start_end_position)
sent_name2id[normalize_title(para) + '_{}'.format(str(idx))] = sent_node_id
sent_start_end_position.append((doc_offset + sent[0],
doc_offset + sent[1]))
# add sp and ans
sp_facts = []
sup_fact_id = []
for sp in sp_facts:
name = normalize_title(sp[0]) + '_{}'.format(sp[1])
if name in sent_name2id:
sup_fact_id.append(sent_name2id[name])
sup_para_id = set() # use set
if sp_facts:
for para in list(zip(*sp_facts))[0]:
para = normalize_title(para)
if para in para2id:
sup_para_id.add(para2id[para])
sup_para_id = list(sup_para_id)
example = Example(
qas_id=qid,
path=path,
unique_id=qid + '_' + '_'.join(path),
question_tokens=q_tokens,
doc_tokens=doc_tokens, # multi-para tokens w/o query
sent_names=sent_names,
sup_fact_id=sup_fact_id, # global sent id
sup_para_id=sup_para_id, # global para id
para_start_end_position=para_start_end_position,
sent_start_end_position=sent_start_end_position,
title_start_end_position=title_start_end_position,
question_text=question)
examples.append(example)
max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position))
print(f"Maximum sentence cnt: {max_sent_cnt}")
print(f'failed examples: {failed}')
print(f'convert {len(examples)} examples successfully!')
return examples
def add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens):
"""add sub tokens"""
for sub_token in sub_tokens:
tok_to_orig_index.append(idx)
all_query_tokens.append(sub_token)
return tok_to_orig_index, all_query_tokens
def get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index):
"""get sentences' spans"""
sentence_spans = []
for sent_span in example.sent_start_end_position:
sent_start_position = orig_to_tok_index[sent_span[0]]
sent_end_position = orig_to_tok_back_index[sent_span[1]]
sentence_spans.append((sent_start_position, sent_end_position + 1))
return sentence_spans
def get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker):
"""get paragraphs' spans"""
para_spans = []
for title_span, para_span in zip(example.title_start_end_position, example.para_start_end_position):
para_start_position = orig_to_tok_index[title_span[0]]
para_end_position = orig_to_tok_back_index[para_span[1]]
if para_end_position + 1 < len(all_doc_tokens) and all_doc_tokens[para_end_position + 1] == \
marker['sent'][0]:
para_spans.append((para_start_position - 1, para_end_position + 1, para_span[2]))
else:
para_spans.append((para_start_position - 1, para_end_position, para_span[2]))
return para_spans
def build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids, all_query_tokens,
query_input_ids, query_input_mask, query_segment_ids, para_spans, sentence_spans, tok_to_orig_index):
"""build a input feature"""
feature = InputFeatures(
qas_id=example.qas_id,
path=example.path,
unique_id=example.qas_id + '_' + '_'.join(example.path),
sent_names=example.sent_names,
doc_tokens=all_doc_tokens,
doc_input_ids=doc_input_ids,
doc_input_mask=doc_input_mask,
doc_segment_ids=doc_segment_ids,
query_tokens=all_query_tokens,
query_input_ids=query_input_ids,
query_input_mask=query_input_mask,
query_segment_ids=query_segment_ids,
para_spans=para_spans,
sent_spans=sentence_spans,
token_to_orig_map=tok_to_orig_index)
return feature
def convert_example_to_features(tokenizer, args, examples):
"""convert examples to features"""
features = []
failed = 0
marker = {'q': ['[q]', '[/q]'], 'para': ['<t>', '</t>'], 'sent': ['[s]']}
for (_, example) in enumerate(tqdm(examples)):
all_query_tokens = [tokenizer.cls_token, marker['q'][0]]
tok_to_orig_index = [-1, -1] # orig: query + doc tokens
ques_orig_to_tok_index = [] # start position
ques_orig_to_tok_back_index = [] # end position
q_spans = []
# -------------------------------------------for query---------------------------------------------
for (idx, token) in enumerate(example.question_tokens):
sub_tokens = tokenizer.tokenize(token)
ques_orig_to_tok_index.append(len(all_query_tokens))
tok_to_orig_index, all_query_tokens = add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens)
ques_orig_to_tok_back_index.append(len(all_query_tokens) - 1)
all_query_tokens = all_query_tokens[:63]
tok_to_orig_index = tok_to_orig_index[:63]
all_query_tokens.append(marker['q'][-1])
tok_to_orig_index.append(-1)
q_spans.append((1, len(all_query_tokens) - 1))
# ---------------------------------------add doc tokens------------------------------------------------
all_doc_tokens = []
orig_to_tok_index = [] # orig: token in doc
orig_to_tok_back_index = []
title_start_mapping, title_end_mapping = generate_mapping(len(example.doc_tokens),
example.title_start_end_position)
_, sent_end_mapping = generate_mapping(len(example.doc_tokens),
example.sent_start_end_position)
all_doc_tokens += all_query_tokens
for (idx, token) in enumerate(example.doc_tokens):
sub_tokens = tokenizer.tokenize(token)
if title_start_mapping[idx] == 1:
all_doc_tokens.append(marker['para'][0])
tok_to_orig_index.append(-1)
# orig: position in doc tokens tok: global tokenized tokens (start)
orig_to_tok_index.append(len(all_doc_tokens))
tok_to_orig_index, all_doc_tokens = add_sub_token(sub_tokens, idx + len(example.question_tokens),
tok_to_orig_index, all_doc_tokens)
orig_to_tok_back_index.append(len(all_doc_tokens) - 1)
if title_end_mapping[idx] == 1:
all_doc_tokens.append(marker['para'][1])
tok_to_orig_index.append(-1)
if sent_end_mapping[idx] == 1:
all_doc_tokens.append(marker['sent'][0])
tok_to_orig_index.append(-1)
# -----------------------------------for sentence-------------------------------------------------
sentence_spans = get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index)
# -----------------------------------for para-------------------------------------------------------
para_spans = get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker)
# -----------------------------------remove sent > max seq length-----------------------------------------
sent_max_index = _largest_valid_index(sentence_spans, args.seq_len)
max_sent_cnt = len(sentence_spans)
if sent_max_index != len(sentence_spans):
if sent_max_index == 0:
failed += 0
continue
sentence_spans = sentence_spans[:sent_max_index]
max_tok_length = sentence_spans[-1][1] # max_tok_length [s]
# max end index: max_tok_length
para_max_index = _largest_valid_index(para_spans, max_tok_length + 1)
if para_max_index == 0: # only one para
failed += 0
continue
if orig_to_tok_back_index[example.title_start_end_position[1][1]] + 1 >= max_tok_length:
failed += 0
continue
max_para_span = para_spans[para_max_index]
para_spans = para_spans[:para_max_index]
para_spans.append((max_para_span[0], max_tok_length, max_para_span[2]))
all_doc_tokens = all_doc_tokens[:max_tok_length + 1]
sentence_spans = sentence_spans[:min(max_sent_cnt, args.max_sent_num)]
# ----------------------------------------Padding Document-----------------------------------------------------
if len(all_doc_tokens) > args.seq_len:
st, _, title = para_spans[-1]
para_spans[-1] = (st, args.seq_len - 1, title)
all_doc_tokens = all_doc_tokens[:args.seq_len - 1] + [marker['sent'][0]]
doc_input_ids = tokenizer.convert_tokens_to_ids(all_doc_tokens)
query_input_ids = tokenizer.convert_tokens_to_ids(all_query_tokens)
doc_input_mask = [1] * len(doc_input_ids)
doc_segment_ids = [0] * len(query_input_ids) + [1] * (len(doc_input_ids) - len(query_input_ids))
doc_pad_length = args.seq_len - len(doc_input_ids)
doc_input_ids += [0] * doc_pad_length
doc_input_mask += [0] * doc_pad_length
doc_segment_ids += [0] * doc_pad_length
# Padding Question
query_input_mask = [1] * len(query_input_ids)
query_segment_ids = [0] * len(query_input_ids)
query_pad_length = 64 - len(query_input_ids)
query_input_ids += [0] * query_pad_length
query_input_mask += [0] * query_pad_length
query_segment_ids += [0] * query_pad_length
feature = build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids,
all_query_tokens, query_input_ids, query_input_mask, query_segment_ids, para_spans,
sentence_spans, tok_to_orig_index)
features.append(feature)
return features
def get_rerank_data(args):
"""function for generating reranker's data"""
new_dev_data = gen_dev_data(dev_file=args.dev_gold_file,
db_path=args.wiki_db_file,
topk_file=args.retriever_result_file)
tokenizer = AutoTokenizer.from_pretrained(args.albert_model_path)
new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]']
tokenizer.add_tokens(new_tokens)
examples = read_hotpot_examples(new_dev_data)
features = convert_example_to_features(tokenizer=tokenizer, args=args, examples=examples)
with gzip.open(args.rerank_example_file, "wb") as f:
pickle.dump(examples, f)
with gzip.open(args.rerank_feature_file, "wb") as f:
pickle.dump(features, f)

View File

@ -33,14 +33,14 @@ def ThinkRetrieverConfig():
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
parser.add_argument("--save_path", type=str, default='./', help='path of output')
parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path")
parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path")
parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json',
parser.add_argument("--save_path", type=str, default='../', help='path of output')
parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
parser.add_argument("--wiki_path", type=str, default='../db_docs_bidirection_new.pkl', help="wiki path")
parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json',
help="dev path")
parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path")
parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path")
parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path")
parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path")
parser.add_argument("--onehop_bert_path", type=str, default='../onehop.ckpt', help="onehop bert ckpt path")
parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
return parser.parse_args()

View File

@ -0,0 +1,153 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hotpotqa evaluate script"""
import re
import string
from collections import Counter
import ujson as json
def normalize_answer(s):
"""normalize answer"""
def remove_articles(text):
"""remove articles"""
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
"""fix whitespace"""
return ' '.join(text.split())
def remove_punc(text):
"""remove punctuation from text"""
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
"""lower text"""
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
"""calculate f1 score"""
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall
def exact_match_score(prediction, ground_truth):
"""calculate exact match score"""
return normalize_answer(prediction) == normalize_answer(ground_truth)
def update_answer(metrics, prediction, gold):
"""update answer"""
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall
def update_sp(metrics, prediction, gold):
"""update supporting sentences"""
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall
def hotpotqa_eval(prediction_file, gold_file):
"""hotpotqa evaluate function"""
with open(prediction_file) as f:
prediction = json.load(f)
with open(gold_file) as f:
gold = json.load(f)
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
for dp in gold:
cur_id = dp['_id']
can_eval_joint = True
if cur_id not in prediction['answer']:
print('missing answer {}'.format(cur_id))
can_eval_joint = False
else:
em, prec, recall = update_answer(
metrics, prediction['answer'][cur_id], dp['answer'])
if cur_id not in prediction['sp']:
print('missing sp fact {}'.format(cur_id))
can_eval_joint = False
else:
sp_em, sp_prec, sp_recall = update_sp(
metrics, prediction['sp'][cur_id], dp['supporting_facts'])
if can_eval_joint:
joint_prec = prec * sp_prec
joint_recall = recall * sp_recall
if joint_prec + joint_recall > 0:
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
else:
joint_f1 = 0.
joint_em = em * sp_em
metrics['joint_em'] += joint_em
metrics['joint_f1'] += joint_f1
metrics['joint_prec'] += joint_prec
metrics['joint_recall'] += joint_recall
num = len(gold)
for k in metrics:
metrics[k] /= num
return metrics

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reader model"""
import mindspore.nn as nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.ops import BatchMatMul
from mindspore import ops
from mindspore import dtype as mstype
from src.reader_albert_xxlarge import Reader_Albert
from src.reader_downstream import Reader_Downstream
dst_type = mstype.float16
dst_type2 = mstype.float32
class Reader(nn.Cell):
"""Reader model"""
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
"""init function"""
super(Reader, self).__init__(auto_prefix=False)
self.encoder = Reader_Albert(batch_size)
param_dict = load_checkpoint(encoder_ck_file)
not_load_params = load_param_into_net(self.encoder, param_dict)
print(f"not loaded: {not_load_params}")
self.downstream = Reader_Downstream()
param_dict = load_checkpoint(downstream_ck_file)
not_load_params = load_param_into_net(self.downstream, param_dict)
print(f"not loaded: {not_load_params}")
self.bmm = BatchMatMul()
def construct(self, input_ids, attn_mask, token_type_ids,
context_mask, square_mask, packing_mask, cache_mask,
para_start_mapping, sent_end_mapping):
"""construct function"""
state = self.encoder(attn_mask, input_ids, token_type_ids)
para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D]
sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D]
q_type, start, end, para_logit, sent_logit = self.downstream(ops.Cast()(para_state, dst_type2),
ops.Cast()(sent_state, dst_type2),
state,
context_mask)
outer = start[:, :, None] + end[:, None]
outer_mask = cache_mask
outer_mask = square_mask * outer_mask[None]
outer = outer - 1e30 * (1 - outer_mask)
outer = outer - 1e30 * packing_mask[:, :, None]
max_row = ops.ReduceMax()(outer, 2)
y1 = ops.Argmax()(max_row)
max_col = ops.ReduceMax()(outer, 1)
y2 = ops.Argmax()(max_col)
return start, end, q_type, para_logit, sent_logit, y1, y2

View File

@ -0,0 +1,263 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""albert-xxlarge Model for reader"""
import numpy as np
from mindspore import nn, ops
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore import dtype as mstype
dst_type = mstype.float16
dst_type2 = mstype.float32
class LayerNorm(nn.Cell):
"""LayerNorm layer"""
def __init__(self, mul_7_w_shape, add_8_bias_shape):
"""init function"""
super(LayerNorm, self).__init__()
self.reducemean_0 = P.ReduceMean(keep_dims=True)
self.sub_1 = P.Sub()
self.pow_2 = P.Pow()
self.pow_2_input_weight = 2.0
self.reducemean_3 = P.ReduceMean(keep_dims=True)
self.add_4 = P.Add()
self.add_4_bias = 9.999999960041972e-13
self.sqrt_5 = P.Sqrt()
self.div_6 = P.Div()
self.mul_7 = P.Mul()
self.mul_7_w = Parameter(Tensor(np.random.uniform(0, 1, mul_7_w_shape).astype(np.float32)), name=None)
self.add_8 = P.Add()
self.add_8_bias = Parameter(Tensor(np.random.uniform(0, 1, add_8_bias_shape).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
opt_reducemean_0 = self.reducemean_0(x, -1)
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight)
opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1)
opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias)
opt_sqrt_5 = self.sqrt_5(opt_add_4)
opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5)
opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w)
opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias)
return opt_add_8
class Linear(nn.Cell):
"""Linear layer"""
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
"""init function"""
super(Linear, self).__init__()
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
name=None)
self.add_1 = P.Add()
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
return opt_add_1
class MultiHeadAttn(nn.Cell):
"""Multi-head attention layer"""
def __init__(self, batch_size):
"""init function"""
super(MultiHeadAttn, self).__init__()
self.batch_size = batch_size
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.matmul_1 = nn.MatMul()
self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.matmul_2 = nn.MatMul()
self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.add_3 = P.Add()
self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.add_4 = P.Add()
self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.add_5 = P.Add()
self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.reshape_6 = P.Reshape()
self.reshape_6_shape = tuple([batch_size, 512, 64, 64])
self.reshape_7 = P.Reshape()
self.reshape_7_shape = tuple([batch_size, 512, 64, 64])
self.reshape_8 = P.Reshape()
self.reshape_8_shape = tuple([batch_size, 512, 64, 64])
self.transpose_9 = P.Transpose()
self.transpose_10 = P.Transpose()
self.transpose_11 = P.Transpose()
self.matmul_12 = nn.MatMul()
self.div_13 = P.Div()
self.div_13_w = 8.0
self.add_14 = P.Add()
self.softmax_15 = nn.Softmax(axis=3)
self.matmul_16 = nn.MatMul()
self.transpose_17 = P.Transpose()
self.matmul_18 = P.MatMul()
self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None)
self.add_19 = P.Add()
self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
def construct(self, x, x0):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type))
opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type))
opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias)
opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias)
opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias)
opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type))
opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2))
opt_add_14 = self.add_14(opt_div_13, x0)
opt_softmax_15 = self.softmax_15(opt_add_14)
opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type))
opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3))
opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1),
ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\
.view(self.batch_size, 512, 4096)
opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias)
return opt_add_19
class NewGeLU(nn.Cell):
"""new gelu layer"""
def __init__(self):
"""init function"""
super(NewGeLU, self).__init__()
self.mul_0 = P.Mul()
self.mul_0_w = 0.5
self.pow_1 = P.Pow()
self.pow_1_input_weight = 3.0
self.mul_2 = P.Mul()
self.mul_2_w = 0.044714998453855515
self.add_3 = P.Add()
self.mul_4 = P.Mul()
self.mul_4_w = 0.7978845834732056
self.tanh_5 = nn.Tanh()
self.add_6 = P.Add()
self.add_6_bias = 1.0
self.mul_7 = P.Mul()
def construct(self, x):
"""construct function"""
opt_mul_0 = self.mul_0(x, self.mul_0_w)
opt_pow_1 = self.pow_1(x, self.pow_1_input_weight)
opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w)
opt_add_3 = self.add_3(x, opt_mul_2)
opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w)
opt_tanh_5 = self.tanh_5(opt_mul_4)
opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias)
opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6)
return opt_mul_7
class TransformerLayer(nn.Cell):
"""Transformer layer"""
def __init__(self, batch_size, layernorm1_0_mul_7_w_shape, layernorm1_0_add_8_bias_shape,
linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape,
linear3_1_add_1_bias_shape):
"""init function"""
super(TransformerLayer, self).__init__()
self.multiheadattn_0 = MultiHeadAttn(batch_size)
self.add_0 = P.Add()
self.layernorm1_0 = LayerNorm(mul_7_w_shape=layernorm1_0_mul_7_w_shape,
add_8_bias_shape=layernorm1_0_add_8_bias_shape)
self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape,
add_1_bias_shape=linear3_0_add_1_bias_shape)
self.newgelu2_0 = NewGeLU()
self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape,
add_1_bias_shape=linear3_1_add_1_bias_shape)
self.add_1 = P.Add()
def construct(self, x, x0):
"""construct function"""
multiheadattn_0_opt = self.multiheadattn_0(x, x0)
opt_add_0 = self.add_0(x, multiheadattn_0_opt)
layernorm1_0_opt = self.layernorm1_0(opt_add_0)
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
newgelu2_0_opt = self.newgelu2_0(linear3_0_opt)
linear3_1_opt = self.linear3_1(newgelu2_0_opt)
opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
return opt_add_1
class Reader_Albert(nn.Cell):
"""Albert model for reader"""
def __init__(self, batch_size):
"""init function"""
super(Reader_Albert, self).__init__()
self.expanddims_0 = P.ExpandDims()
self.expanddims_0_axis = 1
self.expanddims_3 = P.ExpandDims()
self.expanddims_3_axis = 2
self.cast_5 = P.Cast()
self.cast_5_to = mstype.float32
self.sub_7 = P.Sub()
self.sub_7_bias = 1.0
self.mul_9 = P.Mul()
self.mul_9_w = -10000.0
self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)),
name=None)
self.gather_1_axis = 0
self.gather_1 = P.Gather()
self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None)
self.gather_2_axis = 0
self.gather_2 = P.Gather()
self.add_4 = P.Add()
self.add_6 = P.Add()
self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None)
self.layernorm1_0 = LayerNorm(mul_7_w_shape=(128,), add_8_bias_shape=(128,))
self.linear3_0 = Linear(matmul_0_weight_shape=(128, 4096), add_1_bias_shape=(4096,))
self.module34_0 = TransformerLayer(batch_size,
layernorm1_0_mul_7_w_shape=(4096,),
layernorm1_0_add_8_bias_shape=(4096,),
linear3_0_matmul_0_weight_shape=(4096, 16384),
linear3_0_add_1_bias_shape=(16384,),
linear3_1_matmul_0_weight_shape=(16384, 4096),
linear3_1_add_1_bias_shape=(4096,))
self.layernorm1_1 = LayerNorm(mul_7_w_shape=(4096,), add_8_bias_shape=(4096,))
def construct(self, x, x0, x1):
"""construct function"""
opt_expanddims_0 = self.expanddims_0(x, self.expanddims_0_axis)
opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
opt_gather_1_axis = self.gather_1_axis
opt_gather_1 = self.gather_1(self.gather_1_input_weight, x0, opt_gather_1_axis)
opt_gather_2_axis = self.gather_2_axis
opt_gather_2 = self.gather_2(self.gather_2_input_weight, x1, opt_gather_2_axis)
opt_add_4 = self.add_4(opt_gather_1, opt_gather_2)
opt_add_6 = self.add_6(opt_add_4, self.add_6_bias)
layernorm1_0_opt = self.layernorm1_0(opt_add_6)
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
module34_0_opt = self.module34_0(linear3_0_opt, opt_mul_9)
out = self.layernorm1_1(module34_0_opt)
for _ in range(11):
out = self.module34_0(out, opt_mul_9)
out = self.layernorm1_1(out)
return out

View File

@ -0,0 +1,213 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""downstream Model for reader"""
import numpy as np
from mindspore import nn, ops
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore import dtype as mstype
dst_type = mstype.float16
dst_type2 = mstype.float32
class Module15(nn.Cell):
"""module of reader downstream"""
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
"""init function"""
super(Module15, self).__init__()
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
name=None)
self.add_1 = P.Add()
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
self.relu_2 = nn.ReLU()
def construct(self, x):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
opt_relu_2 = self.relu_2(opt_add_1)
return opt_relu_2
class NormModule(nn.Cell):
"""Normalization module of reader downstream"""
def __init__(self, mul_8_w_shape, add_9_bias_shape):
"""init function"""
super(NormModule, self).__init__()
self.reducemean_0 = P.ReduceMean(keep_dims=True)
self.sub_1 = P.Sub()
self.sub_2 = P.Sub()
self.pow_3 = P.Pow()
self.pow_3_input_weight = 2.0
self.reducemean_4 = P.ReduceMean(keep_dims=True)
self.add_5 = P.Add()
self.add_5_bias = 9.999999960041972e-13
self.sqrt_6 = P.Sqrt()
self.div_7 = P.Div()
self.mul_8 = P.Mul()
self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None)
self.add_9 = P.Add()
self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
opt_reducemean_0 = self.reducemean_0(x, -1)
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
opt_sub_2 = self.sub_2(x, opt_reducemean_0)
opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight)
opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
opt_sqrt_6 = self.sqrt_6(opt_add_5)
opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6)
opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7)
opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
return opt_add_9
class Module16(nn.Cell):
"""module of reader downstream"""
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
normmodule_0_add_9_bias_shape):
"""init function"""
super(Module16, self).__init__()
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
add_1_bias_shape=module15_0_add_1_bias_shape)
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
add_9_bias_shape=normmodule_0_add_9_bias_shape)
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
module15_0_opt = self.module15_0(x)
normmodule_0_opt = self.normmodule_0(module15_0_opt)
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
return ops.Cast()(opt_matmul_0, dst_type2)
class Module17(nn.Cell):
"""module of reader downstream"""
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
normmodule_0_add_9_bias_shape):
"""init function"""
super(Module17, self).__init__()
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
add_1_bias_shape=module15_0_add_1_bias_shape)
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
add_9_bias_shape=normmodule_0_add_9_bias_shape)
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None)
self.add_1 = P.Add()
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
module15_0_opt = self.module15_0(x)
normmodule_0_opt = self.normmodule_0(module15_0_opt)
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
return opt_add_1
class Module5(nn.Cell):
"""module of reader downstream"""
def __init__(self):
"""init function"""
super(Module5, self).__init__()
self.sub_0 = P.Sub()
self.sub_0_bias = 1.0
self.mul_1 = P.Mul()
self.mul_1_w = 1.0000000150474662e+30
def construct(self, x):
"""construct function"""
opt_sub_0 = self.sub_0(self.sub_0_bias, x)
opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w)
return opt_mul_1
class Module10(nn.Cell):
"""module of reader downstream"""
def __init__(self):
"""init function"""
super(Module10, self).__init__()
self.squeeze_0 = P.Squeeze(2)
self.module5_0 = Module5()
self.sub_1 = P.Sub()
def construct(self, x, x0):
"""construct function"""
opt_squeeze_0 = self.squeeze_0(x)
module5_0_opt = self.module5_0(x0)
opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt)
return opt_sub_1
class Reader_Downstream(nn.Cell):
"""Downstream model for reader"""
def __init__(self):
"""init function"""
super(Reader_Downstream, self).__init__()
self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
module15_0_add_1_bias_shape=(8192,),
normmodule_0_mul_8_w_shape=(8192,),
normmodule_0_add_9_bias_shape=(8192,))
self.add_74 = P.Add()
self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
module15_0_add_1_bias_shape=(8192,),
normmodule_0_mul_8_w_shape=(8192,),
normmodule_0_add_9_bias_shape=(8192,))
self.add_75 = P.Add()
self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
module15_0_add_1_bias_shape=(4096,),
normmodule_0_mul_8_w_shape=(4096,),
normmodule_0_add_9_bias_shape=(4096,))
self.module10_0 = Module10()
self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
module15_0_add_1_bias_shape=(4096,),
normmodule_0_mul_8_w_shape=(4096,),
normmodule_0_add_9_bias_shape=(4096,))
self.module10_1 = Module10()
self.gather_6_input_weight = Tensor(np.array(0))
self.gather_6_axis = 1
self.gather_6 = P.Gather()
self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True)
self.relu_18 = nn.ReLU()
self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,))
self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True)
def construct(self, x, x0, x1, x2):
"""construct function"""
module16_0_opt = self.module16_0(x)
opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias)
module16_1_opt = self.module16_1(x0)
opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias)
module17_0_opt = self.module17_0(x1)
opt_module10_0 = self.module10_0(module17_0_opt, x2)
module17_1_opt = self.module17_1(x1)
opt_module10_1 = self.module10_1(module17_1_opt, x2)
opt_gather_6_axis = self.gather_6_axis
opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis)
opt_dense_13 = self.dense_13(opt_gather_6)
opt_relu_18 = self.relu_18(opt_dense_13)
normmodule_0_opt = self.normmodule_0(opt_relu_18)
opt_dense_73 = self.dense_73(normmodule_0_opt)
return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75

View File

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""execute reader"""
from collections import defaultdict
import random
from time import time
import json
from tqdm import tqdm
import numpy as np
from transformers import AlbertTokenizer
from mindspore import Tensor, ops
from mindspore import dtype as mstype
from src.rerank_and_reader_data_generator import DataGenerator
from src.rerank_and_reader_utils import convert_to_tokens, make_wiki_id, DocDB
from src.reader import Reader
def read(args):
"""reader function"""
db_file = args.wiki_db_file
reader_feature_file = args.reader_feature_file
reader_example_file = args.reader_example_file
encoder_ck_file = args.reader_encoder_ck_file
downstream_ck_file = args.reader_downstream_ck_file
albert_model_path = args.albert_model_path
reader_result_file = args.reader_result_file
seed = args.seed
sp_threshold = args.sp_threshold
seq_len = args.seq_len
batch_size = args.reader_batch_size
para_limit = args.max_para_num
sent_limit = args.max_sent_num
random.seed(seed)
np.random.seed(seed)
t1 = time()
doc_db = DocDB(db_file)
generator = DataGenerator(feature_file_path=reader_feature_file,
example_file_path=reader_example_file,
batch_size=batch_size, seq_len=seq_len,
para_limit=para_limit, sent_limit=sent_limit,
task_type="reader")
example_dict = generator.example_dict
feature_dict = generator.feature_dict
answer_dict = defaultdict(lambda: defaultdict(list))
new_answer_dict = {}
total_sp_dict = defaultdict(list)
new_total_sp_dict = defaultdict(list)
tokenizer = AlbertTokenizer.from_pretrained(albert_model_path)
new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]']
tokenizer.add_tokens(new_tokens)
reader = Reader(batch_size=batch_size,
encoder_ck_file=encoder_ck_file,
downstream_ck_file=downstream_ck_file)
print("start reading ...")
for _, batch in tqdm(enumerate(generator)):
input_ids = Tensor(batch["context_idxs"], mstype.int32)
attn_mask = Tensor(batch["context_mask"], mstype.int32)
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
context_mask = Tensor(batch["context_mask"], mstype.float32)
square_mask = Tensor(batch["square_mask"], mstype.float32)
packing_mask = Tensor(batch["query_mapping"], mstype.float32)
para_start_mapping = Tensor(batch["para_start_mapping"], mstype.float32)
sent_end_mapping = Tensor(batch["sent_end_mapping"], mstype.float32)
unique_ids = batch["unique_ids"]
sent_names = batch["sent_names"]
cache_mask = Tensor(np.tril(np.triu(np.ones((seq_len, seq_len)), 0), 30), mstype.float32)
_, _, q_type, _, sent_logit, y1, y2 = reader(input_ids, attn_mask, token_type_ids,
context_mask, square_mask, packing_mask, cache_mask,
para_start_mapping, sent_end_mapping)
type_prob = ops.Softmax()(q_type).asnumpy()
answer_dict_ = convert_to_tokens(example_dict,
feature_dict,
batch['ids'],
y1.asnumpy().tolist(),
y2.asnumpy().tolist(),
type_prob,
tokenizer,
sent_logit.asnumpy(),
sent_names,
unique_ids)
for q_id in answer_dict_:
answer_dict[q_id] = answer_dict_[q_id]
for q_id in answer_dict:
res = answer_dict[q_id]
answer_text_ = res[0]
sent_ = res[1]
sent_names_ = res[2]
new_answer_dict[q_id] = answer_text_
predict_support_np = ops.Sigmoid()(Tensor(sent_, mstype.float32)).asnumpy()
for j in range(predict_support_np.shape[0]):
if j >= len(sent_names_):
break
if predict_support_np[j] > sp_threshold:
total_sp_dict[q_id].append(sent_names_[j])
for _id in total_sp_dict:
_sent_names = total_sp_dict[_id]
for para in _sent_names:
title = make_wiki_id(para[0], 0)
para_original_title = doc_db.get_doc_info(title)[-1]
para[0] = para_original_title
new_total_sp_dict[_id].append(para)
prediction = {'answer': new_answer_dict,
'sp': new_total_sp_dict}
with open(reader_result_file, 'w') as f:
json.dump(prediction, f, indent=4)
t2 = time()
print(f"reader cost time: {t2-t1} s")

View File

@ -0,0 +1,276 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""albert-xxlarge Model for reranker"""
import numpy as np
from mindspore import nn, ops
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore import dtype as mstype
dst_type = mstype.float16
dst_type2 = mstype.float32
class LayerNorm(nn.Cell):
"""LayerNorm layer"""
def __init__(self, passthrough_w_0, passthrough_w_1):
"""init function"""
super(LayerNorm, self).__init__()
self.reducemean_0 = P.ReduceMean(keep_dims=True)
self.sub_1 = P.Sub()
self.pow_2 = P.Pow()
self.pow_2_input_weight = 2.0
self.reducemean_3 = P.ReduceMean(keep_dims=True)
self.add_4 = P.Add()
self.add_4_bias = 9.999999960041972e-13
self.sqrt_5 = P.Sqrt()
self.div_6 = P.Div()
self.mul_7 = P.Mul()
self.mul_7_w = passthrough_w_0
self.add_8 = P.Add()
self.add_8_bias = passthrough_w_1
def construct(self, x):
"""construct function"""
opt_reducemean_0 = self.reducemean_0(x, -1)
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight)
opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1)
opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias)
opt_sqrt_5 = self.sqrt_5(opt_add_4)
opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5)
opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w)
opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias)
return opt_add_8
class Linear(nn.Cell):
"""Linear layer"""
def __init__(self, matmul_0_w_shape, passthrough_w_0):
"""init function"""
super(Linear, self).__init__()
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_w_shape).astype(np.float32)), name=None)
self.add_1 = P.Add()
self.add_1_bias = passthrough_w_0
def construct(self, x):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
return opt_add_1
class MultiHeadAttn(nn.Cell):
"""Multi-head attention layer"""
def __init__(self, batch_size, passthrough_w_0, passthrough_w_1, passthrough_w_2):
"""init function"""
super(MultiHeadAttn, self).__init__()
self.batch_size = batch_size
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.matmul_1 = nn.MatMul()
self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.matmul_2 = nn.MatMul()
self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
self.add_3 = P.Add()
self.add_3_bias = passthrough_w_0
self.add_4 = P.Add()
self.add_4_bias = passthrough_w_1
self.add_5 = P.Add()
self.add_5_bias = passthrough_w_2
self.reshape_6 = P.Reshape()
self.reshape_6_shape = tuple([batch_size, 512, 64, 64])
self.reshape_7 = P.Reshape()
self.reshape_7_shape = tuple([batch_size, 512, 64, 64])
self.reshape_8 = P.Reshape()
self.reshape_8_shape = tuple([batch_size, 512, 64, 64])
self.transpose_9 = P.Transpose()
self.transpose_10 = P.Transpose()
self.transpose_11 = P.Transpose()
self.matmul_12 = nn.MatMul()
self.div_13 = P.Div()
self.div_13_w = 8.0
self.add_14 = P.Add()
self.softmax_15 = nn.Softmax(axis=3)
self.matmul_16 = nn.MatMul()
self.transpose_17 = P.Transpose()
self.matmul_18 = P.MatMul()
self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None)
self.add_19 = P.Add()
self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
def construct(self, x, x0):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type))
opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type))
opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias)
opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias)
opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias)
opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type))
opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2))
opt_add_14 = self.add_14(opt_div_13, x0)
opt_softmax_15 = self.softmax_15(opt_add_14)
opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type))
opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3))
opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1),
ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\
.view(self.batch_size, 512, 4096)
opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias)
return opt_add_19
class NewGeLU(nn.Cell):
"""Gelu layer"""
def __init__(self):
"""init function"""
super(NewGeLU, self).__init__()
self.mul_0 = P.Mul()
self.mul_0_w = 0.5
self.pow_1 = P.Pow()
self.pow_1_input_weight = 3.0
self.mul_2 = P.Mul()
self.mul_2_w = 0.044714998453855515
self.add_3 = P.Add()
self.mul_4 = P.Mul()
self.mul_4_w = 0.7978845834732056
self.tanh_5 = nn.Tanh()
self.add_6 = P.Add()
self.add_6_bias = 1.0
self.mul_7 = P.Mul()
def construct(self, x):
"""construct function"""
opt_mul_0 = self.mul_0(x, self.mul_0_w)
opt_pow_1 = self.pow_1(x, self.pow_1_input_weight)
opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w)
opt_add_3 = self.add_3(x, opt_mul_2)
opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w)
opt_tanh_5 = self.tanh_5(opt_mul_4)
opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias)
opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6)
return opt_mul_7
class TransformerLayerWithLayerNorm(nn.Cell):
"""Transformer layer with LayerNOrm"""
def __init__(self, batch_size, linear3_0_matmul_0_w_shape, linear3_1_matmul_0_w_shape, passthrough_w_0,
passthrough_w_1, passthrough_w_2, passthrough_w_3, passthrough_w_4, passthrough_w_5, passthrough_w_6):
"""init function"""
super(TransformerLayerWithLayerNorm, self).__init__()
self.multiheadattn_0 = MultiHeadAttn(batch_size=batch_size,
passthrough_w_0=passthrough_w_0,
passthrough_w_1=passthrough_w_1,
passthrough_w_2=passthrough_w_2)
self.add_0 = P.Add()
self.layernorm1_0 = LayerNorm(passthrough_w_0=passthrough_w_3, passthrough_w_1=passthrough_w_4)
self.linear3_0 = Linear(matmul_0_w_shape=linear3_0_matmul_0_w_shape, passthrough_w_0=passthrough_w_5)
self.newgelu2_0 = NewGeLU()
self.linear3_1 = Linear(matmul_0_w_shape=linear3_1_matmul_0_w_shape, passthrough_w_0=passthrough_w_6)
self.add_1 = P.Add()
def construct(self, x, x0):
"""construct function"""
multiheadattn_0_opt = self.multiheadattn_0(x, x0)
opt_add_0 = self.add_0(x, multiheadattn_0_opt)
layernorm1_0_opt = self.layernorm1_0(opt_add_0)
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
newgelu2_0_opt = self.newgelu2_0(linear3_0_opt)
linear3_1_opt = self.linear3_1(newgelu2_0_opt)
opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
return opt_add_1
class Rerank_Albert(nn.Cell):
"""Albert model for rerank"""
def __init__(self, batch_size):
"""init function"""
super(Rerank_Albert, self).__init__()
self.passthrough_w_0 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None)
self.passthrough_w_1 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None)
self.passthrough_w_2 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_3 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_4 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_5 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_6 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_7 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_8 = Parameter(Tensor(np.random.uniform(0, 1, (16384,)).astype(np.float32)), name=None)
self.passthrough_w_9 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_10 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.passthrough_w_11 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
self.expanddims_0 = P.ExpandDims()
self.expanddims_0_axis = 1
self.expanddims_3 = P.ExpandDims()
self.expanddims_3_axis = 2
self.cast_5 = P.Cast()
self.cast_5_to = mstype.float32
self.sub_7 = P.Sub()
self.sub_7_bias = 1.0
self.mul_9 = P.Mul()
self.mul_9_w = -10000.0
self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)),
name=None)
self.gather_1_axis = 0
self.gather_1 = P.Gather()
self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None)
self.gather_2_axis = 0
self.gather_2 = P.Gather()
self.add_4 = P.Add()
self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None)
self.add_6 = P.Add()
self.layernorm1_0 = LayerNorm(passthrough_w_0=self.passthrough_w_0, passthrough_w_1=self.passthrough_w_1)
self.linear3_0 = Linear(matmul_0_w_shape=(128, 4096), passthrough_w_0=self.passthrough_w_2)
self.module34_0 = TransformerLayerWithLayerNorm(batch_size=batch_size,
linear3_0_matmul_0_w_shape=(4096, 16384),
linear3_1_matmul_0_w_shape=(16384, 4096),
passthrough_w_0=self.passthrough_w_3,
passthrough_w_1=self.passthrough_w_4,
passthrough_w_2=self.passthrough_w_5,
passthrough_w_3=self.passthrough_w_6,
passthrough_w_4=self.passthrough_w_7,
passthrough_w_5=self.passthrough_w_8,
passthrough_w_6=self.passthrough_w_9)
self.layernorm1_1 = LayerNorm(passthrough_w_0=self.passthrough_w_10, passthrough_w_1=self.passthrough_w_11)
def construct(self, input_ids, attention_mask, token_type_ids):
"""construct function"""
opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis)
opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
opt_gather_1_axis = self.gather_1_axis
opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis)
opt_gather_2_axis = self.gather_2_axis
opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis)
opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias)
opt_add_6 = self.add_6(opt_add_4, opt_gather_2)
layernorm1_0_opt = self.layernorm1_0(opt_add_6)
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
opt = self.module34_0(linear3_0_opt, opt_mul_9)
opt = self.layernorm1_1(opt)
for _ in range(11):
opt = self.module34_0(opt, opt_mul_9)
opt = self.layernorm1_1(opt)
return opt

View File

@ -0,0 +1,183 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define a data generator"""
import gzip
import pickle
import random
import numpy as np
random.seed(42)
np.random.seed(42)
class DataGenerator:
"""data generator for reranker and reader"""
def __init__(self, feature_file_path, example_file_path, batch_size, seq_len,
para_limit=None, sent_limit=None, task_type=None):
"""init function"""
self.example_ptr = 0
self.bsz = batch_size
self.seq_length = seq_len
self.para_limit = para_limit
self.sent_limit = sent_limit
self.task_type = task_type
self.feature_file_path = feature_file_path
self.example_file_path = example_file_path
self.features = self.load_features()
self.examples = self.load_examples()
self.feature_dict = self.get_feature_dict()
self.example_dict = self.get_example_dict()
self.features = self.padding_feature(self.features, self.bsz)
def load_features(self):
"""load features from feature file"""
with gzip.open(self.feature_file_path, 'rb') as fin:
features = pickle.load(fin)
print("load features successful !!!")
return features
def padding_feature(self, features, bsz):
"""padding features as multiples of batch size"""
padding_num = ((len(features) // bsz + 1) * bsz - len(features))
print(f"features padding num is {padding_num}")
new_features = features + features[:padding_num]
return new_features
def load_examples(self):
"""laod examples from file"""
if self.example_file_path:
with gzip.open(self.example_file_path, 'rb') as fin:
examples = pickle.load(fin)
print("load examples successful !!!")
return examples
return {}
def get_feature_dict(self):
"""build a feature dict"""
return {f.unique_id: f for f in self.features}
def get_example_dict(self):
"""build a example dict"""
if self.example_file_path:
return {e.unique_id: e for e in self.examples}
return {}
def common_process_single_case(self, i, case, context_idxs, context_mask, segment_idxs, ids, path, unique_ids):
"""common process for a single case"""
context_idxs[i] = np.array(case.doc_input_ids)
context_mask[i] = np.array(case.doc_input_mask)
segment_idxs[i] = np.array(case.doc_segment_ids)
ids.append(case.qas_id)
path.append(case.path)
unique_ids.append(case.unique_id)
return context_idxs, context_mask, segment_idxs, ids, path, unique_ids
def reader_process_single_case(self, i, case, sent_names, square_mask, query_mapping, ques_start_mapping,
para_start_mapping, sent_end_mapping):
"""process for a single case about reader"""
sent_names.append(case.sent_names)
prev_position = None
for cur_position, token_id in enumerate(case.doc_input_ids):
if token_id >= 30000:
if prev_position:
square_mask[i, prev_position + 1: cur_position, prev_position + 1: cur_position] = 1.0
prev_position = cur_position
if case.sent_spans:
for j in range(case.sent_spans[0][0] - 1):
query_mapping[i, j] = 1
ques_start_mapping[i, 0, 1] = 1
for j, para_span in enumerate(case.para_spans[:self.para_limit]):
start, end, _ = para_span
if start <= end:
para_start_mapping[i, j, start] = 1
for j, sent_span in enumerate(case.sent_spans[:self.sent_limit]):
start, end = sent_span
if start <= end:
end = min(end, self.seq_length - 1)
sent_end_mapping[i, j, end] = 1
return sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping
def __iter__(self):
"""iteration function"""
while True:
if self.example_ptr >= len(self.features):
break
start_id = self.example_ptr
cur_bsz = min(self.bsz, len(self.features) - start_id)
cur_batch = self.features[start_id: start_id + cur_bsz]
# BERT input
context_idxs = np.zeros((cur_bsz, self.seq_length))
context_mask = np.zeros((cur_bsz, self.seq_length))
segment_idxs = np.zeros((cur_bsz, self.seq_length))
# others
ids = []
path = []
unique_ids = []
if self.task_type == "reader":
# Mappings
ques_start_mapping = np.zeros((cur_bsz, 1, self.seq_length))
query_mapping = np.zeros((cur_bsz, self.seq_length))
para_start_mapping = np.zeros((cur_bsz, self.para_limit, self.seq_length))
sent_end_mapping = np.zeros((cur_bsz, self.sent_limit, self.seq_length))
square_mask = np.zeros((cur_bsz, self.seq_length, self.seq_length))
sent_names = []
for i, case in enumerate(cur_batch):
context_idxs, context_mask, segment_idxs, ids, path, unique_ids = \
self.common_process_single_case(i, case, context_idxs, context_mask, segment_idxs, ids, path,
unique_ids)
if self.task_type == "reader":
sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping = \
self.reader_process_single_case(i, case, sent_names, square_mask, query_mapping,
ques_start_mapping, para_start_mapping, sent_end_mapping)
self.example_ptr += cur_bsz
if self.task_type == "reranker":
yield {
"context_idxs": context_idxs,
"context_mask": context_mask,
"segment_idxs": segment_idxs,
"ids": ids,
"unique_ids": unique_ids,
"path": path
}
elif self.task_type == "reader":
yield {
"context_idxs": context_idxs,
"context_mask": context_mask,
"segment_idxs": segment_idxs,
"query_mapping": query_mapping,
"para_start_mapping": para_start_mapping,
"sent_end_mapping": sent_end_mapping,
"square_mask": square_mask,
"ques_start_mapping": ques_start_mapping,
"ids": ids,
"unique_ids": unique_ids,
"sent_names": sent_names,
"path": path
}
else:
print(f"data generator received a error type: {self.task_type} !!!")

View File

@ -0,0 +1,656 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils"""
import re
import argparse
from urllib.parse import unquote
from collections import defaultdict
import collections
import logging
import unicodedata
import json
import gzip
import string
import pickle
import sqlite3
from tqdm import tqdm
import numpy as np
from transformers import BasicTokenizer
logger = logging.getLogger(__name__)
class Example:
"""A single example of data"""
def __init__(self,
qas_id,
path,
unique_id,
question_tokens,
doc_tokens,
sent_names,
sup_fact_id,
sup_para_id,
para_start_end_position,
sent_start_end_position,
question_text,
title_start_end_position=None):
"""init function"""
self.qas_id = qas_id
self.path = path
self.unique_id = unique_id
self.question_tokens = question_tokens
self.doc_tokens = doc_tokens
self.question_text = question_text
self.sent_names = sent_names
self.sup_fact_id = sup_fact_id
self.sup_para_id = sup_para_id
self.para_start_end_position = para_start_end_position
self.sent_start_end_position = sent_start_end_position
self.title_start_end_position = title_start_end_position
class InputFeatures:
"""A single set of features of data."""
def __init__(self,
unique_id,
qas_id,
path,
sent_names,
doc_tokens,
doc_input_ids,
doc_input_mask,
doc_segment_ids,
query_tokens,
query_input_ids,
query_input_mask,
query_segment_ids,
para_spans,
sent_spans,
token_to_orig_map):
"""init function"""
self.qas_id = qas_id
self.doc_tokens = doc_tokens
self.doc_input_ids = doc_input_ids
self.doc_input_mask = doc_input_mask
self.doc_segment_ids = doc_segment_ids
self.path = path
self.unique_id = unique_id
self.sent_names = sent_names
self.query_tokens = query_tokens
self.query_input_ids = query_input_ids
self.query_input_mask = query_input_mask
self.query_segment_ids = query_segment_ids
self.para_spans = para_spans
self.sent_spans = sent_spans
self.token_to_orig_map = token_to_orig_map
class DocDB:
"""
Sqlite backed document storage.
Implements get_doc_text(doc_id).
"""
def __init__(self, db_path):
"""init function"""
self.path = db_path
self.connection = sqlite3.connect(self.path, check_same_thread=False)
def __enter__(self):
"""enter function"""
return self
def __exit__(self, *args):
"""exit function"""
self.close()
def close(self):
"""Close the connection to the database."""
self.connection.close()
def get_doc_ids(self):
"""Fetch all ids of docs stored in the db."""
cursor = self.connection.cursor()
cursor.execute("SELECT id FROM documents")
results = [r[0] for r in cursor.fetchall()]
cursor.close()
return results
def get_doc_info(self, doc_id):
"""get docment information"""
if not doc_id.endswith('_0'):
doc_id += '_0'
cursor = self.connection.cursor()
cursor.execute(
"SELECT * FROM documents WHERE id = ?",
(normalize_title(doc_id),)
)
result = cursor.fetchall()
cursor.close()
return result if result is None else result[0]
def get_parse():
"""get parse function"""
parser = argparse.ArgumentParser()
# Environment
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--seq_len', type=int, default=512,
help="max sentence length")
parser.add_argument("--get_reranker_data",
action='store_true',
help="Set this flag if you want to get reranker data from retrieved result")
parser.add_argument("--run_reranker",
action='store_true',
help="Set this flag if you want to run reranker")
parser.add_argument("--cal_reranker_metrics",
action='store_true',
help="Set this flag if you want to calculate rerank metrics")
parser.add_argument("--select_reader_data",
action='store_true',
help="Set this flag if you want to select reader data")
parser.add_argument("--run_reader",
action='store_true',
help="Set this flag if you want to run reader")
parser.add_argument("--cal_reader_metrics",
action='store_true',
help="Set this flag if you want to calculate reader metrics")
parser.add_argument('--dev_gold_file',
type=str,
default="../hotpot_dev_fullwiki_v1.json",
help='file of dev ground truth')
parser.add_argument('--wiki_db_file',
type=str,
default="../enwiki_offset.db",
help='wiki_database_file')
parser.add_argument('--albert_model_path',
type=str,
default="../albert-xxlarge/",
help='model path of huggingface albert-xxlarge')
# Retriever
parser.add_argument('--retriever_result_file',
type=str,
default="../doc_path",
help='file of retriever result')
# Rerank
parser.add_argument('--rerank_batch_size', type=int, default=32,
help="rerank batchsize for evaluating")
parser.add_argument('--rerank_feature_file',
type=str,
default="../reranker_feature_file.pkl.gz",
help='file of rerank feature')
parser.add_argument('--rerank_example_file',
type=str,
default="../reranker_example_file.pkl.gz",
help='file of rerank example')
parser.add_argument('--rerank_result_file',
type=str,
default="../rerank_result.json",
help='file of rerank result')
parser.add_argument('--rerank_encoder_ck_file',
type=str,
default="../rerank_albert_12.ckpt",
help='checkpoint of rerank albert-xxlarge')
parser.add_argument('--rerank_downstream_ck_file',
type=str,
default="../rerank_downstream.ckpt",
help='checkpoint of rerank downstream')
# Reader
parser.add_argument('--reader_batch_size', type=int, default=32,
help="reader batchsize for evaluating")
parser.add_argument('--reader_feature_file',
type=str,
default="../reader_feature_file.pkl.gz",
help='file of reader feature')
parser.add_argument('--reader_example_file',
type=str,
default="../reader_example_file.pkl.gz",
help='file of reader example')
parser.add_argument('--reader_encoder_ck_file',
type=str,
default="../albert_12_layer.ckpt",
help='checkpoint of reader albert-xxlarge')
parser.add_argument('--reader_downstream_ck_file',
type=str,
default="../reader_downstream.ckpt",
help='checkpoint of reader downstream')
parser.add_argument('--reader_result_file',
type=str,
default="../reader_result_file.json",
help='file of reader result')
parser.add_argument('--sp_threshold', type=float, default=0.65,
help="threshold for selecting supporting sentences")
parser.add_argument("--max_para_num", default=2, type=int)
parser.add_argument("--max_sent_num", default=40, type=int)
return parser
def select_reader_dev_data(args):
"""select reader dev data from result of retriever based on result of reranker"""
rerank_result_file = args.rerank_result_file
rerank_feature_file = args.rerank_feature_file
rerank_example_file = args.rerank_example_file
reader_feature_file = args.reader_feature_file
reader_example_file = args.reader_example_file
with gzip.open(rerank_example_file, "rb") as f:
dev_examples = pickle.load(f)
with gzip.open(rerank_feature_file, "rb") as f:
dev_features = pickle.load(f)
with open(rerank_result_file, "r") as f:
rerank_result = json.load(f)
new_dev_examples = []
new_dev_features = []
rerank_unique_ids = defaultdict(int)
feature_unique_ids = defaultdict(int)
for _, res in tqdm(rerank_result.items(), desc="get rerank unique ids"):
rerank_unique_ids[res[0]] = True
print(f"rerank result num is {len(rerank_unique_ids)}")
for feature in tqdm(dev_features, desc="select rerank top1 feature"):
if feature.unique_id in rerank_unique_ids:
feature_unique_ids[feature.unique_id] = True
new_dev_features.append(feature)
print(f"new feature num is {len(new_dev_features)}")
for example in tqdm(dev_examples, desc="select rerank top1 example"):
if example.unique_id in rerank_unique_ids and example.unique_id in feature_unique_ids:
new_dev_examples.append(example)
print(f"new examples num is {len(new_dev_examples)}")
print("start save new examples ......")
with gzip.open(reader_example_file, "wb") as f:
pickle.dump(new_dev_examples, f)
print("start save new features ......")
with gzip.open(reader_feature_file, "wb") as f:
pickle.dump(new_dev_features, f)
print("finish selecting reader data !!!")
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose_logging:
print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose_logging:
print("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose_logging:
print("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def get_ans_from_pos(tokenizer, examples, features, y1, y2, unique_id):
"""get answer text from predicted position"""
feature = features[unique_id]
example = examples[unique_id]
tok_to_orig_map = feature.token_to_orig_map
orig_all_tokens = example.question_tokens + example.doc_tokens
final_text = " "
if y1 < len(tok_to_orig_map) and y2 < len(tok_to_orig_map):
orig_tok_start = tok_to_orig_map[y1]
orig_tok_end = tok_to_orig_map[y2]
# -----------------orig all tokens-----------------------------------
orig_tokens = orig_all_tokens[orig_tok_start: (orig_tok_end + 1)]
tok_tokens = feature.doc_tokens[y1: y2 + 1]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, True, False)
# print("final_text: " + final_text)
return final_text
def convert_to_tokens(examples, features, ids, y1, y2, q_type_prob, tokenizer, sent, sent_names,
unique_ids):
"""get raw answer text and supporting sentences"""
answer_dict = defaultdict(list)
q_type = np.argmax(q_type_prob, 1)
for i, qid in enumerate(ids):
unique_id = unique_ids[i]
if q_type[i] == 0:
answer_text = 'yes'
elif q_type[i] == 1:
answer_text = 'no'
elif q_type[i] == 2:
answer_text = get_ans_from_pos(tokenizer, examples, features, y1[i], y2[i], unique_id)
else:
raise ValueError("question type error")
answer_dict[qid].append(answer_text)
answer_dict[qid].append(sent[i])
answer_dict[qid].append(sent_names[i])
return answer_dict
def normalize_title(text):
"""Resolve different type of unicode encodings / capitarization in HotpotQA data."""
text = unicodedata.normalize('NFD', text)
return text[0].capitalize() + text[1:]
def make_wiki_id(title, para_index):
"""make wiki id"""
title_id = "{0}_{1}".format(normalize_title(title), para_index)
return title_id
def cal_reranker_metrics(dev_gold_file, rerank_result_file):
"""function for calculating reranker's metrics"""
with open(dev_gold_file, 'rb') as f:
gt = json.load(f)
with open(rerank_result_file, 'rb') as f:
rerank_result = json.load(f)
cnt = 0
all_ = len(gt)
cnt_c = 0
cnt_b = 0
all_c = 0
all_b = 0
for item in tqdm(gt, desc="get com and bridge "):
q_type = item["type"]
if q_type == "comparison":
all_c += 1
elif q_type == "bridge":
all_b += 1
else:
print(f"{q_type} is a error question type!!!")
for item in tqdm(gt, desc="cal pem"):
_id = item["_id"]
if _id in rerank_result:
pred = rerank_result[_id][1]
sps = item["supporting_facts"]
q_type = item["type"]
gold = []
for t in sps:
gold.append(normalize_title(t[0]))
gold = set(gold)
flag = True
for t in gold:
if t not in pred:
flag = False
break
if flag:
cnt += 1
if q_type == "comparison":
cnt_c += 1
elif q_type == "bridge":
cnt_b += 1
else:
print(f"{q_type} is a error question type!!!")
return cnt/all_, cnt_c/all_c, cnt_b/all_b
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def find_hyper_linked_titles(text_w_links):
"""find hyperlinked titles"""
titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links)
titles = [unquote(title) for title in titles]
titles = [title[0].capitalize() + title[1:] for title in titles]
return titles
def normalize_text(text):
"""Resolve different type of unicode encodings / capitarization in HotpotQA data."""
text = unicodedata.normalize('NFD', text)
return text
def convert_char_to_token_offset(orig_text, start_offset, end_offset, char_to_word_offset, doc_tokens):
"""build characters' offset"""
length = len(orig_text)
assert start_offset + length == end_offset
assert end_offset <= len(char_to_word_offset)
start_position = char_to_word_offset[start_offset]
end_position = char_to_word_offset[start_offset + length - 1]
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
assert actual_text.lower().find(orig_text.lower()) != -1
return start_position, end_position
def _is_whitespace(c):
"""check whitespace"""
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
def convert_text_to_tokens(context_text, return_word_start=False):
"""convert text to tokens"""
doc_tokens = []
char_to_word_offset = []
words_start_idx = []
prev_is_whitespace = True
for idx, c in enumerate(context_text):
if _is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
words_start_idx.append(idx)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
if not return_word_start:
return doc_tokens, char_to_word_offset
return doc_tokens, char_to_word_offset, words_start_idx
def read_json(eval_file_name):
"""reader json files"""
print("loading examples from {0}".format(eval_file_name))
with open(eval_file_name) as reader:
lines = json.load(reader)
return lines
def write_json(data, out_file_name):
"""write json files"""
print("writing {0} examples to {1}".format(len(data), out_file_name))
with open(out_file_name, 'w') as writer:
json.dump(data, writer, indent=4)
def get_edges(sentence):
"""get edges"""
EDGE_XY = re.compile(r'<a href="(?!http|<a)(.*?)">(.*?)<\/a>')
ret = EDGE_XY.findall(sentence)
return [(unquote(x), y) for x, y in ret]
def relocate_tok_span(orig_to_tok_index, orig_to_tok_back_index, word_tokens, subword_tokens,
orig_start_position, orig_end_position, orig_text, tokenizer, tok_to_orig_index=None):
"""relocate tokens' span"""
if orig_start_position is None:
return 0, 0
tok_start_position = orig_to_tok_index[orig_start_position]
if tok_start_position >= len(subword_tokens):
return 0, 0
if orig_end_position < len(word_tokens) - 1:
tok_end_position = orig_to_tok_back_index[orig_end_position]
if tok_to_orig_index and tok_to_orig_index[tok_end_position + 1] == -1:
assert tok_end_position <= orig_to_tok_index[orig_end_position + 1] - 2
else:
assert tok_end_position == orig_to_tok_index[orig_end_position + 1] - 1
else:
tok_end_position = orig_to_tok_back_index[orig_end_position]
return _improve_answer_span(
subword_tokens, tok_start_position, tok_end_position, tokenizer, orig_text)
def generate_mapping(length, positions):
"""generate mapping"""
start_mapping = [0] * length
end_mapping = [0] * length
for _, (start, end) in enumerate(positions):
start_mapping[start] = 1
end_mapping[end] = 1
return start_mapping, end_mapping
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=True))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
if text_span == tok_answer_text:
return new_start, new_end
return input_start, input_end
def _largest_valid_index(spans, limit):
"""return largest valid index"""
for idx, _ in enumerate(spans):
if spans[idx][1] >= limit:
return idx
return len(spans)
def remove_punc(text):
"""remove punctuation"""
if text == " ":
return ''
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def check_text_include_ans(ans, text):
"""check whether text include answer"""
if normalize_answer(ans) in normalize_answer(text):
return True
return False
def remove_articles(text):
"""remove articles"""
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
"""fix whitespace"""
return ' '.join(text.split())
def lower(text):
"""lower text"""
return text.lower()
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
return white_space_fix(remove_articles(remove_punc(lower(s))))

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""downstream Model for reranker"""
import numpy as np
from mindspore import nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
class Rerank_Downstream(nn.Cell):
"""Downstream model for rerank"""
def __init__(self):
"""init function"""
super(Rerank_Downstream, self).__init__()
self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True)
self.relu_1 = nn.ReLU()
self.reducemean_2 = P.ReduceMean(keep_dims=True)
self.sub_3 = P.Sub()
self.sub_4 = P.Sub()
self.pow_5 = P.Pow()
self.pow_5_input_weight = 2.0
self.reducemean_6 = P.ReduceMean(keep_dims=True)
self.add_7 = P.Add()
self.add_7_bias = 9.999999960041972e-13
self.sqrt_8 = P.Sqrt()
self.div_9 = P.Div()
self.mul_10 = P.Mul()
self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
self.add_11 = P.Add()
self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True)
def construct(self, x):
"""construct function"""
opt_dense_0 = self.dense_0(x)
opt_relu_1 = self.relu_1(opt_dense_0)
opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1)
opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2)
opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2)
opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight)
opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1)
opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias)
opt_sqrt_8 = self.sqrt_8(opt_add_7)
opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8)
opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9)
opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias)
opt_dense_12 = self.dense_12(opt_add_11)
return opt_dense_12

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reranker Model"""
import mindspore.nn as nn
from mindspore import load_checkpoint, load_param_into_net
from src.rerank_albert_xxlarge import Rerank_Albert
from src.rerank_downstream import Rerank_Downstream
class Reranker(nn.Cell):
"""Reranker model"""
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
"""init function"""
super(Reranker, self).__init__(auto_prefix=False)
self.encoder = Rerank_Albert(batch_size)
param_dict = load_checkpoint(encoder_ck_file)
not_load_params_1 = load_param_into_net(self.encoder, param_dict)
print(f"not loaded albert: {not_load_params_1}")
self.no_answer_mlp = Rerank_Downstream()
param_dict = load_checkpoint(downstream_ck_file)
not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict)
print(f"not loaded downstream: {not_load_params_2}")
def construct(self, input_ids, attn_mask, token_type_ids):
"""construct function"""
state = self.encoder(input_ids, attn_mask, token_type_ids)
state = state[:, 0, :]
no_answer = self.no_answer_mlp(state)
return no_answer

View File

@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""execute reranker"""
import json
import random
from collections import defaultdict
from time import time
from tqdm import tqdm
import numpy as np
from mindspore import Tensor, ops
from mindspore import dtype as mstype
from src.rerank_and_reader_data_generator import DataGenerator
from src.reranker import Reranker
def rerank(args):
"""rerank function"""
rerank_feature_file = args.rerank_feature_file
rerank_result_file = args.rerank_result_file
encoder_ck_file = args.rerank_encoder_ck_file
downstream_ck_file = args.rerank_downstream_ck_file
seed = args.seed
seq_len = args.seq_len
batch_size = args.rerank_batch_size
random.seed(seed)
np.random.seed(seed)
t1 = time()
generator = DataGenerator(feature_file_path=rerank_feature_file,
example_file_path=None,
batch_size=batch_size, seq_len=seq_len,
task_type="reranker")
gather_dict = defaultdict(lambda: defaultdict(list))
reranker = Reranker(batch_size=batch_size,
encoder_ck_file=encoder_ck_file,
downstream_ck_file=downstream_ck_file)
print("start re-ranking ...")
for _, batch in tqdm(enumerate(generator)):
input_ids = Tensor(batch["context_idxs"], mstype.int32)
attn_mask = Tensor(batch["context_mask"], mstype.int32)
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
no_answer = reranker(input_ids, attn_mask, token_type_ids)
no_answer_prob = ops.Softmax()(no_answer).asnumpy()
no_answer_prob = no_answer_prob[:, 0]
for i in range(len(batch['ids'])):
qas_id = batch['ids'][i]
gather_dict[qas_id][no_answer_prob[i]].append(batch['unique_ids'][i])
gather_dict[qas_id][no_answer_prob[i]].append(batch['path'][i])
rerank_result = {}
for qas_id in tqdm(gather_dict, desc="get top1 path from re-rank result"):
all_paths = gather_dict[qas_id]
all_paths = sorted(all_paths.items(), key=lambda item: item[0])
assert qas_id not in rerank_result
rerank_result[qas_id] = all_paths[0][1]
with open(rerank_result_file, 'w') as f:
json.dump(rerank_result, f)
t2 = time()
print(f"re-rank cost time: {t2-t1} s")