!13838 [ModelZoo]Add tprr model 8p version

From: @zhan_ke
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-25 11:22:35 +08:00 committed by Gitee
commit 0dd7003fd7
4 changed files with 47 additions and 28 deletions

View File

@ -19,6 +19,7 @@ Retriever Evaluation.
import time
import json
from multiprocessing import Pool
import numpy as np
from mindspore import Tensor
@ -69,16 +70,20 @@ def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
return val, true_count, topk_titles
def evaluation():
def evaluation(d_id):
"""evaluation"""
context.set_context(mode=context.GRAPH_MODE,
device_target='Ascend',
device_id=d_id,
save_graphs=False)
print('********************** loading corpus ********************** ')
s_lc = time.time()
data_generator = DataGen(config)
queries = read_query(config)
queries = read_query(config, d_id)
print("loading corpus time (h):", (time.time() - s_lc) / 3600)
print('********************** loading model ********************** ')
s_lm = time.time()
s_lm = time.time()
model_onehop_bert = ModelOneHop()
param_dict = load_checkpoint(config.onehop_bert_path)
load_param_into_net(model_onehop_bert, param_dict)
@ -90,10 +95,10 @@ def evaluation():
print("loading model time (h):", (time.time() - s_lm) / 3600)
print('********************** evaluation ********************** ')
s_tr = time.time()
f_dev = open(config.dev_path, 'rb')
dev_data = json.load(f_dev)
f_dev.close()
q_gold = {}
q_2id = {}
for onedata in dev_data:
@ -101,10 +106,10 @@ def evaluation():
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
q_2id[onedata["question"]] = onedata['_id']
val, true_count, count, step = 0, 0, 0, 0
batch_queries = split_queries(config, queries)[:-1]
batch_queries = split_queries(config, queries)
output_path = []
for _, batch in enumerate(batch_queries):
print("###step###: ", step)
print("###step###: ", str(step) + "_" + str(d_id))
query = batch[0]
temp_dict = {}
temp_dict['q_id'] = q_2id[query]
@ -158,23 +163,36 @@ def evaluation():
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
temp_dict['topk_titles'] = topk_titles
output_path.append(temp_dict)
count += 1
print("val:", val)
print("count:", count)
print("true count:", true_count)
if count:
print("PEM:", val / count)
if true_count:
print("true top8 PEM:", val / true_count)
step += 1
save_json(output_path, config.save_path, config.save_name)
print("evaluation time (h):", (time.time() - s_tr) / 3600)
count += 1
return {'val': val, 'count': count, 'true_count': true_count, 'path': output_path}
if __name__ == "__main__":
t_s = time.time()
config = ThinkRetrieverConfig()
context.set_context(mode=context.GRAPH_MODE,
device_target='Ascend',
device_id=config.device_id,
save_graphs=False)
evaluation()
pool = Pool(processes=config.device_num)
results = []
for device_id in range(config.device_num):
results.append(pool.apply_async(evaluation, (device_id,)))
print("Waiting for all subprocess done...")
pool.close()
pool.join()
val_all, true_count_all, count_all = 0, 0, 0
output_path_all = []
for res in results:
output = res.get()
val_all += output['val']
count_all += output['count']
true_count_all += output['true_count']
output_path_all += output['path']
print("val:", val_all)
print("count:", count_all)
print("true count:", true_count_all)
print("PEM:", val_all / count_all)
print("true top8 PEM:", val_all / true_count_all)
save_json(output_path_all, config.save_path, config.save_name)
print("evaluation time (h):", (time.time() - t_s) / 3600)

View File

@ -31,7 +31,7 @@ def ThinkRetrieverConfig():
parser.add_argument("--topk", type=int, default=8, help="top num")
parser.add_argument("--onehop_num", type=int, default=8, help="onehop num")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--device_num", type=int, default=8, help="device num")
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
parser.add_argument("--save_path", type=str, default='../', help='path of output')
parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
@ -43,4 +43,5 @@ def ThinkRetrieverConfig():
parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
parser.add_argument("--q_path", type=str, default="../queries", help="queries data path")
return parser.parse_args()

View File

@ -53,6 +53,9 @@ class DataGen:
data_db = pkl.load(f_wiki, encoding="gbk")
dev_data = json.load(f_train)
q_doc_text = pkl.load(f_doc, encoding='gbk')
f_wiki.close()
f_train.close()
f_doc.close()
return data_db, dev_data, q_doc_text
def process_data(self):

View File

@ -28,13 +28,10 @@ def normalize(text):
return text[0].capitalize() + text[1:]
def read_query(config):
def read_query(config, device_id):
"""get query data"""
with open(config.dev_data_path, 'rb') as f:
temp_dic = pkl.load(f, encoding='gbk')
queries = []
for item in temp_dic:
queries.append(temp_dic[item]["query"])
with open(config.q_path + str(device_id), 'rb') as f:
queries = pkl.load(f, encoding='gbk')
return queries