forked from mindspore-Ecosystem/mindspore
!7646 eval result only return probs
Merge pull request !7646 from wanghua/master
This commit is contained in:
commit
7c98803ad6
|
@ -301,7 +301,7 @@ def do_eval_standalone():
|
|||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
logits = eval_model(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits[3], label_ids)
|
||||
callback.update(logits, label_ids)
|
||||
acc = callback.acc_num / callback.total_num
|
||||
print("======================================")
|
||||
print("============== acc is {}".format(acc))
|
||||
|
|
|
@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell):
|
|||
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False, phase_type="teacher"):
|
||||
use_one_hot_embeddings=False, phase_type="student"):
|
||||
super(BertModelCLS, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
|
@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell):
|
|||
logits = self.dense_1(cls)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
log_probs = self.log_softmax(logits)
|
||||
if self._phase == 'train' or self.phase_type == "teacher":
|
||||
return seq_output, att_output, logits, log_probs
|
||||
return log_probs
|
||||
|
|
|
@ -100,7 +100,7 @@ class EvalCallBack(Callback):
|
|||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
self.network.set_train(False)
|
||||
logits = self.network(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits[3], label_ids)
|
||||
callback.update(logits, label_ids)
|
||||
acc = callback.acc_num / callback.total_num
|
||||
with open("./eval.log", "a+") as f:
|
||||
f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,
|
||||
|
|
Loading…
Reference in New Issue