!7646 eval result only return probs

Merge pull request !7646 from wanghua/master
This commit is contained in:
mindspore-ci-bot 2020-10-27 19:53:51 +08:00 committed by Gitee
commit 7c98803ad6
3 changed files with 6 additions and 4 deletions

View File

@ -301,7 +301,7 @@ def do_eval_standalone():
input_data.append(data[i])
input_ids, input_mask, token_type_id, label_ids = input_data
logits = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits[3], label_ids)
callback.update(logits, label_ids)
acc = callback.acc_num / callback.total_num
print("======================================")
print("============== acc is {}".format(acc))

View File

@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell):
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
use_one_hot_embeddings=False, phase_type="teacher"):
use_one_hot_embeddings=False, phase_type="student"):
super(BertModelCLS, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell):
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
log_probs = self.log_softmax(logits)
return seq_output, att_output, logits, log_probs
if self._phase == 'train' or self.phase_type == "teacher":
return seq_output, att_output, logits, log_probs
return log_probs

View File

@ -100,7 +100,7 @@ class EvalCallBack(Callback):
input_ids, input_mask, token_type_id, label_ids = input_data
self.network.set_train(False)
logits = self.network(input_ids, token_type_id, input_mask)
callback.update(logits[3], label_ids)
callback.update(logits, label_ids)
acc = callback.acc_num / callback.total_num
with open("./eval.log", "a+") as f:
f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,