Support CPU tinybert and ner task

This commit is contained in:
zhaoting 2021-01-20 14:12:52 +08:00
parent 5d0490909d
commit c14c707475
8 changed files with 220 additions and 89 deletions

View File

@ -1,4 +1,4 @@
# Contents
# Contents
- [Contents](#contents)
- [TinyBERT Description](#tinybert-description)
@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN
[--load_gd_ckpt_path LOAD_GD_CKPT_PATH]
[--load_td1_ckpt_path LOAD_TD1_CKPT_PATH]
[--train_data_dir TRAIN_DATA_DIR]
[--eval_data_dir EVAL_DATA_DIR]
[--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE]
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE]
[--assessment_method ASSESSMENT_METHOD]
options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
@ -217,7 +218,9 @@ options:
--load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is ""
--train_data_dir path to train dataset directory: PATH, default is ""
--eval_data_dir path to eval dataset directory: PATH, default is ""
--task_name classification task: "SST-2" | "QNLI" | "MNLI", default is ""
--task_type task type: "classification" | "ner", default is "classification"
--task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is ""
--assessment_method assessment method to do evaluation: acc | f1
--schema_dir path to schema.json file, PATH, default is ""
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
```
@ -249,6 +252,7 @@ Parameters for optimizer:
Parameters for bert network:
seq_length length of input sequence: N, default is 128
vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522
Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522
hidden_size size of bert encoder layers: N
num_hidden_layers number of hidden layers: N
num_attention_heads number of attention heads: N, default is 12

View File

@ -22,7 +22,7 @@ from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.td_config import td_student_net_cfg
from src.tinybert_model import BertModelCLS
from src.tinybert_model import BertModelCLS, BertModelNER
parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
help="The type of the task to train.")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
help="The name of the task to train.")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128
DEFAULT_BS = 32
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}}
"MNLI": {"num_labels": 3, "seq_length": 128},
"TNEWS": {"num_labels": 15, "seq_length": 128},
"CLUENER": {"num_labels": 10, "seq_length": 128}}
class Task:
"""
@ -68,8 +73,13 @@ if __name__ == '__main__':
task = Task(args.task_name)
td_student_net_cfg.seq_length = task.seq_length
td_student_net_cfg.batch_size = DEFAULT_BS
if args.task_type == "classification":
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
elif args.task_type == "ner":
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
else:
raise ValueError(f"Not support task type: {args.task_type}")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(args.ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():

View File

@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
def run_general_distill():
"""
run general distill
"""
def get_argument():
"""Tinybert general distill argument parser."""
parser = argparse.ArgumentParser(description='tinybert general distill')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
@ -61,20 +57,21 @@ def run_general_distill():
parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord")
args_opt = parser.parse_args()
return args_opt
def run_general_distill():
"""
run general distill
"""
args_opt = get_argument()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
reserve_class_name_in_scope=False)
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
context.set_context(device_id=args_opt.device_id)
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend':
D.init()
@ -104,6 +101,14 @@ def run_general_distill():
# and the loss scale is not necessary
enable_loss_scale = False
if args_opt.device_target == "CPU":
logger.warning('CPU only support float32 temporarily, run with float32.')
bert_teacher_net_cfg.dtype = mstype.float32
bert_teacher_net_cfg.compute_type = mstype.float32
bert_student_net_cfg.dtype = mstype.float32
bert_student_net_cfg.compute_type = mstype.float32
enable_loss_scale = False
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
teacher_ckpt=args_opt.load_teacher_ckpt_path,
student_config=bert_student_net_cfg,

View File

@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy
from src.assessment_method import Accuracy, F1
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
from src.tinybert_model import BertModelCLS
from src.tinybert_model import BertModelCLS, BertModelNER
_cur_dir = os.getcwd()
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
@ -46,7 +46,7 @@ def parse_args():
parse args
"""
parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
help="Do train task, default is true.")
@ -69,21 +69,46 @@ def parse_args():
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
help="The type of the task to train.")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
help="The name of the task to train.")
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord")
args = parser.parse_args()
if args.do_train.lower() != "true" and args.do_eval.lower() != "true":
raise ValueError("do train or do eval must have one be true, please confirm your config")
if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification":
raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification")
if args.task_name in ["CLUENER"] and args.task_type != "ner":
raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner")
if args.task_name in ["SST-2", "QNLI", "MNLI"] and \
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\
"EN vocabs according to the origin paper.")
if args.task_name in ["TNEWS", "CLUENER"] and \
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \
"EN vocabs according to the origin paper.")
return args
args_opt = parse_args()
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}}
"MNLI": {"num_labels": 3, "seq_length": 128},
"TNEWS": {"num_labels": 15, "seq_length": 128},
"CLUENER": {"num_labels": 43, "seq_length": 128}}
class Task:
@ -112,29 +137,15 @@ def run_predistill():
run predistill
"""
cfg = phase1_cfg
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification',
is_training=True, task_type=args_opt.task_type,
num_labels=task.num_labels, is_predistill=True)
rank = 0
device_num = 1
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir,
@ -190,25 +201,19 @@ def run_task_distill(ckpt_file):
raise ValueError("Student ckpt file should not be None")
cfg = phase2_cfg
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = ckpt_file
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification',
is_training=True, task_type=args_opt.task_type,
num_labels=task.num_labels, is_predistill=False)
rank = 0
device_num = 1
train_dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir)
args_opt.train_data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size)
@ -238,7 +243,8 @@ def run_task_distill(ckpt_file):
eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir)
args_opt.eval_data_dir, args_opt.schema_dir,
data_type=dataset_type)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
if args_opt.do_eval.lower() == "true":
@ -263,6 +269,19 @@ def run_task_distill(ckpt_file):
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
sink_size=args_opt.data_sink_steps)
def eval_result_print(assessment_method="accuracy", callback=None):
"""print eval result"""
if assessment_method == "accuracy":
print("============== acc is {}".format(callback.acc_num / callback.total_num))
elif assessment_method == "bf1":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
elif assessment_method == "mf1":
print("F1 {:.6f} ".format(callback.eval()))
else:
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
def do_eval_standalone():
"""
do eval standalone
@ -270,13 +289,12 @@ def do_eval_standalone():
ckpt_file = args_opt.load_td1_ckpt_path
if ckpt_file == '':
raise ValueError("Student ckpt file should not be None")
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.task_type == "classification":
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
elif args_opt.task_type == "ner":
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
else:
raise Exception("Target error, GPU or Ascend is supported.")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
raise ValueError(f"Not support the task type {args_opt.task_type}")
param_dict = load_checkpoint(ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():
@ -289,11 +307,18 @@ def do_eval_standalone():
eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size,
device_num=1, rank=0, do_shuffle="false",
data_dir=args_opt.eval_data_dir,
schema_dir=args_opt.schema_dir)
schema_dir=args_opt.schema_dir,
data_type=dataset_type)
print('eval dataset size: ', eval_dataset.get_dataset_size())
print('eval dataset batch size: ', eval_dataset.get_batch_size())
callback = Accuracy()
if args_opt.assessment_method == "accuracy":
callback = Accuracy()
elif args_opt.assessment_method == "bf1":
callback = F1(num_labels=task.num_labels)
elif args_opt.assessment_method == "mf1":
callback = F1(num_labels=task.num_labels, mode="MultiLabel")
else:
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in eval_dataset.create_dict_iterator(num_epochs=1):
input_data = []
@ -302,16 +327,16 @@ def do_eval_standalone():
input_ids, input_mask, token_type_id, label_ids = input_data
logits = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits, label_ids)
acc = callback.acc_num / callback.total_num
print("======================================")
print("============== acc is {}".format(acc))
print("======================================")
print("==============================================================")
eval_result_print(args_opt.assessment_method, callback)
print("==============================================================")
if __name__ == '__main__':
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
raise ValueError("do_train or do eval must have one be true, please confirm your config")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
reserve_class_name_in_scope=False)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)
enable_loss_scale = True
if args_opt.device_target == "GPU":
if td_student_net_cfg.compute_type != mstype.float32:
@ -321,6 +346,14 @@ if __name__ == '__main__':
# and the loss scale is not necessary
enable_loss_scale = False
if args_opt.device_target == "CPU":
logger.warning('CPU only support float32 temporarily, run with float32.')
td_teacher_net_cfg.dtype = mstype.float32
td_teacher_net_cfg.compute_type = mstype.float32
td_student_net_cfg.dtype = mstype.float32
td_student_net_cfg.compute_type = mstype.float32
enable_loss_scale = False
td_teacher_net_cfg.seq_length = task.seq_length
td_student_net_cfg.seq_length = task.seq_length

View File

@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--do_eval="true" \
--td_phase1_epoch_size=10 \
--td_phase2_epoch_size=3 \
--task_name="" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=100 \
@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--train_data_dir="" \
--eval_data_dir="" \
--schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 &
--dataset_type="tfrecord" \
--task_type="classification" \
--task_name="" \
--assessment_method="accuracy" > log.txt 2>&1 &

View File

@ -32,23 +32,56 @@ class Accuracy():
self.total_num += len(labels)
class F1():
"""F1"""
def __init__(self):
'''
calculate F1 score
'''
def __init__(self, num_labels=2, mode="Binary"):
self.TP = 0
self.FP = 0
self.FN = 0
self.num_labels = num_labels
self.P = 0
self.AP = 0
self.mode = mode
if self.mode.lower() not in ("binary", "multilabel"):
raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]")
def update(self, logits, labels):
"""Update F1 score"""
'''
update F1 score
'''
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1)
pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7])
pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7])
self.TP += np.sum(pos_eva & pos_label)
self.FP += np.sum(pos_eva & (~pos_label))
self.FN += np.sum((~pos_eva) & pos_label)
print("-----------------precision is ", self.TP / (self.TP + self.FP))
print("-----------------recall is ", self.TP / (self.TP + self.FN))
if self.mode.lower() == "binary":
pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
self.TP += np.sum(pos_eva&pos_label)
self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label)
else:
target = np.zeros((len(labels), self.num_labels), dtype=np.int)
pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int)
for i, label in enumerate(labels):
target[i][label] = 1
for i, label in enumerate(logit_id):
pred[i][label] = 1
positives = pred.sum(axis=0)
actual_positives = target.sum(axis=0)
true_positives = (target * pred).sum(axis=0)
self.TP += true_positives
self.P += positives
self.AP += actual_positives
def eval(self):
if self.mode.lower() == "binary":
f1 = self.TP / (2 * self.TP + self.FP + self.FN)
else:
tp = np.sum(self.TP)
p = np.sum(self.P)
ap = np.sum(self.AP)
f1 = 2 * tp / (ap + p)
return f1

View File

@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER
GRADIENT_CLIP_TYPE = 1
@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell):
temperature=1.0, dropout_prob=0.1):
super(BertNetworkWithLoss_td, self).__init__()
# load teacher model
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
use_one_hot_embeddings, "teacher")
if task_type == "classification":
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
use_one_hot_embeddings, "teacher")
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
elif task_type == "ner":
self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob,
use_one_hot_embeddings, "teacher")
self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
else:
raise ValueError(f"Not support task type: {task_type}")
param_dict = load_checkpoint(teacher_ckpt)
new_param_dict = {}
for key, value in param_dict.items():
@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell):
for param in params:
param.requires_grad = False
# load student model
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
param_dict = load_checkpoint(student_ckpt)
if is_predistill:
new_param_dict = {}
@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell):
self.is_predistill = is_predistill
self.is_att_fit = is_att_fit
self.is_rep_fit = is_rep_fit
self.task_type = task_type
self.use_soft_cross_entropy = task_type in ["classification", "ner"]
self.temperature = temperature
self.loss_mse = nn.MSELoss()
self.select = P.Select()
@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell):
rep_loss += self.loss_mse(student_rep, teacher_rep)
total_loss += rep_loss
else:
if self.task_type == "classification":
if self.use_soft_cross_entropy:
cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
else:
cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1])

View File

@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell):
if self._phase == 'train' or self.phase_type == "teacher":
return seq_output, att_output, logits, log_probs
return log_probs
class BertModelNER(nn.Cell):
"""
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0,
use_one_hot_embeddings=False, phase_type="student"):
super(BertModelNER, self).__init__()
if not is_training:
config.hidden_dropout_prob = 0.0
config.hidden_probs_dropout_prob = 0.0
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.ReLU()
self.reshape = P.Reshape()
self.shape = (-1, config.hidden_size)
self.origin_shape = (-1, config.seq_length, self.num_labels)
def construct(self, input_ids, input_mask, token_type_id):
"""Return the final logits as the results of log_softmax."""
sequence_output, _, _, encoder_outputs, attention_outputs = \
self.bert(input_ids, token_type_id, input_mask)
seq = self.dropout(sequence_output)
seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq)
logits = self.cast(logits, self.dtype)
return_value = self.log_softmax(logits)
if self._phase == 'train' or self.phase_type == "teacher":
return encoder_outputs, attention_outputs, logits, return_value
return return_value