From 5d420716f949e9ad242f4347763944e0ce9bc03e Mon Sep 17 00:00:00 2001 From: zhanke Date: Tue, 9 Mar 2021 16:52:44 +0800 Subject: [PATCH] add tprr retriever model --- model_zoo/research/nlp/tprr/README.md | 164 ++++++++++ model_zoo/research/nlp/tprr/retriever_eval.py | 180 +++++++++++ .../nlp/tprr/scripts/run_eval_ascend.sh | 39 +++ model_zoo/research/nlp/tprr/src/config.py | 46 +++ model_zoo/research/nlp/tprr/src/onehop.py | 58 ++++ .../research/nlp/tprr/src/onehop_bert.py | 302 ++++++++++++++++++ .../research/nlp/tprr/src/process_data.py | 254 +++++++++++++++ model_zoo/research/nlp/tprr/src/twohop.py | 60 ++++ .../research/nlp/tprr/src/twohop_bert.py | 302 ++++++++++++++++++ model_zoo/research/nlp/tprr/src/utils.py | 62 ++++ 10 files changed, 1467 insertions(+) create mode 100644 model_zoo/research/nlp/tprr/README.md create mode 100644 model_zoo/research/nlp/tprr/retriever_eval.py create mode 100644 model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh create mode 100644 model_zoo/research/nlp/tprr/src/config.py create mode 100644 model_zoo/research/nlp/tprr/src/onehop.py create mode 100644 model_zoo/research/nlp/tprr/src/onehop_bert.py create mode 100644 model_zoo/research/nlp/tprr/src/process_data.py create mode 100644 model_zoo/research/nlp/tprr/src/twohop.py create mode 100644 model_zoo/research/nlp/tprr/src/twohop_bert.py create mode 100644 model_zoo/research/nlp/tprr/src/utils.py diff --git a/model_zoo/research/nlp/tprr/README.md b/model_zoo/research/nlp/tprr/README.md new file mode 100644 index 00000000000..25ffac30a88 --- /dev/null +++ b/model_zoo/research/nlp/tprr/README.md @@ -0,0 +1,164 @@ +# Contents + +- [Thinking Path Re-Ranker](#thinking-path-re-ranker) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) +- [Description of random situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [Thinking Path Re-Ranker](#contents) + +Thinking Path Re-Ranker(TPRR) was proposed in 2021 by Huawei Poisson Lab & Parallel Distributed Computing Lab. By incorporating the +retriever, reranker and reader modules, TPRR shows excellent performance on open-domain multi-hop question answering. Moreover, TPRR has won +the first place in the current HotpotQA official leaderboard. This is a example of evaluation of TPRR with HotPotQA dataset in MindSpore. More +importantly, this is the first open source version for TPRR. + +# [Model Architecture](#contents) + +Specially, TPRR contains three main modules. The first is retriever, which generate document sequences of each hop iteratively. The second +is reranker for selecting the best path from candidate paths generated by retriever. The last one is reader for extracting answer spans. + +# [Dataset](#contents) + +The retriever dataset consists of three parts: +Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyperlinks. +dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs. +dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia +through TF-IDF. + +# [Features](#contents) + +## [Mixed Precision](#contents) + +To ultilize the strong computation power of Ascend chip, and accelerate the evaluation process, the mixed evaluation method is used. MindSpore +is able to cope with FP32 inputs and FP16 operators. In TPRR example, the model is set to FP16 mode for the matmul calculation part. + +# [Environment Requirements](#contents) + +- Hardware (Ascend) +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Quick Start](#contents) + +After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows. + +- running on Ascend + + ```python + # run evaluation example with HotPotQA dev dataset + sh run_eval_ascend.sh + ``` + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +```shell +. +└─tprr + ├─README.md + ├─scripts + | ├─run_eval_ascend.sh # Launch evaluation in ascend + | + ├─src + | ├─config.py # Evaluation configurations + | ├─onehop.py # Onehop model + | ├─onehop_bert.py # Onehop bert model + | ├─process_data.py # Data preprocessing + | ├─twohop.py # Twohop model + | ├─twohop_bert.py # Twohop bert model + | └─utils.py # Utils for evaluation + | + └─retriever_eval.py # Evaluation net for retriever +``` + +## [Script Parameters](#contents) + +Parameters for evaluation can be set in config.py. + +- config for TPRR retriever dataset + + ```python + "q_len": 64, # Max query length + "d_len": 192, # Max doc length + "s_len": 448, # Max sequence length + "in_len": 768, # Input dim + "out_len": 1, # Output dim + "num_docs": 500, # Num of docs + "topk": 8, # Top k + "onehop_num": 8 # Num of onehop doc as twohop neighbor + ``` + + config.py for more configuration. + +## [Evaluation Process](#contents) + +### Evaluation + +- Evaluation on Ascend + + ```python + sh run_eval_ascend.sh + ``` + + Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the + followings in log. + + ```python + ###step###: 0 + val: 0 + count: 1 + true count: 0 + PEM: 0.0 + + ... + ###step###: 7396 + val:6796 + count:7397 + true count: 6924 + PEM: 0.9187508449371367 + true top8 PEM: 0.9815135759676488 + evaluation time (h): 20.155506462653477 + ``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Inference Performance + +| Parameter | BGCF Ascend | +| ------------------------------ | ---------------------------- | +| Model Version | Inception V1 | +| Resource | Ascend 910 | +| uploaded Date | 03/12/2021(month/day/year) | +| MindSpore Version | 1.2.0 | +| Dataset | HotPotQA | +| Batch_size | 1 | +| Output | inference path | +| PEM | 0.9188 | + +# [Description of random situation](#contents) + +No random situation for evaluation. + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file diff --git a/model_zoo/research/nlp/tprr/retriever_eval.py b/model_zoo/research/nlp/tprr/retriever_eval.py new file mode 100644 index 00000000000..b0b7a125043 --- /dev/null +++ b/model_zoo/research/nlp/tprr/retriever_eval.py @@ -0,0 +1,180 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Retriever Evaluation. + +""" + +import time +import json + +import numpy as np +from mindspore import Tensor +import mindspore.context as context +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore import load_checkpoint, load_param_into_net + +from src.onehop import OneHopBert +from src.twohop import TwoHopBert +from src.process_data import DataGen +from src.onehop_bert import ModelOneHop +from src.twohop_bert import ModelTwoHop +from src.config import ThinkRetrieverConfig +from src.utils import read_query, split_queries, get_new_title, get_raw_title, save_json + + +def eval_output(out_2, last_out, path_raw, gold_path, val, true_count): + """evaluation output""" + y_pred_raw = out_2.asnumpy() + last_out_raw = last_out.asnumpy() + path = [] + y_pred = [] + last_out_list = [] + topk_titles = [] + index_list_raw = np.argsort(y_pred_raw) + for index_r in index_list_raw[::-1]: + tag = 1 + for raw_path in path: + if path_raw[index_r][0] in raw_path and path_raw[index_r][1] in raw_path: + tag = 0 + break + if tag: + path.append(path_raw[index_r]) + y_pred.append(y_pred_raw[index_r]) + last_out_list.append(last_out_raw[index_r]) + index_list = np.argsort(y_pred) + for path_index in index_list: + if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]: + true_count += 1 + break + for path_index in index_list[-8:][::-1]: + topk_titles.append(list(path[path_index])) + for path_index in index_list[-8:]: + if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]: + val += 1 + break + return val, true_count, topk_titles + + +def evaluation(): + """evaluation""" + print('********************** loading corpus ********************** ') + s_lc = time.time() + data_generator = DataGen(config) + queries = read_query(config) + print("loading corpus time (h):", (time.time() - s_lc) / 3600) + print('********************** loading model ********************** ') + s_lm = time.time() + + model_onehop_bert = ModelOneHop() + param_dict = load_checkpoint(config.onehop_bert_path) + load_param_into_net(model_onehop_bert, param_dict) + model_twohop_bert = ModelTwoHop() + param_dict2 = load_checkpoint(config.twohop_bert_path) + load_param_into_net(model_twohop_bert, param_dict2) + onehop = OneHopBert(config, model_onehop_bert) + twohop = TwoHopBert(config, model_twohop_bert) + + print("loading model time (h):", (time.time() - s_lm) / 3600) + print('********************** evaluation ********************** ') + s_tr = time.time() + + f_dev = open(config.dev_path, 'rb') + dev_data = json.load(f_dev) + q_gold = {} + q_2id = {} + for onedata in dev_data: + if onedata["question"] not in q_gold: + q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']] + q_2id[onedata["question"]] = onedata['_id'] + val, true_count, count, step = 0, 0, 0, 0 + batch_queries = split_queries(config, queries)[:-1] + output_path = [] + for _, batch in enumerate(batch_queries): + print("###step###: ", step) + query = batch[0] + temp_dict = {} + temp_dict['q_id'] = q_2id[query] + temp_dict['question'] = query + gold_path = q_gold[query] + input_ids_1, token_type_ids_1, input_mask_1 = data_generator.convert_onehop_to_features(batch) + start = 0 + TOTAL = len(input_ids_1) + split_chunk = 8 + while start < TOTAL: + end = min(start + split_chunk - 1, TOTAL - 1) + chunk_len = end - start + 1 + input_ids_1_ = input_ids_1[start:start + chunk_len] + input_ids_1_ = Tensor(input_ids_1_, mstype.int32) + token_type_ids_1_ = token_type_ids_1[start:start + chunk_len] + token_type_ids_1_ = Tensor(token_type_ids_1_, mstype.int32) + input_mask_1_ = input_mask_1[start:start + chunk_len] + input_mask_1_ = Tensor(input_mask_1_, mstype.int32) + cls_out = onehop(input_ids_1_, token_type_ids_1_, input_mask_1_) + if start == 0: + out = cls_out + else: + out = P.Concat(0)((out, cls_out)) + start = end + 1 + out = P.Squeeze(1)(out) + onehop_prob, onehop_index = P.TopK(sorted=True)(out, config.topk) + onehop_prob = P.Softmax()(onehop_prob) + sample, path_raw, last_out = data_generator.get_samples(query, onehop_index, onehop_prob) + input_ids_2, token_type_ids_2, input_mask_2 = data_generator.convert_twohop_to_features(sample) + start_2 = 0 + TOTAL_2 = len(input_ids_2) + split_chunk = 8 + while start_2 < TOTAL_2: + end_2 = min(start_2 + split_chunk - 1, TOTAL_2 - 1) + chunk_len = end_2 - start_2 + 1 + input_ids_2_ = input_ids_2[start_2:start_2 + chunk_len] + input_ids_2_ = Tensor(input_ids_2_, mstype.int32) + token_type_ids_2_ = token_type_ids_2[start_2:start_2 + chunk_len] + token_type_ids_2_ = Tensor(token_type_ids_2_, mstype.int32) + input_mask_2_ = input_mask_2[start_2:start_2 + chunk_len] + input_mask_2_ = Tensor(input_mask_2_, mstype.int32) + cls_out = twohop(input_ids_2_, token_type_ids_2_, input_mask_2_) + if start_2 == 0: + out_2 = cls_out + else: + out_2 = P.Concat(0)((out_2, cls_out)) + start_2 = end_2 + 1 + out_2 = P.Softmax()(out_2) + last_out = Tensor(last_out, mstype.float32) + out_2 = P.Mul()(out_2, last_out) + val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count) + temp_dict['topk_titles'] = topk_titles + output_path.append(temp_dict) + count += 1 + print("val:", val) + print("count:", count) + print("true count:", true_count) + if count: + print("PEM:", val / count) + if true_count: + print("true top8 PEM:", val / true_count) + step += 1 + save_json(output_path, config.save_path, config.save_name) + print("evaluation time (h):", (time.time() - s_tr) / 3600) + + +if __name__ == "__main__": + config = ThinkRetrieverConfig() + context.set_context(mode=context.GRAPH_MODE, + device_target='Ascend', + device_id=config.device_id, + save_graphs=False) + evaluation() diff --git a/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh new file mode 100644 index 00000000000..4d1f2e5c1fd --- /dev/null +++ b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# eval script + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval + +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start evaluation" + +python retriever_eval.py > log.txt 2>&1 & + +cd .. diff --git a/model_zoo/research/nlp/tprr/src/config.py b/model_zoo/research/nlp/tprr/src/config.py new file mode 100644 index 00000000000..9af4e5d46f7 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/config.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Retriever Config. + +""" +import argparse + + +def ThinkRetrieverConfig(): + """retriever config""" + parser = argparse.ArgumentParser() + parser.add_argument("--q_len", type=int, default=64, help="max query len") + parser.add_argument("--d_len", type=int, default=192, help="max doc len") + parser.add_argument("--s_len", type=int, default=448, help="max seq len") + parser.add_argument("--in_len", type=int, default=768, help="in len") + parser.add_argument("--out_len", type=int, default=1, help="out len") + parser.add_argument("--num_docs", type=int, default=500, help="docs num") + parser.add_argument("--topk", type=int, default=8, help="top num") + parser.add_argument("--onehop_num", type=int, default=8, help="onehop num") + parser.add_argument("--batch_size", type=int, default=1, help="batch size") + parser.add_argument("--device_id", type=int, default=0, help="device id") + parser.add_argument("--save_name", type=str, default='doc_path', help='name of output') + parser.add_argument("--save_path", type=str, default='./', help='path of output') + parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path") + parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path") + parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json', + help="dev path") + parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path") + parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path") + parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path") + parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path") + parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path") + return parser.parse_args() diff --git a/model_zoo/research/nlp/tprr/src/onehop.py b/model_zoo/research/nlp/tprr/src/onehop.py new file mode 100644 index 00000000000..db9f9293bb6 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/onehop.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +One Hop Model. + +""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore import load_checkpoint, load_param_into_net + + +class Model(nn.Cell): + """mlp model""" + def __init__(self): + super(Model, self).__init__() + self.tanh_0 = nn.Tanh() + self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True) + + def construct(self, x): + """construct function""" + opt_tanh_0 = self.tanh_0(x) + opt_dense_1 = self.dense_1(opt_tanh_0) + return opt_dense_1 + + +class OneHopBert(nn.Cell): + """onehop model""" + def __init__(self, config, network): + super(OneHopBert, self).__init__(auto_prefix=False) + self.network = network + self.mlp = Model() + param_dict = load_checkpoint(config.onehop_mlp_path) + load_param_into_net(self.mlp, param_dict) + self.cast = P.Cast() + + def construct(self, + input_ids, + token_type_id, + input_mask): + """construct function""" + out = self.network(input_ids, token_type_id, input_mask) + out = self.mlp(out) + out = self.cast(out, mstype.float32) + return out diff --git a/model_zoo/research/nlp/tprr/src/onehop_bert.py b/model_zoo/research/nlp/tprr/src/onehop_bert.py new file mode 100644 index 00000000000..501c4160889 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/onehop_bert.py @@ -0,0 +1,302 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +One Hop BERT. + +""" + +import numpy as np + +from mindspore import nn +from mindspore import Tensor, Parameter +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + +BATCH_SIZE = -1 + + +class LayerNorm(nn.Cell): + """layer norm""" + def __init__(self): + super(LayerNorm, self).__init__() + self.reducemean_0 = P.ReduceMean(keep_dims=True) + self.sub_1 = P.Sub() + self.cast_2 = P.Cast() + self.cast_2_to = mstype.float32 + self.pow_3 = P.Pow() + self.pow_3_input_weight = 2.0 + self.reducemean_4 = P.ReduceMean(keep_dims=True) + self.add_5 = P.Add() + self.add_5_bias = 9.999999960041972e-13 + self.sqrt_6 = P.Sqrt() + self.div_7 = P.Div() + self.mul_8 = P.Mul() + self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_9 = P.Add() + self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + + def construct(self, x): + """construct function""" + opt_reducemean_0 = self.reducemean_0(x, -1) + opt_sub_1 = self.sub_1(x, opt_reducemean_0) + opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to) + opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight) + opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) + opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) + opt_sqrt_6 = self.sqrt_6(opt_add_5) + opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6) + opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w) + opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) + return opt_add_9 + + +class MultiHeadAttn(nn.Cell): + """multi head attention layer""" + def __init__(self): + super(MultiHeadAttn, self).__init__() + self.matmul_0 = nn.MatMul() + self.matmul_0.to_float(mstype.float16) + self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.matmul_1 = nn.MatMul() + self.matmul_1.to_float(mstype.float16) + self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.matmul_2 = nn.MatMul() + self.matmul_2.to_float(mstype.float16) + self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.add_3 = P.Add() + self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_4 = P.Add() + self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_5 = P.Add() + self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.reshape_6 = P.Reshape() + self.reshape_6_shape = tuple([BATCH_SIZE, 256, 12, 64]) + self.reshape_7 = P.Reshape() + self.reshape_7_shape = tuple([BATCH_SIZE, 256, 12, 64]) + self.reshape_8 = P.Reshape() + self.reshape_8_shape = tuple([BATCH_SIZE, 256, 12, 64]) + self.transpose_9 = P.Transpose() + self.transpose_10 = P.Transpose() + self.transpose_11 = P.Transpose() + self.matmul_12 = nn.MatMul() + self.matmul_12.to_float(mstype.float16) + self.div_13 = P.Div() + self.div_13_w = 8.0 + self.add_14 = P.Add() + self.softmax_15 = nn.Softmax(axis=3) + self.matmul_16 = nn.MatMul() + self.matmul_16.to_float(mstype.float16) + self.transpose_17 = P.Transpose() + self.reshape_18 = P.Reshape() + self.reshape_18_shape = tuple([BATCH_SIZE, 256, 768]) + self.matmul_19 = nn.MatMul() + self.matmul_19.to_float(mstype.float16) + self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.add_20 = P.Add() + self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + + def construct(self, x, x0): + """construct function""" + opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) + opt_matmul_1 = self.matmul_1(x, self.matmul_1_w) + opt_matmul_2 = self.matmul_2(x, self.matmul_2_w) + opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) + opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32) + opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32) + opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias) + opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias) + opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias) + opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) + opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) + opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) + opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) + opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) + opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) + opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10) + opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32) + opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w) + opt_add_14 = self.add_14(opt_div_13, x0) + opt_add_14 = P.Cast()(opt_add_14, mstype.float32) + opt_softmax_15 = self.softmax_15(opt_add_14) + opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11) + opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32) + opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3)) + opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape) + opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w) + opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32) + opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias) + return opt_add_20 + + +class Linear(nn.Cell): + """linear layer""" + def __init__(self, matmul_0_weight_shape, add_1_bias_shape): + super(Linear, self).__init__() + self.matmul_0 = nn.MatMul() + self.matmul_0.to_float(mstype.float16) + self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), + name=None) + self.add_1 = P.Add() + self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) + + def construct(self, x): + """construct function""" + opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) + opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) + opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias) + return opt_add_1 + + +class GeLU(nn.Cell): + """gelu layer""" + def __init__(self): + super(GeLU, self).__init__() + self.div_0 = P.Div() + self.div_0_w = 1.4142135381698608 + self.erf_1 = P.Erf() + self.add_2 = P.Add() + self.add_2_bias = 1.0 + self.mul_3 = P.Mul() + self.mul_4 = P.Mul() + self.mul_4_w = 0.5 + + def construct(self, x): + """construct function""" + opt_div_0 = self.div_0(x, self.div_0_w) + opt_erf_1 = self.erf_1(opt_div_0) + opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias) + opt_mul_3 = self.mul_3(x, opt_add_2) + opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w) + return opt_mul_4 + + +class TransformerLayer(nn.Cell): + """transformer layer""" + def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, + linear3_1_add_1_bias_shape): + super(TransformerLayer, self).__init__() + self.multiheadattn_0 = MultiHeadAttn() + self.add_0 = P.Add() + self.layernorm1_0 = LayerNorm() + self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, + add_1_bias_shape=linear3_0_add_1_bias_shape) + self.gelu1_0 = GeLU() + self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, + add_1_bias_shape=linear3_1_add_1_bias_shape) + self.add_1 = P.Add() + self.layernorm1_1 = LayerNorm() + + def construct(self, x, x0): + """construct function""" + multiheadattn_0_opt = self.multiheadattn_0(x, x0) + opt_add_0 = self.add_0(multiheadattn_0_opt, x) + layernorm1_0_opt = self.layernorm1_0(opt_add_0) + linear3_0_opt = self.linear3_0(layernorm1_0_opt) + gelu1_0_opt = self.gelu1_0(linear3_0_opt) + linear3_1_opt = self.linear3_1(gelu1_0_opt) + opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) + layernorm1_1_opt = self.layernorm1_1(opt_add_1) + return layernorm1_1_opt + + +class Encoder1_4(nn.Cell): + """encoder layer""" + def __init__(self): + super(Encoder1_4, self).__init__() + self.module47_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module47_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module47_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module47_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + + def construct(self, x, x0): + """construct function""" + module47_0_opt = self.module47_0(x, x0) + module47_1_opt = self.module47_1(module47_0_opt, x0) + module47_2_opt = self.module47_2(module47_1_opt, x0) + module47_3_opt = self.module47_3(module47_2_opt, x0) + return module47_3_opt + + +class ModelOneHop(nn.Cell): + """one hop layer""" + def __init__(self): + super(ModelOneHop, self).__init__() + self.expanddims_0 = P.ExpandDims() + self.expanddims_0_axis = 1 + self.expanddims_3 = P.ExpandDims() + self.expanddims_3_axis = 2 + self.cast_5 = P.Cast() + self.cast_5_to = mstype.float32 + self.sub_7 = P.Sub() + self.sub_7_bias = 1.0 + self.mul_9 = P.Mul() + self.mul_9_w = -10000.0 + self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)), + name=None) + self.gather_1_axis = 0 + self.gather_1 = P.Gather() + self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None) + self.gather_2_axis = 0 + self.gather_2 = P.Gather() + self.add_4 = P.Add() + self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 256, 768)).astype(np.float32)), name=None) + self.add_6 = P.Add() + self.layernorm1_0 = LayerNorm() + self.module51_0 = Encoder1_4() + self.module51_1 = Encoder1_4() + self.module51_2 = Encoder1_4() + self.gather_643_input_weight = Tensor(np.array(0)) + self.gather_643_axis = 1 + self.gather_643 = P.Gather() + self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True) + self.tanh_645 = nn.Tanh() + + def construct(self, input_ids, token_type_ids, attention_mask): + """construct function""" + input_ids = self.cast_5(input_ids, mstype.int32) + token_type_ids = self.cast_5(token_type_ids, mstype.int32) + attention_mask = self.cast_5(attention_mask, mstype.int32) + opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) + opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) + opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) + opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) + opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) + opt_gather_1_axis = self.gather_1_axis + opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) + opt_gather_2_axis = self.gather_2_axis + opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) + opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias) + opt_add_6 = self.add_6(opt_add_4, opt_gather_2) + layernorm1_0_opt = self.layernorm1_0(opt_add_6) + module51_0_opt = self.module51_0(layernorm1_0_opt, opt_mul_9) + module51_1_opt = self.module51_1(module51_0_opt, opt_mul_9) + module51_2_opt = self.module51_2(module51_1_opt, opt_mul_9) + opt_gather_643_axis = self.gather_643_axis + opt_gather_643 = self.gather_643(module51_2_opt, self.gather_643_input_weight, opt_gather_643_axis) + opt_dense_644 = self.dense_644(opt_gather_643) + opt_tanh_645 = self.tanh_645(opt_dense_644) + return opt_tanh_645 diff --git a/model_zoo/research/nlp/tprr/src/process_data.py b/model_zoo/research/nlp/tprr/src/process_data.py new file mode 100644 index 00000000000..e8028d10361 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/process_data.py @@ -0,0 +1,254 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Process Data. + +""" + +import json +import pickle as pkl + +from transformers import BertTokenizer + +from src.utils import get_new_title, get_raw_title + + +class DataGen: + """data generator""" + + def __init__(self, config): + """init function""" + self.wiki_path = config.wiki_path + self.dev_path = config.dev_path + self.dev_data_path = config.dev_data_path + self.num_docs = config.num_docs + self.max_q_len = config.q_len + self.max_doc_len = config.d_len + self.max_seq_len2 = config.s_len + self.vocab = config.vocab_path + self.onehop_num = config.onehop_num + + self.data_db, self.dev_data, self.q_doc_text = self.load_data() + self.query2id, self.q_gold = self.process_data() + self.id2title, self.id2doc, self.query_id_list, self.id2query = self.load_id2() + + def load_data(self): + """load data""" + print('********************** loading data ********************** ') + f_wiki = open(self.wiki_path, 'rb') + f_train = open(self.dev_path, 'rb') + f_doc = open(self.dev_data_path, 'rb') + data_db = pkl.load(f_wiki, encoding="gbk") + dev_data = json.load(f_train) + q_doc_text = pkl.load(f_doc, encoding='gbk') + return data_db, dev_data, q_doc_text + + def process_data(self): + """process data""" + query2id = {} + q_gold = {} + for onedata in self.dev_data: + if onedata['question'] not in query2id: + q_gold[onedata['_id']] = {} + query2id[onedata['question']] = onedata['_id'] + gold_path = [] + for item in onedata['path']: + gold_path.append(get_raw_title(item)) + q_gold[onedata['_id']]['title'] = gold_path + gold_text = [] + for item in gold_path: + gold_text.append(self.data_db[get_new_title(item)]['text']) + q_gold[onedata['_id']]['text'] = gold_text + return query2id, q_gold + + def load_id2(self): + """load dev data""" + with open(self.dev_data_path, 'rb') as f: + temp_dev_dic = pkl.load(f, encoding='gbk') + id2title = {} + id2doc = {} + id2query = {} + query_id_list = [] + for q_id in temp_dev_dic: + id2title[q_id] = temp_dev_dic[q_id]['title'] + id2doc[q_id] = temp_dev_dic[q_id]['text'] + query_id_list.append(q_id) + id2query[q_id] = temp_dev_dic[q_id]['query'] + return id2title, id2doc, query_id_list, id2query + + def get_query2id(self, query): + """get query id""" + output_list = [] + for item in query: + output_list.append(self.query2id[item]) + return output_list + + def get_linked_text(self, title): + """get linked text""" + linked_title_list = [] + raw_title_list = self.data_db[get_new_title(title)]['linked_title'].split('\t') + for item in raw_title_list: + if item and self.data_db[get_new_title(item)].get("text"): + linked_title_list.append(get_new_title(item)) + output_twohop_list = [] + for item in linked_title_list: + output_twohop_list.append(self.data_db[get_new_title(item)]['text']) + return output_twohop_list, linked_title_list + + def convert_onehop_to_features(self, query, + cls_token='[CLS]', + sep_token='[SEP]', + pad_token=0): + """convert one hop data to features""" + query_id = self.get_query2id(query) + examples = [] + count = 0 + for item in query_id: + title_doc_list = [] + for i in range(len(self.q_doc_text[item]['text'][:self.num_docs])): + title_doc_list.append([query[count], self.q_doc_text[item]["text"][i]]) + examples += title_doc_list + count += 1 + + max_q_len = self.max_q_len + max_doc_len = self.max_doc_len + tokenizer = BertTokenizer.from_pretrained(self.vocab, do_lower_case=True) + input_ids_list = [] + token_type_ids_list = [] + attention_mask_list = [] + for _, example in enumerate(examples): + tokens_q = tokenizer.tokenize(example[0]) + tokens_d1 = tokenizer.tokenize(example[1]) + special_tokens_count = 2 + if len(tokens_q) > max_q_len - 1: + tokens_q = tokens_q[:(max_q_len - 1)] + if len(tokens_d1) > max_doc_len - special_tokens_count: + tokens_d1 = tokens_d1[:(max_doc_len - special_tokens_count)] + tokens_q = [cls_token] + tokens_q + tokens_d = [sep_token] + tokens_d += tokens_d1 + tokens_d += [sep_token] + + q_ids = tokenizer.convert_tokens_to_ids(tokens_q) + d_ids = tokenizer.convert_tokens_to_ids(tokens_d) + padding_length_d = max_doc_len - len(d_ids) + padding_length_q = max_q_len - len(q_ids) + input_ids = q_ids + ([pad_token] * padding_length_q) + d_ids + ([pad_token] * padding_length_d) + token_type_ids = [0] * max_q_len + token_type_ids += [1] * max_doc_len + attention_mask_id = [] + + for item in input_ids: + attention_mask_id.append(item != 0) + + input_ids_list.append(input_ids) + token_type_ids_list.append(token_type_ids) + attention_mask_list.append(attention_mask_id) + return input_ids_list, token_type_ids_list, attention_mask_list + + def convert_twohop_to_features(self, examples, + cls_token='[CLS]', + sep_token='[SEP]', + pad_token=0): + """convert two hop data to features""" + max_q_len = self.max_q_len + max_doc_len = self.max_doc_len + max_seq_len = self.max_seq_len2 + tokenizer = BertTokenizer.from_pretrained(self.vocab, do_lower_case=True) + input_ids_list = [] + token_type_ids_list = [] + attention_mask_list = [] + + for _, example in enumerate(examples): + tokens_q = tokenizer.tokenize(example[0]) + tokens_d1 = tokenizer.tokenize(example[1]) + tokens_d2 = tokenizer.tokenize(example[2]) + + special_tokens_count1 = 1 + special_tokens_count2 = 2 + + if len(tokens_q) > max_q_len - 1: + tokens_q = tokens_q[:(max_q_len - 1)] + if len(tokens_d1) > max_doc_len - special_tokens_count1: + tokens_d1 = tokens_d1[:(max_doc_len - special_tokens_count1)] + if len(tokens_d2) > max_doc_len - special_tokens_count2: + tokens_d2 = tokens_d2[:(max_doc_len - special_tokens_count2)] + tokens = [cls_token] + tokens_q + tokens += [sep_token] + tokens += tokens_d1 + tokens += [sep_token] + tokens += tokens_d2 + tokens += [sep_token] + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + padding_length = max_seq_len - len(input_ids) + input_ids = input_ids + ([pad_token] * padding_length) + + token_type_ids = [0] * (len(tokens_q) + 1) + token_type_ids += [1] * (len(tokens_d1) + 1) + token_type_ids += [1] * (max_seq_len - len(token_type_ids)) + + attention_mask_id = [] + for item in input_ids: + attention_mask_id.append(item != 0) + + input_ids_list.append(input_ids) + token_type_ids_list.append(token_type_ids) + attention_mask_list.append(attention_mask_id) + return input_ids_list, token_type_ids_list, attention_mask_list + + def get_samples(self, query, onehop_index, onehop_prob): + """get samples""" + query = self.get_query2id([query]) + index_np = onehop_index.asnumpy() + onehop_prob = onehop_prob.asnumpy() + sample = [] + path = [] + last_out = [] + q_id = query[0] + q_text = self.id2query[q_id] + onehop_ids_list = index_np + + onehop_text_list = [] + onehop_title_list = [] + for ids in list(onehop_ids_list): + onehop_text_list.append(self.id2doc[q_id][ids]) + onehop_title_list.append(self.id2title[q_id][ids]) + twohop_text_list = [] + twohop_title_list = [] + for title in onehop_title_list: + two_hop_text, two_hop_title = self.get_linked_text(title) + twohop_text_list.append(two_hop_text[:1000]) + twohop_title_list.append(two_hop_title[:1000]) + d1_count = 0 + d2_count = 0 + tiny_sample = [] + tiny_path = [] + for i in range(1, self.onehop_num): + tiny_sample.append((q_text, onehop_text_list[0], onehop_text_list[i])) + tiny_path.append((get_new_title(onehop_title_list[0]), get_new_title(onehop_title_list[i]))) + last_out.append(onehop_prob[d1_count]) + for twohop_text_tiny_list in twohop_text_list: + for twohop_text in twohop_text_tiny_list: + tiny_sample.append((q_text, onehop_text_list[d1_count], twohop_text)) + last_out.append(onehop_prob[d1_count]) + d1_count += 1 + for twohop_title_tiny_list in twohop_title_list: + for twohop_title in twohop_title_tiny_list: + tiny_path.append((get_new_title(onehop_title_list[d2_count]), get_new_title(twohop_title))) + d2_count += 1 + sample += tiny_sample + path += tiny_path + return sample, path, last_out diff --git a/model_zoo/research/nlp/tprr/src/twohop.py b/model_zoo/research/nlp/tprr/src/twohop.py new file mode 100644 index 00000000000..1ed96ffefb7 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/twohop.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Two Hop Model. + +""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore import load_checkpoint, load_param_into_net + + +class Model(nn.Cell): + """mlp model""" + def __init__(self): + super(Model, self).__init__() + self.tanh_0 = nn.Tanh() + self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True) + + def construct(self, x): + """construct function""" + opt_tanh_0 = self.tanh_0(x) + opt_dense_1 = self.dense_1(opt_tanh_0) + return opt_dense_1 + + +class TwoHopBert(nn.Cell): + """two hop model""" + def __init__(self, config, network): + super(TwoHopBert, self).__init__(auto_prefix=False) + self.network = network + self.mlp = Model() + param_dict = load_checkpoint(config.twohop_mlp_path) + load_param_into_net(self.mlp, param_dict) + self.reshape = P.Reshape() + self.cast = P.Cast() + + def construct(self, + input_ids, + token_type_id, + input_mask): + """construct function""" + out = self.network(input_ids, token_type_id, input_mask) + out = self.mlp(out) + out = self.cast(out, mstype.float32) + out = self.reshape(out, (-1,)) + return out diff --git a/model_zoo/research/nlp/tprr/src/twohop_bert.py b/model_zoo/research/nlp/tprr/src/twohop_bert.py new file mode 100644 index 00000000000..cff80d9c203 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/twohop_bert.py @@ -0,0 +1,302 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Two Hop BERT. + +""" + +import numpy as np + +from mindspore import nn +from mindspore import Tensor, Parameter +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + +BATCH_SIZE = -1 + + +class LayerNorm(nn.Cell): + """layer norm""" + def __init__(self): + super(LayerNorm, self).__init__() + self.reducemean_0 = P.ReduceMean(keep_dims=True) + self.sub_1 = P.Sub() + self.cast_2 = P.Cast() + self.cast_2_to = mstype.float32 + self.pow_3 = P.Pow() + self.pow_3_input_weight = 2.0 + self.reducemean_4 = P.ReduceMean(keep_dims=True) + self.add_5 = P.Add() + self.add_5_bias = 9.999999960041972e-13 + self.sqrt_6 = P.Sqrt() + self.div_7 = P.Div() + self.mul_8 = P.Mul() + self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_9 = P.Add() + self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + + def construct(self, x): + """construct function""" + opt_reducemean_0 = self.reducemean_0(x, -1) + opt_sub_1 = self.sub_1(x, opt_reducemean_0) + opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to) + opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight) + opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) + opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) + opt_sqrt_6 = self.sqrt_6(opt_add_5) + opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6) + opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w) + opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) + return opt_add_9 + + +class MultiHeadAttn(nn.Cell): + """multi head attention layer""" + def __init__(self): + super(MultiHeadAttn, self).__init__() + self.matmul_0 = nn.MatMul() + self.matmul_0.to_float(mstype.float16) + self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.matmul_1 = nn.MatMul() + self.matmul_1.to_float(mstype.float16) + self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.matmul_2 = nn.MatMul() + self.matmul_2.to_float(mstype.float16) + self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.add_3 = P.Add() + self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_4 = P.Add() + self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.add_5 = P.Add() + self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + self.reshape_6 = P.Reshape() + self.reshape_6_shape = tuple([BATCH_SIZE, 448, 12, 64]) + self.reshape_7 = P.Reshape() + self.reshape_7_shape = tuple([BATCH_SIZE, 448, 12, 64]) + self.reshape_8 = P.Reshape() + self.reshape_8_shape = tuple([BATCH_SIZE, 448, 12, 64]) + self.transpose_9 = P.Transpose() + self.transpose_10 = P.Transpose() + self.transpose_11 = P.Transpose() + self.matmul_12 = nn.MatMul() + self.matmul_12.to_float(mstype.float16) + self.div_13 = P.Div() + self.div_13_w = 8.0 + self.add_14 = P.Add() + self.softmax_15 = nn.Softmax(axis=3) + self.matmul_16 = nn.MatMul() + self.matmul_16.to_float(mstype.float16) + self.transpose_17 = P.Transpose() + self.reshape_18 = P.Reshape() + self.reshape_18_shape = tuple([BATCH_SIZE, 448, 768]) + self.matmul_19 = nn.MatMul() + self.matmul_19.to_float(mstype.float16) + self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) + self.add_20 = P.Add() + self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) + + def construct(self, x, x0): + """construct function""" + opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) + opt_matmul_1 = self.matmul_1(x, self.matmul_1_w) + opt_matmul_2 = self.matmul_2(x, self.matmul_2_w) + opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) + opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32) + opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32) + opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias) + opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias) + opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias) + opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) + opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) + opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) + opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) + opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) + opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) + opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10) + opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32) + opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w) + opt_add_14 = self.add_14(opt_div_13, x0) + opt_add_14 = P.Cast()(opt_add_14, mstype.float32) + opt_softmax_15 = self.softmax_15(opt_add_14) + opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11) + opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32) + opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3)) + opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape) + opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w) + opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32) + opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias) + return opt_add_20 + + +class Linear(nn.Cell): + """linear layer""" + def __init__(self, matmul_0_weight_shape, add_1_bias_shape): + super(Linear, self).__init__() + self.matmul_0 = nn.MatMul() + self.matmul_0.to_float(mstype.float16) + self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), + name=None) + self.add_1 = P.Add() + self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) + + def construct(self, x): + """construct function""" + opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) + opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) + opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias) + return opt_add_1 + + +class GeLU(nn.Cell): + """gelu layer""" + def __init__(self): + super(GeLU, self).__init__() + self.div_0 = P.Div() + self.div_0_w = 1.4142135381698608 + self.erf_1 = P.Erf() + self.add_2 = P.Add() + self.add_2_bias = 1.0 + self.mul_3 = P.Mul() + self.mul_4 = P.Mul() + self.mul_4_w = 0.5 + + def construct(self, x): + """construct function""" + opt_div_0 = self.div_0(x, self.div_0_w) + opt_erf_1 = self.erf_1(opt_div_0) + opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias) + opt_mul_3 = self.mul_3(x, opt_add_2) + opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w) + return opt_mul_4 + + +class TransformerLayer(nn.Cell): + """transformer layer""" + def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, + linear3_1_add_1_bias_shape): + super(TransformerLayer, self).__init__() + self.multiheadattn_0 = MultiHeadAttn() + self.add_0 = P.Add() + self.layernorm1_0 = LayerNorm() + self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, + add_1_bias_shape=linear3_0_add_1_bias_shape) + self.gelu1_0 = GeLU() + self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, + add_1_bias_shape=linear3_1_add_1_bias_shape) + self.add_1 = P.Add() + self.layernorm1_1 = LayerNorm() + + def construct(self, x, x0): + """construct function""" + multiheadattn_0_opt = self.multiheadattn_0(x, x0) + opt_add_0 = self.add_0(multiheadattn_0_opt, x) + layernorm1_0_opt = self.layernorm1_0(opt_add_0) + linear3_0_opt = self.linear3_0(layernorm1_0_opt) + gelu1_0_opt = self.gelu1_0(linear3_0_opt) + linear3_1_opt = self.linear3_1(gelu1_0_opt) + opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) + layernorm1_1_opt = self.layernorm1_1(opt_add_1) + return layernorm1_1_opt + + +class Encoder1_4(nn.Cell): + """encoder layer""" + def __init__(self): + super(Encoder1_4, self).__init__() + self.module46_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module46_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module46_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + self.module46_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), + linear3_0_add_1_bias_shape=(3072,), + linear3_1_matmul_0_weight_shape=(3072, 768), + linear3_1_add_1_bias_shape=(768,)) + + def construct(self, x, x0): + """construct function""" + module46_0_opt = self.module46_0(x, x0) + module46_1_opt = self.module46_1(module46_0_opt, x0) + module46_2_opt = self.module46_2(module46_1_opt, x0) + module46_3_opt = self.module46_3(module46_2_opt, x0) + return module46_3_opt + + +class ModelTwoHop(nn.Cell): + """two hop layer""" + def __init__(self): + super(ModelTwoHop, self).__init__() + self.expanddims_0 = P.ExpandDims() + self.expanddims_0_axis = 1 + self.expanddims_3 = P.ExpandDims() + self.expanddims_3_axis = 2 + self.cast_5 = P.Cast() + self.cast_5_to = mstype.float32 + self.sub_7 = P.Sub() + self.sub_7_bias = 1.0 + self.mul_9 = P.Mul() + self.mul_9_w = -10000.0 + self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)), + name=None) + self.gather_1_axis = 0 + self.gather_1 = P.Gather() + self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None) + self.gather_2_axis = 0 + self.gather_2 = P.Gather() + self.add_4 = P.Add() + self.add_6 = P.Add() + self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 448, 768)).astype(np.float32)), name=None) + self.layernorm1_0 = LayerNorm() + self.module50_0 = Encoder1_4() + self.module50_1 = Encoder1_4() + self.module50_2 = Encoder1_4() + self.gather_643_input_weight = Tensor(np.array(0)) + self.gather_643_axis = 1 + self.gather_643 = P.Gather() + self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True) + self.tanh_645 = nn.Tanh() + + def construct(self, input_ids, token_type_ids, attention_mask): + """construct function""" + input_ids = P.Cast()(input_ids, mstype.int32) + token_type_ids = P.Cast()(token_type_ids, mstype.int32) + attention_mask = P.Cast()(attention_mask, mstype.int32) + opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) + opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) + opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) + opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) + opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) + opt_gather_1_axis = self.gather_1_axis + opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) + opt_gather_2_axis = self.gather_2_axis + opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) + opt_add_4 = self.add_4(opt_gather_1, opt_gather_2) + opt_add_6 = self.add_6(opt_add_4, self.add_6_bias) + layernorm1_0_opt = self.layernorm1_0(opt_add_6) + module50_0_opt = self.module50_0(layernorm1_0_opt, opt_mul_9) + module50_1_opt = self.module50_1(module50_0_opt, opt_mul_9) + module50_2_opt = self.module50_2(module50_1_opt, opt_mul_9) + opt_gather_643_axis = self.gather_643_axis + opt_gather_643 = self.gather_643(module50_2_opt, self.gather_643_input_weight, opt_gather_643_axis) + opt_dense_644 = self.dense_644(opt_gather_643) + opt_tanh_645 = self.tanh_645(opt_dense_644) + return opt_tanh_645 diff --git a/model_zoo/research/nlp/tprr/src/utils.py b/model_zoo/research/nlp/tprr/src/utils.py new file mode 100644 index 00000000000..cb3f94746e5 --- /dev/null +++ b/model_zoo/research/nlp/tprr/src/utils.py @@ -0,0 +1,62 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Retriever Utils. + +""" + +import json +import unicodedata +import pickle as pkl + + +def normalize(text): + """normalize text""" + text = unicodedata.normalize('NFD', text) + return text[0].capitalize() + text[1:] + + +def read_query(config): + """get query data""" + with open(config.dev_data_path, 'rb') as f: + temp_dic = pkl.load(f, encoding='gbk') + queries = [] + for item in temp_dic: + queries.append(temp_dic[item]["query"]) + return queries + + +def split_queries(config, queries): + batch_size = config.batch_size + batch_queries = [queries[i:i + batch_size] for i in range(0, len(queries), batch_size)] + return batch_queries + + +def save_json(obj, path, name): + with open(path + name, "w") as f: + return json.dump(obj, f) + +def get_new_title(title): + """get new title""" + if title[-2:] == "_0": + return normalize(title[:-2]) + "_0" + return normalize(title) + "_0" + + +def get_raw_title(title): + """get raw title""" + if title[-2:] == "_0": + return title[:-2] + return title