Support CPU tinybert and ner task
This commit is contained in:
parent
5d0490909d
commit
c14c707475
|
@ -1,4 +1,4 @@
|
|||
# Contents
|
||||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [TinyBERT Description](#tinybert-description)
|
||||
|
@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN
|
|||
[--load_gd_ckpt_path LOAD_GD_CKPT_PATH]
|
||||
[--load_td1_ckpt_path LOAD_TD1_CKPT_PATH]
|
||||
[--train_data_dir TRAIN_DATA_DIR]
|
||||
[--eval_data_dir EVAL_DATA_DIR]
|
||||
[--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE]
|
||||
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE]
|
||||
[--assessment_method ASSESSMENT_METHOD]
|
||||
|
||||
options:
|
||||
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||
|
@ -217,7 +218,9 @@ options:
|
|||
--load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is ""
|
||||
--train_data_dir path to train dataset directory: PATH, default is ""
|
||||
--eval_data_dir path to eval dataset directory: PATH, default is ""
|
||||
--task_name classification task: "SST-2" | "QNLI" | "MNLI", default is ""
|
||||
--task_type task type: "classification" | "ner", default is "classification"
|
||||
--task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is ""
|
||||
--assessment_method assessment method to do evaluation: acc | f1
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
|
||||
```
|
||||
|
@ -249,6 +252,7 @@ Parameters for optimizer:
|
|||
Parameters for bert network:
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522
|
||||
Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Tensor, context
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.td_config import td_student_net_cfg
|
||||
from src.tinybert_model import BertModelCLS
|
||||
from src.tinybert_model import BertModelCLS, BertModelNER
|
||||
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
|
@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi
|
|||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
|
||||
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
|
||||
help="The type of the task to train.")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
||||
help="The name of the task to train.")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128
|
|||
DEFAULT_BS = 32
|
||||
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
||||
"QNLI": {"num_labels": 2, "seq_length": 128},
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128}}
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128},
|
||||
"TNEWS": {"num_labels": 15, "seq_length": 128},
|
||||
"CLUENER": {"num_labels": 10, "seq_length": 128}}
|
||||
|
||||
class Task:
|
||||
"""
|
||||
|
@ -68,8 +73,13 @@ if __name__ == '__main__':
|
|||
task = Task(args.task_name)
|
||||
td_student_net_cfg.seq_length = task.seq_length
|
||||
td_student_net_cfg.batch_size = DEFAULT_BS
|
||||
if args.task_type == "classification":
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
elif args.task_type == "ner":
|
||||
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
else:
|
||||
raise ValueError(f"Not support task type: {args.task_type}")
|
||||
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
|
|
|
@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
|||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
|
||||
|
||||
|
||||
|
||||
def run_general_distill():
|
||||
"""
|
||||
run general distill
|
||||
"""
|
||||
def get_argument():
|
||||
"""Tinybert general distill argument parser."""
|
||||
parser = argparse.ArgumentParser(description='tinybert general distill')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
|
@ -61,20 +57,21 @@ def run_general_distill():
|
|||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
def run_general_distill():
|
||||
"""
|
||||
run general distill
|
||||
"""
|
||||
args_opt = get_argument()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
reserve_class_name_in_scope=False)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init()
|
||||
|
@ -104,6 +101,14 @@ def run_general_distill():
|
|||
# and the loss scale is not necessary
|
||||
enable_loss_scale = False
|
||||
|
||||
if args_opt.device_target == "CPU":
|
||||
logger.warning('CPU only support float32 temporarily, run with float32.')
|
||||
bert_teacher_net_cfg.dtype = mstype.float32
|
||||
bert_teacher_net_cfg.compute_type = mstype.float32
|
||||
bert_student_net_cfg.dtype = mstype.float32
|
||||
bert_student_net_cfg.compute_type = mstype.float32
|
||||
enable_loss_scale = False
|
||||
|
||||
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
|
||||
teacher_ckpt=args_opt.load_teacher_ckpt_path,
|
||||
student_config=bert_student_net_cfg,
|
||||
|
|
|
@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay
|
|||
from mindspore import log as logger
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.assessment_method import Accuracy
|
||||
from src.assessment_method import Accuracy, F1
|
||||
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
||||
from src.tinybert_model import BertModelCLS
|
||||
from src.tinybert_model import BertModelCLS, BertModelNER
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
|
||||
|
@ -46,7 +46,7 @@ def parse_args():
|
|||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
|
||||
help="Do train task, default is true.")
|
||||
|
@ -69,21 +69,46 @@ def parse_args():
|
|||
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
||||
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
|
||||
help="The type of the task to train.")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
|
||||
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args = parser.parse_args()
|
||||
if args.do_train.lower() != "true" and args.do_eval.lower() != "true":
|
||||
raise ValueError("do train or do eval must have one be true, please confirm your config")
|
||||
if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification":
|
||||
raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification")
|
||||
if args.task_name in ["CLUENER"] and args.task_type != "ner":
|
||||
raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner")
|
||||
if args.task_name in ["SST-2", "QNLI", "MNLI"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
|
||||
logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\
|
||||
"EN vocabs according to the origin paper.")
|
||||
if args.task_name in ["TNEWS", "CLUENER"] and \
|
||||
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
|
||||
logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \
|
||||
"EN vocabs according to the origin paper.")
|
||||
return args
|
||||
|
||||
|
||||
args_opt = parse_args()
|
||||
|
||||
if args_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif args_opt.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
raise Exception("dataset format is not supported yet")
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
||||
"QNLI": {"num_labels": 2, "seq_length": 128},
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128}}
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128},
|
||||
"TNEWS": {"num_labels": 15, "seq_length": 128},
|
||||
"CLUENER": {"num_labels": 43, "seq_length": 128}}
|
||||
|
||||
|
||||
class Task:
|
||||
|
@ -112,29 +137,15 @@ def run_predistill():
|
|||
run predistill
|
||||
"""
|
||||
cfg = phase1_cfg
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
|
||||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
is_training=True, task_type=args_opt.task_type,
|
||||
num_labels=task.num_labels, is_predistill=True)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
if args_opt.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif args_opt.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
raise Exception("dataset format is not supported yet")
|
||||
dataset = create_tinybert_dataset('td', cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, args_opt.schema_dir,
|
||||
|
@ -190,25 +201,19 @@ def run_task_distill(ckpt_file):
|
|||
raise ValueError("Student ckpt file should not be None")
|
||||
cfg = phase2_cfg
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = ckpt_file
|
||||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
is_training=True, task_type=args_opt.task_type,
|
||||
num_labels=task.num_labels, is_predistill=False)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
train_dataset = create_tinybert_dataset('td', cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, args_opt.schema_dir)
|
||||
args_opt.train_data_dir, args_opt.schema_dir,
|
||||
data_type=dataset_type)
|
||||
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
print('td2 train dataset size: ', dataset_size)
|
||||
|
@ -238,7 +243,8 @@ def run_task_distill(ckpt_file):
|
|||
|
||||
eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||
args_opt.eval_data_dir, args_opt.schema_dir,
|
||||
data_type=dataset_type)
|
||||
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
|
@ -263,6 +269,19 @@ def run_task_distill(ckpt_file):
|
|||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
def eval_result_print(assessment_method="accuracy", callback=None):
|
||||
"""print eval result"""
|
||||
if assessment_method == "accuracy":
|
||||
print("============== acc is {}".format(callback.acc_num / callback.total_num))
|
||||
elif assessment_method == "bf1":
|
||||
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
|
||||
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
|
||||
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
|
||||
elif assessment_method == "mf1":
|
||||
print("F1 {:.6f} ".format(callback.eval()))
|
||||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
|
||||
|
||||
def do_eval_standalone():
|
||||
"""
|
||||
do eval standalone
|
||||
|
@ -270,13 +289,12 @@ def do_eval_standalone():
|
|||
ckpt_file = args_opt.load_td1_ckpt_path
|
||||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.task_type == "classification":
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
elif args_opt.task_type == "ner":
|
||||
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
raise ValueError(f"Not support the task type {args_opt.task_type}")
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
|
@ -289,11 +307,18 @@ def do_eval_standalone():
|
|||
eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size,
|
||||
device_num=1, rank=0, do_shuffle="false",
|
||||
data_dir=args_opt.eval_data_dir,
|
||||
schema_dir=args_opt.schema_dir)
|
||||
schema_dir=args_opt.schema_dir,
|
||||
data_type=dataset_type)
|
||||
print('eval dataset size: ', eval_dataset.get_dataset_size())
|
||||
print('eval dataset batch size: ', eval_dataset.get_batch_size())
|
||||
|
||||
callback = Accuracy()
|
||||
if args_opt.assessment_method == "accuracy":
|
||||
callback = Accuracy()
|
||||
elif args_opt.assessment_method == "bf1":
|
||||
callback = F1(num_labels=task.num_labels)
|
||||
elif args_opt.assessment_method == "mf1":
|
||||
callback = F1(num_labels=task.num_labels, mode="MultiLabel")
|
||||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in eval_dataset.create_dict_iterator(num_epochs=1):
|
||||
input_data = []
|
||||
|
@ -302,16 +327,16 @@ def do_eval_standalone():
|
|||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
logits = eval_model(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits, label_ids)
|
||||
acc = callback.acc_num / callback.total_num
|
||||
print("======================================")
|
||||
print("============== acc is {}".format(acc))
|
||||
print("======================================")
|
||||
print("==============================================================")
|
||||
eval_result_print(args_opt.assessment_method, callback)
|
||||
print("==============================================================")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
||||
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
reserve_class_name_in_scope=False)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
enable_loss_scale = True
|
||||
if args_opt.device_target == "GPU":
|
||||
if td_student_net_cfg.compute_type != mstype.float32:
|
||||
|
@ -321,6 +346,14 @@ if __name__ == '__main__':
|
|||
# and the loss scale is not necessary
|
||||
enable_loss_scale = False
|
||||
|
||||
if args_opt.device_target == "CPU":
|
||||
logger.warning('CPU only support float32 temporarily, run with float32.')
|
||||
td_teacher_net_cfg.dtype = mstype.float32
|
||||
td_teacher_net_cfg.compute_type = mstype.float32
|
||||
td_student_net_cfg.dtype = mstype.float32
|
||||
td_student_net_cfg.compute_type = mstype.float32
|
||||
enable_loss_scale = False
|
||||
|
||||
td_teacher_net_cfg.seq_length = task.seq_length
|
||||
td_student_net_cfg.seq_length = task.seq_length
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \
|
|||
--do_eval="true" \
|
||||
--td_phase1_epoch_size=10 \
|
||||
--td_phase2_epoch_size=3 \
|
||||
--task_name="" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
|
@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
|
|||
--train_data_dir="" \
|
||||
--eval_data_dir="" \
|
||||
--schema_dir="" \
|
||||
--dataset_type="tfrecord" > log.txt 2>&1 &
|
||||
|
||||
--dataset_type="tfrecord" \
|
||||
--task_type="classification" \
|
||||
--task_name="" \
|
||||
--assessment_method="accuracy" > log.txt 2>&1 &
|
||||
|
|
|
@ -32,23 +32,56 @@ class Accuracy():
|
|||
self.total_num += len(labels)
|
||||
|
||||
class F1():
|
||||
"""F1"""
|
||||
def __init__(self):
|
||||
'''
|
||||
calculate F1 score
|
||||
'''
|
||||
def __init__(self, num_labels=2, mode="Binary"):
|
||||
self.TP = 0
|
||||
self.FP = 0
|
||||
self.FN = 0
|
||||
self.num_labels = num_labels
|
||||
self.P = 0
|
||||
self.AP = 0
|
||||
self.mode = mode
|
||||
if self.mode.lower() not in ("binary", "multilabel"):
|
||||
raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]")
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""Update F1 score"""
|
||||
'''
|
||||
update F1 score
|
||||
'''
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
logit_id = np.reshape(logit_id, -1)
|
||||
pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7])
|
||||
pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7])
|
||||
self.TP += np.sum(pos_eva & pos_label)
|
||||
self.FP += np.sum(pos_eva & (~pos_label))
|
||||
self.FN += np.sum((~pos_eva) & pos_label)
|
||||
print("-----------------precision is ", self.TP / (self.TP + self.FP))
|
||||
print("-----------------recall is ", self.TP / (self.TP + self.FN))
|
||||
|
||||
if self.mode.lower() == "binary":
|
||||
pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
|
||||
pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
|
||||
self.TP += np.sum(pos_eva&pos_label)
|
||||
self.FP += np.sum(pos_eva&(~pos_label))
|
||||
self.FN += np.sum((~pos_eva)&pos_label)
|
||||
else:
|
||||
target = np.zeros((len(labels), self.num_labels), dtype=np.int)
|
||||
pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int)
|
||||
for i, label in enumerate(labels):
|
||||
target[i][label] = 1
|
||||
for i, label in enumerate(logit_id):
|
||||
pred[i][label] = 1
|
||||
positives = pred.sum(axis=0)
|
||||
actual_positives = target.sum(axis=0)
|
||||
true_positives = (target * pred).sum(axis=0)
|
||||
self.TP += true_positives
|
||||
self.P += positives
|
||||
self.AP += actual_positives
|
||||
|
||||
def eval(self):
|
||||
if self.mode.lower() == "binary":
|
||||
f1 = self.TP / (2 * self.TP + self.FP + self.FN)
|
||||
else:
|
||||
tp = np.sum(self.TP)
|
||||
p = np.sum(self.P)
|
||||
ap = np.sum(self.AP)
|
||||
f1 = 2 * tp / (ap + p)
|
||||
return f1
|
||||
|
|
|
@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size
|
|||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS
|
||||
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
|
@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
temperature=1.0, dropout_prob=0.1):
|
||||
super(BertNetworkWithLoss_td, self).__init__()
|
||||
# load teacher model
|
||||
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "teacher")
|
||||
if task_type == "classification":
|
||||
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "teacher")
|
||||
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "student")
|
||||
elif task_type == "ner":
|
||||
self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "teacher")
|
||||
self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "student")
|
||||
else:
|
||||
raise ValueError(f"Not support task type: {task_type}")
|
||||
param_dict = load_checkpoint(teacher_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
|
@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
for param in params:
|
||||
param.requires_grad = False
|
||||
# load student model
|
||||
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "student")
|
||||
param_dict = load_checkpoint(student_ckpt)
|
||||
if is_predistill:
|
||||
new_param_dict = {}
|
||||
|
@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
self.is_predistill = is_predistill
|
||||
self.is_att_fit = is_att_fit
|
||||
self.is_rep_fit = is_rep_fit
|
||||
self.task_type = task_type
|
||||
self.use_soft_cross_entropy = task_type in ["classification", "ner"]
|
||||
self.temperature = temperature
|
||||
self.loss_mse = nn.MSELoss()
|
||||
self.select = P.Select()
|
||||
|
@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
rep_loss += self.loss_mse(student_rep, teacher_rep)
|
||||
total_loss += rep_loss
|
||||
else:
|
||||
if self.task_type == "classification":
|
||||
if self.use_soft_cross_entropy:
|
||||
cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
|
||||
else:
|
||||
cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1])
|
||||
|
|
|
@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell):
|
|||
if self._phase == 'train' or self.phase_type == "teacher":
|
||||
return seq_output, att_output, logits, log_probs
|
||||
return log_probs
|
||||
|
||||
class BertModelNER(nn.Cell):
|
||||
"""
|
||||
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
|
||||
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False, phase_type="student"):
|
||||
super(BertModelNER, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.hidden_probs_dropout_prob = 0.0
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.hidden_size)
|
||||
self.origin_shape = (-1, config.seq_length, self.num_labels)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""Return the final logits as the results of log_softmax."""
|
||||
sequence_output, _, _, encoder_outputs, attention_outputs = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
seq = self.dropout(sequence_output)
|
||||
seq = self.reshape(seq, self.shape)
|
||||
logits = self.dense_1(seq)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
return_value = self.log_softmax(logits)
|
||||
if self._phase == 'train' or self.phase_type == "teacher":
|
||||
return encoder_outputs, attention_outputs, logits, return_value
|
||||
return return_value
|
||||
|
|
Loading…
Reference in New Issue