forked from mindspore-Ecosystem/mindspore
add tprr retriever model
This commit is contained in:
parent
c12abe7a46
commit
5d420716f9
|
@ -0,0 +1,164 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [Thinking Path Re-Ranker](#thinking-path-re-ranker)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Description of random situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [Thinking Path Re-Ranker](#contents)
|
||||||
|
|
||||||
|
Thinking Path Re-Ranker(TPRR) was proposed in 2021 by Huawei Poisson Lab & Parallel Distributed Computing Lab. By incorporating the
|
||||||
|
retriever, reranker and reader modules, TPRR shows excellent performance on open-domain multi-hop question answering. Moreover, TPRR has won
|
||||||
|
the first place in the current HotpotQA official leaderboard. This is a example of evaluation of TPRR with HotPotQA dataset in MindSpore. More
|
||||||
|
importantly, this is the first open source version for TPRR.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Specially, TPRR contains three main modules. The first is retriever, which generate document sequences of each hop iteratively. The second
|
||||||
|
is reranker for selecting the best path from candidate paths generated by retriever. The last one is reader for extracting answer spans.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
The retriever dataset consists of three parts:
|
||||||
|
Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyperlinks.
|
||||||
|
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
|
||||||
|
dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia
|
||||||
|
through TF-IDF.
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## [Mixed Precision](#contents)
|
||||||
|
|
||||||
|
To ultilize the strong computation power of Ascend chip, and accelerate the evaluation process, the mixed evaluation method is used. MindSpore
|
||||||
|
is able to cope with FP32 inputs and FP16 operators. In TPRR example, the model is set to FP16 mode for the matmul calculation part.
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware (Ascend)
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run evaluation example with HotPotQA dev dataset
|
||||||
|
sh run_eval_ascend.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─tprr
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
| ├─run_eval_ascend.sh # Launch evaluation in ascend
|
||||||
|
|
|
||||||
|
├─src
|
||||||
|
| ├─config.py # Evaluation configurations
|
||||||
|
| ├─onehop.py # Onehop model
|
||||||
|
| ├─onehop_bert.py # Onehop bert model
|
||||||
|
| ├─process_data.py # Data preprocessing
|
||||||
|
| ├─twohop.py # Twohop model
|
||||||
|
| ├─twohop_bert.py # Twohop bert model
|
||||||
|
| └─utils.py # Utils for evaluation
|
||||||
|
|
|
||||||
|
└─retriever_eval.py # Evaluation net for retriever
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for evaluation can be set in config.py.
|
||||||
|
|
||||||
|
- config for TPRR retriever dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
"q_len": 64, # Max query length
|
||||||
|
"d_len": 192, # Max doc length
|
||||||
|
"s_len": 448, # Max sequence length
|
||||||
|
"in_len": 768, # Input dim
|
||||||
|
"out_len": 1, # Output dim
|
||||||
|
"num_docs": 500, # Num of docs
|
||||||
|
"topk": 8, # Top k
|
||||||
|
"onehop_num": 8 # Num of onehop doc as twohop neighbor
|
||||||
|
```
|
||||||
|
|
||||||
|
config.py for more configuration.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
- Evaluation on Ascend
|
||||||
|
|
||||||
|
```python
|
||||||
|
sh run_eval_ascend.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
|
||||||
|
followings in log.
|
||||||
|
|
||||||
|
```python
|
||||||
|
###step###: 0
|
||||||
|
val: 0
|
||||||
|
count: 1
|
||||||
|
true count: 0
|
||||||
|
PEM: 0.0
|
||||||
|
|
||||||
|
...
|
||||||
|
###step###: 7396
|
||||||
|
val:6796
|
||||||
|
count:7397
|
||||||
|
true count: 6924
|
||||||
|
PEM: 0.9187508449371367
|
||||||
|
true top8 PEM: 0.9815135759676488
|
||||||
|
evaluation time (h): 20.155506462653477
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Inference Performance
|
||||||
|
|
||||||
|
| Parameter | BGCF Ascend |
|
||||||
|
| ------------------------------ | ---------------------------- |
|
||||||
|
| Model Version | Inception V1 |
|
||||||
|
| Resource | Ascend 910 |
|
||||||
|
| uploaded Date | 03/12/2021(month/day/year) |
|
||||||
|
| MindSpore Version | 1.2.0 |
|
||||||
|
| Dataset | HotPotQA |
|
||||||
|
| Batch_size | 1 |
|
||||||
|
| Output | inference path |
|
||||||
|
| PEM | 0.9188 |
|
||||||
|
|
||||||
|
# [Description of random situation](#contents)
|
||||||
|
|
||||||
|
No random situation for evaluation.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Retriever Evaluation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Tensor
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.onehop import OneHopBert
|
||||||
|
from src.twohop import TwoHopBert
|
||||||
|
from src.process_data import DataGen
|
||||||
|
from src.onehop_bert import ModelOneHop
|
||||||
|
from src.twohop_bert import ModelTwoHop
|
||||||
|
from src.config import ThinkRetrieverConfig
|
||||||
|
from src.utils import read_query, split_queries, get_new_title, get_raw_title, save_json
|
||||||
|
|
||||||
|
|
||||||
|
def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
|
||||||
|
"""evaluation output"""
|
||||||
|
y_pred_raw = out_2.asnumpy()
|
||||||
|
last_out_raw = last_out.asnumpy()
|
||||||
|
path = []
|
||||||
|
y_pred = []
|
||||||
|
last_out_list = []
|
||||||
|
topk_titles = []
|
||||||
|
index_list_raw = np.argsort(y_pred_raw)
|
||||||
|
for index_r in index_list_raw[::-1]:
|
||||||
|
tag = 1
|
||||||
|
for raw_path in path:
|
||||||
|
if path_raw[index_r][0] in raw_path and path_raw[index_r][1] in raw_path:
|
||||||
|
tag = 0
|
||||||
|
break
|
||||||
|
if tag:
|
||||||
|
path.append(path_raw[index_r])
|
||||||
|
y_pred.append(y_pred_raw[index_r])
|
||||||
|
last_out_list.append(last_out_raw[index_r])
|
||||||
|
index_list = np.argsort(y_pred)
|
||||||
|
for path_index in index_list:
|
||||||
|
if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]:
|
||||||
|
true_count += 1
|
||||||
|
break
|
||||||
|
for path_index in index_list[-8:][::-1]:
|
||||||
|
topk_titles.append(list(path[path_index]))
|
||||||
|
for path_index in index_list[-8:]:
|
||||||
|
if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]:
|
||||||
|
val += 1
|
||||||
|
break
|
||||||
|
return val, true_count, topk_titles
|
||||||
|
|
||||||
|
|
||||||
|
def evaluation():
|
||||||
|
"""evaluation"""
|
||||||
|
print('********************** loading corpus ********************** ')
|
||||||
|
s_lc = time.time()
|
||||||
|
data_generator = DataGen(config)
|
||||||
|
queries = read_query(config)
|
||||||
|
print("loading corpus time (h):", (time.time() - s_lc) / 3600)
|
||||||
|
print('********************** loading model ********************** ')
|
||||||
|
s_lm = time.time()
|
||||||
|
|
||||||
|
model_onehop_bert = ModelOneHop()
|
||||||
|
param_dict = load_checkpoint(config.onehop_bert_path)
|
||||||
|
load_param_into_net(model_onehop_bert, param_dict)
|
||||||
|
model_twohop_bert = ModelTwoHop()
|
||||||
|
param_dict2 = load_checkpoint(config.twohop_bert_path)
|
||||||
|
load_param_into_net(model_twohop_bert, param_dict2)
|
||||||
|
onehop = OneHopBert(config, model_onehop_bert)
|
||||||
|
twohop = TwoHopBert(config, model_twohop_bert)
|
||||||
|
|
||||||
|
print("loading model time (h):", (time.time() - s_lm) / 3600)
|
||||||
|
print('********************** evaluation ********************** ')
|
||||||
|
s_tr = time.time()
|
||||||
|
|
||||||
|
f_dev = open(config.dev_path, 'rb')
|
||||||
|
dev_data = json.load(f_dev)
|
||||||
|
q_gold = {}
|
||||||
|
q_2id = {}
|
||||||
|
for onedata in dev_data:
|
||||||
|
if onedata["question"] not in q_gold:
|
||||||
|
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
|
||||||
|
q_2id[onedata["question"]] = onedata['_id']
|
||||||
|
val, true_count, count, step = 0, 0, 0, 0
|
||||||
|
batch_queries = split_queries(config, queries)[:-1]
|
||||||
|
output_path = []
|
||||||
|
for _, batch in enumerate(batch_queries):
|
||||||
|
print("###step###: ", step)
|
||||||
|
query = batch[0]
|
||||||
|
temp_dict = {}
|
||||||
|
temp_dict['q_id'] = q_2id[query]
|
||||||
|
temp_dict['question'] = query
|
||||||
|
gold_path = q_gold[query]
|
||||||
|
input_ids_1, token_type_ids_1, input_mask_1 = data_generator.convert_onehop_to_features(batch)
|
||||||
|
start = 0
|
||||||
|
TOTAL = len(input_ids_1)
|
||||||
|
split_chunk = 8
|
||||||
|
while start < TOTAL:
|
||||||
|
end = min(start + split_chunk - 1, TOTAL - 1)
|
||||||
|
chunk_len = end - start + 1
|
||||||
|
input_ids_1_ = input_ids_1[start:start + chunk_len]
|
||||||
|
input_ids_1_ = Tensor(input_ids_1_, mstype.int32)
|
||||||
|
token_type_ids_1_ = token_type_ids_1[start:start + chunk_len]
|
||||||
|
token_type_ids_1_ = Tensor(token_type_ids_1_, mstype.int32)
|
||||||
|
input_mask_1_ = input_mask_1[start:start + chunk_len]
|
||||||
|
input_mask_1_ = Tensor(input_mask_1_, mstype.int32)
|
||||||
|
cls_out = onehop(input_ids_1_, token_type_ids_1_, input_mask_1_)
|
||||||
|
if start == 0:
|
||||||
|
out = cls_out
|
||||||
|
else:
|
||||||
|
out = P.Concat(0)((out, cls_out))
|
||||||
|
start = end + 1
|
||||||
|
out = P.Squeeze(1)(out)
|
||||||
|
onehop_prob, onehop_index = P.TopK(sorted=True)(out, config.topk)
|
||||||
|
onehop_prob = P.Softmax()(onehop_prob)
|
||||||
|
sample, path_raw, last_out = data_generator.get_samples(query, onehop_index, onehop_prob)
|
||||||
|
input_ids_2, token_type_ids_2, input_mask_2 = data_generator.convert_twohop_to_features(sample)
|
||||||
|
start_2 = 0
|
||||||
|
TOTAL_2 = len(input_ids_2)
|
||||||
|
split_chunk = 8
|
||||||
|
while start_2 < TOTAL_2:
|
||||||
|
end_2 = min(start_2 + split_chunk - 1, TOTAL_2 - 1)
|
||||||
|
chunk_len = end_2 - start_2 + 1
|
||||||
|
input_ids_2_ = input_ids_2[start_2:start_2 + chunk_len]
|
||||||
|
input_ids_2_ = Tensor(input_ids_2_, mstype.int32)
|
||||||
|
token_type_ids_2_ = token_type_ids_2[start_2:start_2 + chunk_len]
|
||||||
|
token_type_ids_2_ = Tensor(token_type_ids_2_, mstype.int32)
|
||||||
|
input_mask_2_ = input_mask_2[start_2:start_2 + chunk_len]
|
||||||
|
input_mask_2_ = Tensor(input_mask_2_, mstype.int32)
|
||||||
|
cls_out = twohop(input_ids_2_, token_type_ids_2_, input_mask_2_)
|
||||||
|
if start_2 == 0:
|
||||||
|
out_2 = cls_out
|
||||||
|
else:
|
||||||
|
out_2 = P.Concat(0)((out_2, cls_out))
|
||||||
|
start_2 = end_2 + 1
|
||||||
|
out_2 = P.Softmax()(out_2)
|
||||||
|
last_out = Tensor(last_out, mstype.float32)
|
||||||
|
out_2 = P.Mul()(out_2, last_out)
|
||||||
|
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
|
||||||
|
temp_dict['topk_titles'] = topk_titles
|
||||||
|
output_path.append(temp_dict)
|
||||||
|
count += 1
|
||||||
|
print("val:", val)
|
||||||
|
print("count:", count)
|
||||||
|
print("true count:", true_count)
|
||||||
|
if count:
|
||||||
|
print("PEM:", val / count)
|
||||||
|
if true_count:
|
||||||
|
print("true top8 PEM:", val / true_count)
|
||||||
|
step += 1
|
||||||
|
save_json(output_path, config.save_path, config.save_name)
|
||||||
|
print("evaluation time (h):", (time.time() - s_tr) / 3600)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = ThinkRetrieverConfig()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target='Ascend',
|
||||||
|
device_id=config.device_id,
|
||||||
|
save_graphs=False)
|
||||||
|
evaluation()
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# eval script
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start evaluation"
|
||||||
|
|
||||||
|
python retriever_eval.py > log.txt 2>&1 &
|
||||||
|
|
||||||
|
cd ..
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Retriever Config.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def ThinkRetrieverConfig():
|
||||||
|
"""retriever config"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--q_len", type=int, default=64, help="max query len")
|
||||||
|
parser.add_argument("--d_len", type=int, default=192, help="max doc len")
|
||||||
|
parser.add_argument("--s_len", type=int, default=448, help="max seq len")
|
||||||
|
parser.add_argument("--in_len", type=int, default=768, help="in len")
|
||||||
|
parser.add_argument("--out_len", type=int, default=1, help="out len")
|
||||||
|
parser.add_argument("--num_docs", type=int, default=500, help="docs num")
|
||||||
|
parser.add_argument("--topk", type=int, default=8, help="top num")
|
||||||
|
parser.add_argument("--onehop_num", type=int, default=8, help="onehop num")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||||
|
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
|
||||||
|
parser.add_argument("--save_path", type=str, default='./', help='path of output')
|
||||||
|
parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path")
|
||||||
|
parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path")
|
||||||
|
parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json',
|
||||||
|
help="dev path")
|
||||||
|
parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path")
|
||||||
|
parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path")
|
||||||
|
parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path")
|
||||||
|
parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path")
|
||||||
|
parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path")
|
||||||
|
return parser.parse_args()
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
One Hop Model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Cell):
|
||||||
|
"""mlp model"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.tanh_0 = nn.Tanh()
|
||||||
|
self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_tanh_0 = self.tanh_0(x)
|
||||||
|
opt_dense_1 = self.dense_1(opt_tanh_0)
|
||||||
|
return opt_dense_1
|
||||||
|
|
||||||
|
|
||||||
|
class OneHopBert(nn.Cell):
|
||||||
|
"""onehop model"""
|
||||||
|
def __init__(self, config, network):
|
||||||
|
super(OneHopBert, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.mlp = Model()
|
||||||
|
param_dict = load_checkpoint(config.onehop_mlp_path)
|
||||||
|
load_param_into_net(self.mlp, param_dict)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
token_type_id,
|
||||||
|
input_mask):
|
||||||
|
"""construct function"""
|
||||||
|
out = self.network(input_ids, token_type_id, input_mask)
|
||||||
|
out = self.mlp(out)
|
||||||
|
out = self.cast(out, mstype.float32)
|
||||||
|
return out
|
|
@ -0,0 +1,302 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
One Hop BERT.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
BATCH_SIZE = -1
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Cell):
|
||||||
|
"""layer norm"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.reducemean_0 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.sub_1 = P.Sub()
|
||||||
|
self.cast_2 = P.Cast()
|
||||||
|
self.cast_2_to = mstype.float32
|
||||||
|
self.pow_3 = P.Pow()
|
||||||
|
self.pow_3_input_weight = 2.0
|
||||||
|
self.reducemean_4 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.add_5 = P.Add()
|
||||||
|
self.add_5_bias = 9.999999960041972e-13
|
||||||
|
self.sqrt_6 = P.Sqrt()
|
||||||
|
self.div_7 = P.Div()
|
||||||
|
self.mul_8 = P.Mul()
|
||||||
|
self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_9 = P.Add()
|
||||||
|
self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_reducemean_0 = self.reducemean_0(x, -1)
|
||||||
|
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
|
||||||
|
opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to)
|
||||||
|
opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight)
|
||||||
|
opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
|
||||||
|
opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
|
||||||
|
opt_sqrt_6 = self.sqrt_6(opt_add_5)
|
||||||
|
opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6)
|
||||||
|
opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w)
|
||||||
|
opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
|
||||||
|
return opt_add_9
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttn(nn.Cell):
|
||||||
|
"""multi head attention layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiHeadAttn, self).__init__()
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0.to_float(mstype.float16)
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.matmul_1 = nn.MatMul()
|
||||||
|
self.matmul_1.to_float(mstype.float16)
|
||||||
|
self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.matmul_2 = nn.MatMul()
|
||||||
|
self.matmul_2.to_float(mstype.float16)
|
||||||
|
self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.add_3 = P.Add()
|
||||||
|
self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_4 = P.Add()
|
||||||
|
self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_5 = P.Add()
|
||||||
|
self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.reshape_6 = P.Reshape()
|
||||||
|
self.reshape_6_shape = tuple([BATCH_SIZE, 256, 12, 64])
|
||||||
|
self.reshape_7 = P.Reshape()
|
||||||
|
self.reshape_7_shape = tuple([BATCH_SIZE, 256, 12, 64])
|
||||||
|
self.reshape_8 = P.Reshape()
|
||||||
|
self.reshape_8_shape = tuple([BATCH_SIZE, 256, 12, 64])
|
||||||
|
self.transpose_9 = P.Transpose()
|
||||||
|
self.transpose_10 = P.Transpose()
|
||||||
|
self.transpose_11 = P.Transpose()
|
||||||
|
self.matmul_12 = nn.MatMul()
|
||||||
|
self.matmul_12.to_float(mstype.float16)
|
||||||
|
self.div_13 = P.Div()
|
||||||
|
self.div_13_w = 8.0
|
||||||
|
self.add_14 = P.Add()
|
||||||
|
self.softmax_15 = nn.Softmax(axis=3)
|
||||||
|
self.matmul_16 = nn.MatMul()
|
||||||
|
self.matmul_16.to_float(mstype.float16)
|
||||||
|
self.transpose_17 = P.Transpose()
|
||||||
|
self.reshape_18 = P.Reshape()
|
||||||
|
self.reshape_18_shape = tuple([BATCH_SIZE, 256, 768])
|
||||||
|
self.matmul_19 = nn.MatMul()
|
||||||
|
self.matmul_19.to_float(mstype.float16)
|
||||||
|
self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.add_20 = P.Add()
|
||||||
|
self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
opt_matmul_0 = self.matmul_0(x, self.matmul_0_w)
|
||||||
|
opt_matmul_1 = self.matmul_1(x, self.matmul_1_w)
|
||||||
|
opt_matmul_2 = self.matmul_2(x, self.matmul_2_w)
|
||||||
|
opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32)
|
||||||
|
opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32)
|
||||||
|
opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32)
|
||||||
|
opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias)
|
||||||
|
opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias)
|
||||||
|
opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias)
|
||||||
|
opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
|
||||||
|
opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
|
||||||
|
opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
|
||||||
|
opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
|
||||||
|
opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
|
||||||
|
opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
|
||||||
|
opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10)
|
||||||
|
opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32)
|
||||||
|
opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w)
|
||||||
|
opt_add_14 = self.add_14(opt_div_13, x0)
|
||||||
|
opt_add_14 = P.Cast()(opt_add_14, mstype.float32)
|
||||||
|
opt_softmax_15 = self.softmax_15(opt_add_14)
|
||||||
|
opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11)
|
||||||
|
opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32)
|
||||||
|
opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3))
|
||||||
|
opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape)
|
||||||
|
opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w)
|
||||||
|
opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32)
|
||||||
|
opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias)
|
||||||
|
return opt_add_20
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Cell):
|
||||||
|
"""linear layer"""
|
||||||
|
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
|
||||||
|
super(Linear, self).__init__()
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0.to_float(mstype.float16)
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
|
||||||
|
name=None)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_matmul_0 = self.matmul_0(x, self.matmul_0_w)
|
||||||
|
opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32)
|
||||||
|
opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias)
|
||||||
|
return opt_add_1
|
||||||
|
|
||||||
|
|
||||||
|
class GeLU(nn.Cell):
|
||||||
|
"""gelu layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(GeLU, self).__init__()
|
||||||
|
self.div_0 = P.Div()
|
||||||
|
self.div_0_w = 1.4142135381698608
|
||||||
|
self.erf_1 = P.Erf()
|
||||||
|
self.add_2 = P.Add()
|
||||||
|
self.add_2_bias = 1.0
|
||||||
|
self.mul_3 = P.Mul()
|
||||||
|
self.mul_4 = P.Mul()
|
||||||
|
self.mul_4_w = 0.5
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_div_0 = self.div_0(x, self.div_0_w)
|
||||||
|
opt_erf_1 = self.erf_1(opt_div_0)
|
||||||
|
opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias)
|
||||||
|
opt_mul_3 = self.mul_3(x, opt_add_2)
|
||||||
|
opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w)
|
||||||
|
return opt_mul_4
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Cell):
|
||||||
|
"""transformer layer"""
|
||||||
|
def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape,
|
||||||
|
linear3_1_add_1_bias_shape):
|
||||||
|
super(TransformerLayer, self).__init__()
|
||||||
|
self.multiheadattn_0 = MultiHeadAttn()
|
||||||
|
self.add_0 = P.Add()
|
||||||
|
self.layernorm1_0 = LayerNorm()
|
||||||
|
self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=linear3_0_add_1_bias_shape)
|
||||||
|
self.gelu1_0 = GeLU()
|
||||||
|
self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=linear3_1_add_1_bias_shape)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.layernorm1_1 = LayerNorm()
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
multiheadattn_0_opt = self.multiheadattn_0(x, x0)
|
||||||
|
opt_add_0 = self.add_0(multiheadattn_0_opt, x)
|
||||||
|
layernorm1_0_opt = self.layernorm1_0(opt_add_0)
|
||||||
|
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
|
||||||
|
gelu1_0_opt = self.gelu1_0(linear3_0_opt)
|
||||||
|
linear3_1_opt = self.linear3_1(gelu1_0_opt)
|
||||||
|
opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
|
||||||
|
layernorm1_1_opt = self.layernorm1_1(opt_add_1)
|
||||||
|
return layernorm1_1_opt
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder1_4(nn.Cell):
|
||||||
|
"""encoder layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Encoder1_4, self).__init__()
|
||||||
|
self.module47_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module47_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module47_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module47_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
module47_0_opt = self.module47_0(x, x0)
|
||||||
|
module47_1_opt = self.module47_1(module47_0_opt, x0)
|
||||||
|
module47_2_opt = self.module47_2(module47_1_opt, x0)
|
||||||
|
module47_3_opt = self.module47_3(module47_2_opt, x0)
|
||||||
|
return module47_3_opt
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOneHop(nn.Cell):
|
||||||
|
"""one hop layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(ModelOneHop, self).__init__()
|
||||||
|
self.expanddims_0 = P.ExpandDims()
|
||||||
|
self.expanddims_0_axis = 1
|
||||||
|
self.expanddims_3 = P.ExpandDims()
|
||||||
|
self.expanddims_3_axis = 2
|
||||||
|
self.cast_5 = P.Cast()
|
||||||
|
self.cast_5_to = mstype.float32
|
||||||
|
self.sub_7 = P.Sub()
|
||||||
|
self.sub_7_bias = 1.0
|
||||||
|
self.mul_9 = P.Mul()
|
||||||
|
self.mul_9_w = -10000.0
|
||||||
|
self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)),
|
||||||
|
name=None)
|
||||||
|
self.gather_1_axis = 0
|
||||||
|
self.gather_1 = P.Gather()
|
||||||
|
self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None)
|
||||||
|
self.gather_2_axis = 0
|
||||||
|
self.gather_2 = P.Gather()
|
||||||
|
self.add_4 = P.Add()
|
||||||
|
self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 256, 768)).astype(np.float32)), name=None)
|
||||||
|
self.add_6 = P.Add()
|
||||||
|
self.layernorm1_0 = LayerNorm()
|
||||||
|
self.module51_0 = Encoder1_4()
|
||||||
|
self.module51_1 = Encoder1_4()
|
||||||
|
self.module51_2 = Encoder1_4()
|
||||||
|
self.gather_643_input_weight = Tensor(np.array(0))
|
||||||
|
self.gather_643_axis = 1
|
||||||
|
self.gather_643 = P.Gather()
|
||||||
|
self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True)
|
||||||
|
self.tanh_645 = nn.Tanh()
|
||||||
|
|
||||||
|
def construct(self, input_ids, token_type_ids, attention_mask):
|
||||||
|
"""construct function"""
|
||||||
|
input_ids = self.cast_5(input_ids, mstype.int32)
|
||||||
|
token_type_ids = self.cast_5(token_type_ids, mstype.int32)
|
||||||
|
attention_mask = self.cast_5(attention_mask, mstype.int32)
|
||||||
|
opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis)
|
||||||
|
opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
|
||||||
|
opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
|
||||||
|
opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
|
||||||
|
opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
|
||||||
|
opt_gather_1_axis = self.gather_1_axis
|
||||||
|
opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis)
|
||||||
|
opt_gather_2_axis = self.gather_2_axis
|
||||||
|
opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis)
|
||||||
|
opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias)
|
||||||
|
opt_add_6 = self.add_6(opt_add_4, opt_gather_2)
|
||||||
|
layernorm1_0_opt = self.layernorm1_0(opt_add_6)
|
||||||
|
module51_0_opt = self.module51_0(layernorm1_0_opt, opt_mul_9)
|
||||||
|
module51_1_opt = self.module51_1(module51_0_opt, opt_mul_9)
|
||||||
|
module51_2_opt = self.module51_2(module51_1_opt, opt_mul_9)
|
||||||
|
opt_gather_643_axis = self.gather_643_axis
|
||||||
|
opt_gather_643 = self.gather_643(module51_2_opt, self.gather_643_input_weight, opt_gather_643_axis)
|
||||||
|
opt_dense_644 = self.dense_644(opt_gather_643)
|
||||||
|
opt_tanh_645 = self.tanh_645(opt_dense_644)
|
||||||
|
return opt_tanh_645
|
|
@ -0,0 +1,254 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Process Data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle as pkl
|
||||||
|
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
from src.utils import get_new_title, get_raw_title
|
||||||
|
|
||||||
|
|
||||||
|
class DataGen:
|
||||||
|
"""data generator"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""init function"""
|
||||||
|
self.wiki_path = config.wiki_path
|
||||||
|
self.dev_path = config.dev_path
|
||||||
|
self.dev_data_path = config.dev_data_path
|
||||||
|
self.num_docs = config.num_docs
|
||||||
|
self.max_q_len = config.q_len
|
||||||
|
self.max_doc_len = config.d_len
|
||||||
|
self.max_seq_len2 = config.s_len
|
||||||
|
self.vocab = config.vocab_path
|
||||||
|
self.onehop_num = config.onehop_num
|
||||||
|
|
||||||
|
self.data_db, self.dev_data, self.q_doc_text = self.load_data()
|
||||||
|
self.query2id, self.q_gold = self.process_data()
|
||||||
|
self.id2title, self.id2doc, self.query_id_list, self.id2query = self.load_id2()
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
"""load data"""
|
||||||
|
print('********************** loading data ********************** ')
|
||||||
|
f_wiki = open(self.wiki_path, 'rb')
|
||||||
|
f_train = open(self.dev_path, 'rb')
|
||||||
|
f_doc = open(self.dev_data_path, 'rb')
|
||||||
|
data_db = pkl.load(f_wiki, encoding="gbk")
|
||||||
|
dev_data = json.load(f_train)
|
||||||
|
q_doc_text = pkl.load(f_doc, encoding='gbk')
|
||||||
|
return data_db, dev_data, q_doc_text
|
||||||
|
|
||||||
|
def process_data(self):
|
||||||
|
"""process data"""
|
||||||
|
query2id = {}
|
||||||
|
q_gold = {}
|
||||||
|
for onedata in self.dev_data:
|
||||||
|
if onedata['question'] not in query2id:
|
||||||
|
q_gold[onedata['_id']] = {}
|
||||||
|
query2id[onedata['question']] = onedata['_id']
|
||||||
|
gold_path = []
|
||||||
|
for item in onedata['path']:
|
||||||
|
gold_path.append(get_raw_title(item))
|
||||||
|
q_gold[onedata['_id']]['title'] = gold_path
|
||||||
|
gold_text = []
|
||||||
|
for item in gold_path:
|
||||||
|
gold_text.append(self.data_db[get_new_title(item)]['text'])
|
||||||
|
q_gold[onedata['_id']]['text'] = gold_text
|
||||||
|
return query2id, q_gold
|
||||||
|
|
||||||
|
def load_id2(self):
|
||||||
|
"""load dev data"""
|
||||||
|
with open(self.dev_data_path, 'rb') as f:
|
||||||
|
temp_dev_dic = pkl.load(f, encoding='gbk')
|
||||||
|
id2title = {}
|
||||||
|
id2doc = {}
|
||||||
|
id2query = {}
|
||||||
|
query_id_list = []
|
||||||
|
for q_id in temp_dev_dic:
|
||||||
|
id2title[q_id] = temp_dev_dic[q_id]['title']
|
||||||
|
id2doc[q_id] = temp_dev_dic[q_id]['text']
|
||||||
|
query_id_list.append(q_id)
|
||||||
|
id2query[q_id] = temp_dev_dic[q_id]['query']
|
||||||
|
return id2title, id2doc, query_id_list, id2query
|
||||||
|
|
||||||
|
def get_query2id(self, query):
|
||||||
|
"""get query id"""
|
||||||
|
output_list = []
|
||||||
|
for item in query:
|
||||||
|
output_list.append(self.query2id[item])
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
def get_linked_text(self, title):
|
||||||
|
"""get linked text"""
|
||||||
|
linked_title_list = []
|
||||||
|
raw_title_list = self.data_db[get_new_title(title)]['linked_title'].split('\t')
|
||||||
|
for item in raw_title_list:
|
||||||
|
if item and self.data_db[get_new_title(item)].get("text"):
|
||||||
|
linked_title_list.append(get_new_title(item))
|
||||||
|
output_twohop_list = []
|
||||||
|
for item in linked_title_list:
|
||||||
|
output_twohop_list.append(self.data_db[get_new_title(item)]['text'])
|
||||||
|
return output_twohop_list, linked_title_list
|
||||||
|
|
||||||
|
def convert_onehop_to_features(self, query,
|
||||||
|
cls_token='[CLS]',
|
||||||
|
sep_token='[SEP]',
|
||||||
|
pad_token=0):
|
||||||
|
"""convert one hop data to features"""
|
||||||
|
query_id = self.get_query2id(query)
|
||||||
|
examples = []
|
||||||
|
count = 0
|
||||||
|
for item in query_id:
|
||||||
|
title_doc_list = []
|
||||||
|
for i in range(len(self.q_doc_text[item]['text'][:self.num_docs])):
|
||||||
|
title_doc_list.append([query[count], self.q_doc_text[item]["text"][i]])
|
||||||
|
examples += title_doc_list
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
max_q_len = self.max_q_len
|
||||||
|
max_doc_len = self.max_doc_len
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(self.vocab, do_lower_case=True)
|
||||||
|
input_ids_list = []
|
||||||
|
token_type_ids_list = []
|
||||||
|
attention_mask_list = []
|
||||||
|
for _, example in enumerate(examples):
|
||||||
|
tokens_q = tokenizer.tokenize(example[0])
|
||||||
|
tokens_d1 = tokenizer.tokenize(example[1])
|
||||||
|
special_tokens_count = 2
|
||||||
|
if len(tokens_q) > max_q_len - 1:
|
||||||
|
tokens_q = tokens_q[:(max_q_len - 1)]
|
||||||
|
if len(tokens_d1) > max_doc_len - special_tokens_count:
|
||||||
|
tokens_d1 = tokens_d1[:(max_doc_len - special_tokens_count)]
|
||||||
|
tokens_q = [cls_token] + tokens_q
|
||||||
|
tokens_d = [sep_token]
|
||||||
|
tokens_d += tokens_d1
|
||||||
|
tokens_d += [sep_token]
|
||||||
|
|
||||||
|
q_ids = tokenizer.convert_tokens_to_ids(tokens_q)
|
||||||
|
d_ids = tokenizer.convert_tokens_to_ids(tokens_d)
|
||||||
|
padding_length_d = max_doc_len - len(d_ids)
|
||||||
|
padding_length_q = max_q_len - len(q_ids)
|
||||||
|
input_ids = q_ids + ([pad_token] * padding_length_q) + d_ids + ([pad_token] * padding_length_d)
|
||||||
|
token_type_ids = [0] * max_q_len
|
||||||
|
token_type_ids += [1] * max_doc_len
|
||||||
|
attention_mask_id = []
|
||||||
|
|
||||||
|
for item in input_ids:
|
||||||
|
attention_mask_id.append(item != 0)
|
||||||
|
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
token_type_ids_list.append(token_type_ids)
|
||||||
|
attention_mask_list.append(attention_mask_id)
|
||||||
|
return input_ids_list, token_type_ids_list, attention_mask_list
|
||||||
|
|
||||||
|
def convert_twohop_to_features(self, examples,
|
||||||
|
cls_token='[CLS]',
|
||||||
|
sep_token='[SEP]',
|
||||||
|
pad_token=0):
|
||||||
|
"""convert two hop data to features"""
|
||||||
|
max_q_len = self.max_q_len
|
||||||
|
max_doc_len = self.max_doc_len
|
||||||
|
max_seq_len = self.max_seq_len2
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(self.vocab, do_lower_case=True)
|
||||||
|
input_ids_list = []
|
||||||
|
token_type_ids_list = []
|
||||||
|
attention_mask_list = []
|
||||||
|
|
||||||
|
for _, example in enumerate(examples):
|
||||||
|
tokens_q = tokenizer.tokenize(example[0])
|
||||||
|
tokens_d1 = tokenizer.tokenize(example[1])
|
||||||
|
tokens_d2 = tokenizer.tokenize(example[2])
|
||||||
|
|
||||||
|
special_tokens_count1 = 1
|
||||||
|
special_tokens_count2 = 2
|
||||||
|
|
||||||
|
if len(tokens_q) > max_q_len - 1:
|
||||||
|
tokens_q = tokens_q[:(max_q_len - 1)]
|
||||||
|
if len(tokens_d1) > max_doc_len - special_tokens_count1:
|
||||||
|
tokens_d1 = tokens_d1[:(max_doc_len - special_tokens_count1)]
|
||||||
|
if len(tokens_d2) > max_doc_len - special_tokens_count2:
|
||||||
|
tokens_d2 = tokens_d2[:(max_doc_len - special_tokens_count2)]
|
||||||
|
tokens = [cls_token] + tokens_q
|
||||||
|
tokens += [sep_token]
|
||||||
|
tokens += tokens_d1
|
||||||
|
tokens += [sep_token]
|
||||||
|
tokens += tokens_d2
|
||||||
|
tokens += [sep_token]
|
||||||
|
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
padding_length = max_seq_len - len(input_ids)
|
||||||
|
input_ids = input_ids + ([pad_token] * padding_length)
|
||||||
|
|
||||||
|
token_type_ids = [0] * (len(tokens_q) + 1)
|
||||||
|
token_type_ids += [1] * (len(tokens_d1) + 1)
|
||||||
|
token_type_ids += [1] * (max_seq_len - len(token_type_ids))
|
||||||
|
|
||||||
|
attention_mask_id = []
|
||||||
|
for item in input_ids:
|
||||||
|
attention_mask_id.append(item != 0)
|
||||||
|
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
token_type_ids_list.append(token_type_ids)
|
||||||
|
attention_mask_list.append(attention_mask_id)
|
||||||
|
return input_ids_list, token_type_ids_list, attention_mask_list
|
||||||
|
|
||||||
|
def get_samples(self, query, onehop_index, onehop_prob):
|
||||||
|
"""get samples"""
|
||||||
|
query = self.get_query2id([query])
|
||||||
|
index_np = onehop_index.asnumpy()
|
||||||
|
onehop_prob = onehop_prob.asnumpy()
|
||||||
|
sample = []
|
||||||
|
path = []
|
||||||
|
last_out = []
|
||||||
|
q_id = query[0]
|
||||||
|
q_text = self.id2query[q_id]
|
||||||
|
onehop_ids_list = index_np
|
||||||
|
|
||||||
|
onehop_text_list = []
|
||||||
|
onehop_title_list = []
|
||||||
|
for ids in list(onehop_ids_list):
|
||||||
|
onehop_text_list.append(self.id2doc[q_id][ids])
|
||||||
|
onehop_title_list.append(self.id2title[q_id][ids])
|
||||||
|
twohop_text_list = []
|
||||||
|
twohop_title_list = []
|
||||||
|
for title in onehop_title_list:
|
||||||
|
two_hop_text, two_hop_title = self.get_linked_text(title)
|
||||||
|
twohop_text_list.append(two_hop_text[:1000])
|
||||||
|
twohop_title_list.append(two_hop_title[:1000])
|
||||||
|
d1_count = 0
|
||||||
|
d2_count = 0
|
||||||
|
tiny_sample = []
|
||||||
|
tiny_path = []
|
||||||
|
for i in range(1, self.onehop_num):
|
||||||
|
tiny_sample.append((q_text, onehop_text_list[0], onehop_text_list[i]))
|
||||||
|
tiny_path.append((get_new_title(onehop_title_list[0]), get_new_title(onehop_title_list[i])))
|
||||||
|
last_out.append(onehop_prob[d1_count])
|
||||||
|
for twohop_text_tiny_list in twohop_text_list:
|
||||||
|
for twohop_text in twohop_text_tiny_list:
|
||||||
|
tiny_sample.append((q_text, onehop_text_list[d1_count], twohop_text))
|
||||||
|
last_out.append(onehop_prob[d1_count])
|
||||||
|
d1_count += 1
|
||||||
|
for twohop_title_tiny_list in twohop_title_list:
|
||||||
|
for twohop_title in twohop_title_tiny_list:
|
||||||
|
tiny_path.append((get_new_title(onehop_title_list[d2_count]), get_new_title(twohop_title)))
|
||||||
|
d2_count += 1
|
||||||
|
sample += tiny_sample
|
||||||
|
path += tiny_path
|
||||||
|
return sample, path, last_out
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Two Hop Model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Cell):
|
||||||
|
"""mlp model"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.tanh_0 = nn.Tanh()
|
||||||
|
self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_tanh_0 = self.tanh_0(x)
|
||||||
|
opt_dense_1 = self.dense_1(opt_tanh_0)
|
||||||
|
return opt_dense_1
|
||||||
|
|
||||||
|
|
||||||
|
class TwoHopBert(nn.Cell):
|
||||||
|
"""two hop model"""
|
||||||
|
def __init__(self, config, network):
|
||||||
|
super(TwoHopBert, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.mlp = Model()
|
||||||
|
param_dict = load_checkpoint(config.twohop_mlp_path)
|
||||||
|
load_param_into_net(self.mlp, param_dict)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
token_type_id,
|
||||||
|
input_mask):
|
||||||
|
"""construct function"""
|
||||||
|
out = self.network(input_ids, token_type_id, input_mask)
|
||||||
|
out = self.mlp(out)
|
||||||
|
out = self.cast(out, mstype.float32)
|
||||||
|
out = self.reshape(out, (-1,))
|
||||||
|
return out
|
|
@ -0,0 +1,302 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Two Hop BERT.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
BATCH_SIZE = -1
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Cell):
|
||||||
|
"""layer norm"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.reducemean_0 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.sub_1 = P.Sub()
|
||||||
|
self.cast_2 = P.Cast()
|
||||||
|
self.cast_2_to = mstype.float32
|
||||||
|
self.pow_3 = P.Pow()
|
||||||
|
self.pow_3_input_weight = 2.0
|
||||||
|
self.reducemean_4 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.add_5 = P.Add()
|
||||||
|
self.add_5_bias = 9.999999960041972e-13
|
||||||
|
self.sqrt_6 = P.Sqrt()
|
||||||
|
self.div_7 = P.Div()
|
||||||
|
self.mul_8 = P.Mul()
|
||||||
|
self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_9 = P.Add()
|
||||||
|
self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_reducemean_0 = self.reducemean_0(x, -1)
|
||||||
|
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
|
||||||
|
opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to)
|
||||||
|
opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight)
|
||||||
|
opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
|
||||||
|
opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
|
||||||
|
opt_sqrt_6 = self.sqrt_6(opt_add_5)
|
||||||
|
opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6)
|
||||||
|
opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w)
|
||||||
|
opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
|
||||||
|
return opt_add_9
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttn(nn.Cell):
|
||||||
|
"""multi head attention layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiHeadAttn, self).__init__()
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0.to_float(mstype.float16)
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.matmul_1 = nn.MatMul()
|
||||||
|
self.matmul_1.to_float(mstype.float16)
|
||||||
|
self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.matmul_2 = nn.MatMul()
|
||||||
|
self.matmul_2.to_float(mstype.float16)
|
||||||
|
self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.add_3 = P.Add()
|
||||||
|
self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_4 = P.Add()
|
||||||
|
self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.add_5 = P.Add()
|
||||||
|
self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
self.reshape_6 = P.Reshape()
|
||||||
|
self.reshape_6_shape = tuple([BATCH_SIZE, 448, 12, 64])
|
||||||
|
self.reshape_7 = P.Reshape()
|
||||||
|
self.reshape_7_shape = tuple([BATCH_SIZE, 448, 12, 64])
|
||||||
|
self.reshape_8 = P.Reshape()
|
||||||
|
self.reshape_8_shape = tuple([BATCH_SIZE, 448, 12, 64])
|
||||||
|
self.transpose_9 = P.Transpose()
|
||||||
|
self.transpose_10 = P.Transpose()
|
||||||
|
self.transpose_11 = P.Transpose()
|
||||||
|
self.matmul_12 = nn.MatMul()
|
||||||
|
self.matmul_12.to_float(mstype.float16)
|
||||||
|
self.div_13 = P.Div()
|
||||||
|
self.div_13_w = 8.0
|
||||||
|
self.add_14 = P.Add()
|
||||||
|
self.softmax_15 = nn.Softmax(axis=3)
|
||||||
|
self.matmul_16 = nn.MatMul()
|
||||||
|
self.matmul_16.to_float(mstype.float16)
|
||||||
|
self.transpose_17 = P.Transpose()
|
||||||
|
self.reshape_18 = P.Reshape()
|
||||||
|
self.reshape_18_shape = tuple([BATCH_SIZE, 448, 768])
|
||||||
|
self.matmul_19 = nn.MatMul()
|
||||||
|
self.matmul_19.to_float(mstype.float16)
|
||||||
|
self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None)
|
||||||
|
self.add_20 = P.Add()
|
||||||
|
self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
opt_matmul_0 = self.matmul_0(x, self.matmul_0_w)
|
||||||
|
opt_matmul_1 = self.matmul_1(x, self.matmul_1_w)
|
||||||
|
opt_matmul_2 = self.matmul_2(x, self.matmul_2_w)
|
||||||
|
opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32)
|
||||||
|
opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32)
|
||||||
|
opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32)
|
||||||
|
opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias)
|
||||||
|
opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias)
|
||||||
|
opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias)
|
||||||
|
opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
|
||||||
|
opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
|
||||||
|
opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
|
||||||
|
opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
|
||||||
|
opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
|
||||||
|
opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
|
||||||
|
opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10)
|
||||||
|
opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32)
|
||||||
|
opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w)
|
||||||
|
opt_add_14 = self.add_14(opt_div_13, x0)
|
||||||
|
opt_add_14 = P.Cast()(opt_add_14, mstype.float32)
|
||||||
|
opt_softmax_15 = self.softmax_15(opt_add_14)
|
||||||
|
opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11)
|
||||||
|
opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32)
|
||||||
|
opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3))
|
||||||
|
opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape)
|
||||||
|
opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w)
|
||||||
|
opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32)
|
||||||
|
opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias)
|
||||||
|
return opt_add_20
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Cell):
|
||||||
|
"""linear layer"""
|
||||||
|
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
|
||||||
|
super(Linear, self).__init__()
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0.to_float(mstype.float16)
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
|
||||||
|
name=None)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_matmul_0 = self.matmul_0(x, self.matmul_0_w)
|
||||||
|
opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32)
|
||||||
|
opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias)
|
||||||
|
return opt_add_1
|
||||||
|
|
||||||
|
|
||||||
|
class GeLU(nn.Cell):
|
||||||
|
"""gelu layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(GeLU, self).__init__()
|
||||||
|
self.div_0 = P.Div()
|
||||||
|
self.div_0_w = 1.4142135381698608
|
||||||
|
self.erf_1 = P.Erf()
|
||||||
|
self.add_2 = P.Add()
|
||||||
|
self.add_2_bias = 1.0
|
||||||
|
self.mul_3 = P.Mul()
|
||||||
|
self.mul_4 = P.Mul()
|
||||||
|
self.mul_4_w = 0.5
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_div_0 = self.div_0(x, self.div_0_w)
|
||||||
|
opt_erf_1 = self.erf_1(opt_div_0)
|
||||||
|
opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias)
|
||||||
|
opt_mul_3 = self.mul_3(x, opt_add_2)
|
||||||
|
opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w)
|
||||||
|
return opt_mul_4
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Cell):
|
||||||
|
"""transformer layer"""
|
||||||
|
def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape,
|
||||||
|
linear3_1_add_1_bias_shape):
|
||||||
|
super(TransformerLayer, self).__init__()
|
||||||
|
self.multiheadattn_0 = MultiHeadAttn()
|
||||||
|
self.add_0 = P.Add()
|
||||||
|
self.layernorm1_0 = LayerNorm()
|
||||||
|
self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=linear3_0_add_1_bias_shape)
|
||||||
|
self.gelu1_0 = GeLU()
|
||||||
|
self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=linear3_1_add_1_bias_shape)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.layernorm1_1 = LayerNorm()
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
multiheadattn_0_opt = self.multiheadattn_0(x, x0)
|
||||||
|
opt_add_0 = self.add_0(multiheadattn_0_opt, x)
|
||||||
|
layernorm1_0_opt = self.layernorm1_0(opt_add_0)
|
||||||
|
linear3_0_opt = self.linear3_0(layernorm1_0_opt)
|
||||||
|
gelu1_0_opt = self.gelu1_0(linear3_0_opt)
|
||||||
|
linear3_1_opt = self.linear3_1(gelu1_0_opt)
|
||||||
|
opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
|
||||||
|
layernorm1_1_opt = self.layernorm1_1(opt_add_1)
|
||||||
|
return layernorm1_1_opt
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder1_4(nn.Cell):
|
||||||
|
"""encoder layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Encoder1_4, self).__init__()
|
||||||
|
self.module46_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module46_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module46_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
self.module46_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072),
|
||||||
|
linear3_0_add_1_bias_shape=(3072,),
|
||||||
|
linear3_1_matmul_0_weight_shape=(3072, 768),
|
||||||
|
linear3_1_add_1_bias_shape=(768,))
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
module46_0_opt = self.module46_0(x, x0)
|
||||||
|
module46_1_opt = self.module46_1(module46_0_opt, x0)
|
||||||
|
module46_2_opt = self.module46_2(module46_1_opt, x0)
|
||||||
|
module46_3_opt = self.module46_3(module46_2_opt, x0)
|
||||||
|
return module46_3_opt
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTwoHop(nn.Cell):
|
||||||
|
"""two hop layer"""
|
||||||
|
def __init__(self):
|
||||||
|
super(ModelTwoHop, self).__init__()
|
||||||
|
self.expanddims_0 = P.ExpandDims()
|
||||||
|
self.expanddims_0_axis = 1
|
||||||
|
self.expanddims_3 = P.ExpandDims()
|
||||||
|
self.expanddims_3_axis = 2
|
||||||
|
self.cast_5 = P.Cast()
|
||||||
|
self.cast_5_to = mstype.float32
|
||||||
|
self.sub_7 = P.Sub()
|
||||||
|
self.sub_7_bias = 1.0
|
||||||
|
self.mul_9 = P.Mul()
|
||||||
|
self.mul_9_w = -10000.0
|
||||||
|
self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)),
|
||||||
|
name=None)
|
||||||
|
self.gather_1_axis = 0
|
||||||
|
self.gather_1 = P.Gather()
|
||||||
|
self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None)
|
||||||
|
self.gather_2_axis = 0
|
||||||
|
self.gather_2 = P.Gather()
|
||||||
|
self.add_4 = P.Add()
|
||||||
|
self.add_6 = P.Add()
|
||||||
|
self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 448, 768)).astype(np.float32)), name=None)
|
||||||
|
self.layernorm1_0 = LayerNorm()
|
||||||
|
self.module50_0 = Encoder1_4()
|
||||||
|
self.module50_1 = Encoder1_4()
|
||||||
|
self.module50_2 = Encoder1_4()
|
||||||
|
self.gather_643_input_weight = Tensor(np.array(0))
|
||||||
|
self.gather_643_axis = 1
|
||||||
|
self.gather_643 = P.Gather()
|
||||||
|
self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True)
|
||||||
|
self.tanh_645 = nn.Tanh()
|
||||||
|
|
||||||
|
def construct(self, input_ids, token_type_ids, attention_mask):
|
||||||
|
"""construct function"""
|
||||||
|
input_ids = P.Cast()(input_ids, mstype.int32)
|
||||||
|
token_type_ids = P.Cast()(token_type_ids, mstype.int32)
|
||||||
|
attention_mask = P.Cast()(attention_mask, mstype.int32)
|
||||||
|
opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis)
|
||||||
|
opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
|
||||||
|
opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
|
||||||
|
opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
|
||||||
|
opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
|
||||||
|
opt_gather_1_axis = self.gather_1_axis
|
||||||
|
opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis)
|
||||||
|
opt_gather_2_axis = self.gather_2_axis
|
||||||
|
opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis)
|
||||||
|
opt_add_4 = self.add_4(opt_gather_1, opt_gather_2)
|
||||||
|
opt_add_6 = self.add_6(opt_add_4, self.add_6_bias)
|
||||||
|
layernorm1_0_opt = self.layernorm1_0(opt_add_6)
|
||||||
|
module50_0_opt = self.module50_0(layernorm1_0_opt, opt_mul_9)
|
||||||
|
module50_1_opt = self.module50_1(module50_0_opt, opt_mul_9)
|
||||||
|
module50_2_opt = self.module50_2(module50_1_opt, opt_mul_9)
|
||||||
|
opt_gather_643_axis = self.gather_643_axis
|
||||||
|
opt_gather_643 = self.gather_643(module50_2_opt, self.gather_643_input_weight, opt_gather_643_axis)
|
||||||
|
opt_dense_644 = self.dense_644(opt_gather_643)
|
||||||
|
opt_tanh_645 = self.tanh_645(opt_dense_644)
|
||||||
|
return opt_tanh_645
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Retriever Utils.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import unicodedata
|
||||||
|
import pickle as pkl
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(text):
|
||||||
|
"""normalize text"""
|
||||||
|
text = unicodedata.normalize('NFD', text)
|
||||||
|
return text[0].capitalize() + text[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def read_query(config):
|
||||||
|
"""get query data"""
|
||||||
|
with open(config.dev_data_path, 'rb') as f:
|
||||||
|
temp_dic = pkl.load(f, encoding='gbk')
|
||||||
|
queries = []
|
||||||
|
for item in temp_dic:
|
||||||
|
queries.append(temp_dic[item]["query"])
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def split_queries(config, queries):
|
||||||
|
batch_size = config.batch_size
|
||||||
|
batch_queries = [queries[i:i + batch_size] for i in range(0, len(queries), batch_size)]
|
||||||
|
return batch_queries
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(obj, path, name):
|
||||||
|
with open(path + name, "w") as f:
|
||||||
|
return json.dump(obj, f)
|
||||||
|
|
||||||
|
def get_new_title(title):
|
||||||
|
"""get new title"""
|
||||||
|
if title[-2:] == "_0":
|
||||||
|
return normalize(title[:-2]) + "_0"
|
||||||
|
return normalize(title) + "_0"
|
||||||
|
|
||||||
|
|
||||||
|
def get_raw_title(title):
|
||||||
|
"""get raw title"""
|
||||||
|
if title[-2:] == "_0":
|
||||||
|
return title[:-2]
|
||||||
|
return title
|
Loading…
Reference in New Issue