forked from mindspore-Ecosystem/mindspore
!13838 [ModelZoo]Add tprr model 8p version
From: @zhan_ke Reviewed-by: Signed-off-by:
This commit is contained in:
commit
0dd7003fd7
|
@ -19,6 +19,7 @@ Retriever Evaluation.
|
|||
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
|
@ -69,16 +70,20 @@ def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
|
|||
return val, true_count, topk_titles
|
||||
|
||||
|
||||
def evaluation():
|
||||
def evaluation(d_id):
|
||||
"""evaluation"""
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target='Ascend',
|
||||
device_id=d_id,
|
||||
save_graphs=False)
|
||||
print('********************** loading corpus ********************** ')
|
||||
s_lc = time.time()
|
||||
data_generator = DataGen(config)
|
||||
queries = read_query(config)
|
||||
queries = read_query(config, d_id)
|
||||
print("loading corpus time (h):", (time.time() - s_lc) / 3600)
|
||||
print('********************** loading model ********************** ')
|
||||
s_lm = time.time()
|
||||
|
||||
s_lm = time.time()
|
||||
model_onehop_bert = ModelOneHop()
|
||||
param_dict = load_checkpoint(config.onehop_bert_path)
|
||||
load_param_into_net(model_onehop_bert, param_dict)
|
||||
|
@ -90,10 +95,10 @@ def evaluation():
|
|||
|
||||
print("loading model time (h):", (time.time() - s_lm) / 3600)
|
||||
print('********************** evaluation ********************** ')
|
||||
s_tr = time.time()
|
||||
|
||||
f_dev = open(config.dev_path, 'rb')
|
||||
dev_data = json.load(f_dev)
|
||||
f_dev.close()
|
||||
q_gold = {}
|
||||
q_2id = {}
|
||||
for onedata in dev_data:
|
||||
|
@ -101,10 +106,10 @@ def evaluation():
|
|||
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
|
||||
q_2id[onedata["question"]] = onedata['_id']
|
||||
val, true_count, count, step = 0, 0, 0, 0
|
||||
batch_queries = split_queries(config, queries)[:-1]
|
||||
batch_queries = split_queries(config, queries)
|
||||
output_path = []
|
||||
for _, batch in enumerate(batch_queries):
|
||||
print("###step###: ", step)
|
||||
print("###step###: ", str(step) + "_" + str(d_id))
|
||||
query = batch[0]
|
||||
temp_dict = {}
|
||||
temp_dict['q_id'] = q_2id[query]
|
||||
|
@ -158,23 +163,36 @@ def evaluation():
|
|||
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
|
||||
temp_dict['topk_titles'] = topk_titles
|
||||
output_path.append(temp_dict)
|
||||
count += 1
|
||||
print("val:", val)
|
||||
print("count:", count)
|
||||
print("true count:", true_count)
|
||||
if count:
|
||||
print("PEM:", val / count)
|
||||
if true_count:
|
||||
print("true top8 PEM:", val / true_count)
|
||||
step += 1
|
||||
save_json(output_path, config.save_path, config.save_name)
|
||||
print("evaluation time (h):", (time.time() - s_tr) / 3600)
|
||||
count += 1
|
||||
return {'val': val, 'count': count, 'true_count': true_count, 'path': output_path}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
t_s = time.time()
|
||||
config = ThinkRetrieverConfig()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target='Ascend',
|
||||
device_id=config.device_id,
|
||||
save_graphs=False)
|
||||
evaluation()
|
||||
pool = Pool(processes=config.device_num)
|
||||
results = []
|
||||
for device_id in range(config.device_num):
|
||||
results.append(pool.apply_async(evaluation, (device_id,)))
|
||||
|
||||
print("Waiting for all subprocess done...")
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
val_all, true_count_all, count_all = 0, 0, 0
|
||||
output_path_all = []
|
||||
for res in results:
|
||||
output = res.get()
|
||||
val_all += output['val']
|
||||
count_all += output['count']
|
||||
true_count_all += output['true_count']
|
||||
output_path_all += output['path']
|
||||
print("val:", val_all)
|
||||
print("count:", count_all)
|
||||
print("true count:", true_count_all)
|
||||
print("PEM:", val_all / count_all)
|
||||
print("true top8 PEM:", val_all / true_count_all)
|
||||
save_json(output_path_all, config.save_path, config.save_name)
|
||||
print("evaluation time (h):", (time.time() - t_s) / 3600)
|
||||
|
|
|
@ -31,7 +31,7 @@ def ThinkRetrieverConfig():
|
|||
parser.add_argument("--topk", type=int, default=8, help="top num")
|
||||
parser.add_argument("--onehop_num", type=int, default=8, help="onehop num")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--device_num", type=int, default=8, help="device num")
|
||||
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
|
||||
parser.add_argument("--save_path", type=str, default='../', help='path of output')
|
||||
parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
|
||||
|
@ -43,4 +43,5 @@ def ThinkRetrieverConfig():
|
|||
parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
|
||||
parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
|
||||
parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
|
||||
parser.add_argument("--q_path", type=str, default="../queries", help="queries data path")
|
||||
return parser.parse_args()
|
||||
|
|
|
@ -53,6 +53,9 @@ class DataGen:
|
|||
data_db = pkl.load(f_wiki, encoding="gbk")
|
||||
dev_data = json.load(f_train)
|
||||
q_doc_text = pkl.load(f_doc, encoding='gbk')
|
||||
f_wiki.close()
|
||||
f_train.close()
|
||||
f_doc.close()
|
||||
return data_db, dev_data, q_doc_text
|
||||
|
||||
def process_data(self):
|
||||
|
|
|
@ -28,13 +28,10 @@ def normalize(text):
|
|||
return text[0].capitalize() + text[1:]
|
||||
|
||||
|
||||
def read_query(config):
|
||||
def read_query(config, device_id):
|
||||
"""get query data"""
|
||||
with open(config.dev_data_path, 'rb') as f:
|
||||
temp_dic = pkl.load(f, encoding='gbk')
|
||||
queries = []
|
||||
for item in temp_dic:
|
||||
queries.append(temp_dic[item]["query"])
|
||||
with open(config.q_path + str(device_id), 'rb') as f:
|
||||
queries = pkl.load(f, encoding='gbk')
|
||||
return queries
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue