From f09087a68aac904f263c3ea8ad507c0950103f3c Mon Sep 17 00:00:00 2001 From: wanghua Date: Thu, 22 Oct 2020 19:51:37 +0800 Subject: [PATCH] update tinybert_model.py --- model_zoo/official/nlp/tinybert/run_task_distill.py | 2 +- model_zoo/official/nlp/tinybert/src/tinybert_model.py | 6 ++++-- model_zoo/official/nlp/tinybert/src/utils.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index cd35bc5c346..430f27ac3b3 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -301,7 +301,7 @@ def do_eval_standalone(): input_data.append(data[i]) input_ids, input_mask, token_type_id, label_ids = input_data logits = eval_model(input_ids, token_type_id, input_mask) - callback.update(logits[3], label_ids) + callback.update(logits, label_ids) acc = callback.acc_num / callback.total_num print("======================================") print("============== acc is {}".format(acc)) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 09504abcd8c..d802d4ba8a7 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell): The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. """ def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, - use_one_hot_embeddings=False, phase_type="teacher"): + use_one_hot_embeddings=False, phase_type="student"): super(BertModelCLS, self).__init__() self.bert = BertModel(config, is_training, use_one_hot_embeddings) self.cast = P.Cast() @@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell): logits = self.dense_1(cls) logits = self.cast(logits, self.dtype) log_probs = self.log_softmax(logits) - return seq_output, att_output, logits, log_probs + if self._phase == 'train' or self.phase_type == "teacher": + return seq_output, att_output, logits, log_probs + return log_probs diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py index 40b970aa8bb..2b2fd69c3f6 100644 --- a/model_zoo/official/nlp/tinybert/src/utils.py +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -100,7 +100,7 @@ class EvalCallBack(Callback): input_ids, input_mask, token_type_id, label_ids = input_data self.network.set_train(False) logits = self.network(input_ids, token_type_id, input_mask) - callback.update(logits[3], label_ids) + callback.update(logits, label_ids) acc = callback.acc_num / callback.total_num with open("./eval.log", "a+") as f: f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,