diff --git a/model_zoo/official/nlp/tinybert/README.md b/model_zoo/official/nlp/tinybert/README.md new file mode 100644 index 0000000000..fd5a54ea1b --- /dev/null +++ b/model_zoo/official/nlp/tinybert/README.md @@ -0,0 +1,129 @@ +# TinyBERT Example +## Description +This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model). + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download dataset for general distill and task distill such as GLUE. +- Prepare a pre-trained bert model and a fine-tuned bert model for specific task such as GLUE. + +## Running the Example +### General Distill +- Set options in `src/gd_config.py`, including lossscale, optimizer and network. + +- Set options in `scripts/run_standalone_gd.sh`, including device target, data sink config, checkpoint config and dataset. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. + +- Run `run_standalone_gd.sh` for non-distributed general distill of BERT-base model. + + ``` bash + bash scripts/run_standalone_gd.sh + ``` +- Run `run_distribute_gd.sh` for distributed general distill of BERT-base model. + + ``` bash + bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH + ``` + +### Task Distill +Task distill has two phases, pre-distill and task distill. +- Set options in `src/td_config.py`, including lossscale, optimizer config of phase 1 and 2, as well as network config. + +- Run `run_standalone_td.py` for task distill of BERT-base model. + + ```bash + bash scripts/run_standalone_td.sh + ``` + +## Usage +### General Distill +``` +usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET] + [--epoch_size N] [--device_id N] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] + [--save_checkpoint_steps N] [--max_ckpt_num N] + [--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute whether to run distributely: "true" | "false" + --device_target target device to run, currently only support "Ascend" + --epoch_size epoch size: N, default is 1 + --device_id device id: N, default is 0 + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is "" + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" + +usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET] + [--epoch_size N] [--device_id N] [--device_num N] + [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] + [--save_ckpt_steps N] [--max_ckpt_num N] + [--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH] + [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] + +options: + --distribute whether to run distributely: "true" | "false" + --device_target target device to run, currently only support "Ascend" + --epoch_size epoch size: N, default is 1 + --device_id device id: N, default is 0 + --device_num device id to run task + --enable_data_sink enable data sink: "true" | "false", default is "true" + --data_sink_steps set data sink steps: N, default is 1 + --load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is "" + --data_dir path to dataset directory: PATH, default is "" + --schema_dir path to schema.json file, PATH, default is "" + +``` + +## Options and Parameters +`gd_config.py` and `td_config.py` Contain parameters of BERT model and options for optimizer and lossscale. +### Options: +``` +Parameters for lossscale: + loss_scale_value initial value of loss scale: N, default is 2^8 + scale_factor factor used to update loss scale: N, default is 2 + scale_window steps for once updatation of loss scale: N, default is 50 + +Parameters for task-specific config: + load_teacher_ckpt_path teacher checkpoint to load + load_student_ckpt_path student checkpoint to load + data_dir training data dir + eval_data_dir evaluation data dir + schema_dir data schema path +``` + +### Parameters: +``` +Parameters for bert network: + batch_size batch size of input dataset: N, default is 16 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522 + hidden_size size of bert encoder layers: N + num_hidden_layers number of hidden layers: N + num_attention_heads number of attention heads: N, default is 12 + intermediate_size size of intermediate layer: N + hidden_act activation function used: ACTIVATION, default is "gelu" + hidden_dropout_prob dropout probability for BertOutput: Q + attention_probs_dropout_prob dropout probability for BertAttention: Q + max_position_embeddings maximum length of sequences: N, default is 512 + save_ckpt_step number for saving checkponit: N, default is 100 + max_ckpt_num maximum number for saving checkpoint: N, default is 1 + type_vocab_size size of token type vocab: N, default is 2 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + use_relative_positions use relative positions or not: True | False, default is False + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True + dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 + compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 + enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False + +Parameters for optimizer: + optimizer optimizer used in the network: AdamWeightDecay + learning_rate value of learning rate: Q + end_learning_rate value of end learning rate: Q, must be positive + power power: Q + weight_decay weight decay: Q + eps term added to the denominator to improve numerical stability: Q +``` + diff --git a/model_zoo/official/nlp/tinybert/__init__.py b/model_zoo/official/nlp/tinybert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py new file mode 100644 index 0000000000..adaf4bcd5b --- /dev/null +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""general distill script""" + +import os +import argparse +import datetime +import numpy +import mindspore.communication.management as D +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.callback import TimeMonitor +from mindspore.train.parallel_utils import ParallelMode +from mindspore.nn.optim import AdamWeightDecay +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from src.dataset import create_tinybert_dataset +from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate +from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg +from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd + +def run_general_distill(): + """ + run general distill + """ + parser = argparse.ArgumentParser(description='tinybert general distill') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") + parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") + parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(reserve_class_name_in_scope=False) + context.set_context(variable_memory_max_size="30GB") + + save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + if not os.path.exists(save_ckpt_dir): + os.makedirs(save_ckpt_dir) + + if args_opt.distribute == "true": + D.init('hccl') + device_num = args_opt.device_num + rank = args_opt.device_id % device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + device_num=device_num) + else: + rank = 0 + device_num = 1 + + netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, + teacher_ckpt=args_opt.load_teacher_ckpt_path, + student_config=bert_student_net_cfg, + is_training=True, use_one_hot_embeddings=False) + + dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, + args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + + dataset_size = dataset.get_dataset_size() + + if args_opt.enable_data_sink == "true": + repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + else: + repeat_count = args_opt.epoch_size + + lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size * args_opt.epoch_size / 10), + decay_steps=int(dataset_size * args_opt.epoch_size), + power=common_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) + + callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + save_ckpt_dir)] + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, + scale_factor=common_cfg.scale_factor, + scale_window=common_cfg.scale_window) + + netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == "true"), + sink_size=args_opt.data_sink_steps) + +if __name__ == '__main__': + numpy.random.seed(0) + run_general_distill() diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py new file mode 100644 index 0000000000..b09c418d03 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -0,0 +1,249 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""task distill script""" + +import os +import re +import argparse +from mindspore import Tensor +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.callback import TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecay +from src.dataset import create_tinybert_dataset +from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate +from src.assessment_method import Accuracy +from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg +from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td +from src.tinybert_model import BertModelCLS + +_cur_dir = os.getcwd() +td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') +td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase2_save_ckpt') +if not os.path.exists(td_phase1_save_ckpt_dir): + os.makedirs(td_phase1_save_ckpt_dir) +if not os.path.exists(td_phase2_save_ckpt_dir): + os.makedirs(td_phase2_save_ckpt_dir) + +def parse_args(): + """ + parse args + """ + parser = argparse.ArgumentParser(description='tinybert task distill') + parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.") + parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") + parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") + parser.add_argument("--td_phase1_epoch_size", type=int, default=10, + help="Epoch size for td phase 1, default is 10.") + parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") + parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") + parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") + parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") + parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + + args = parser.parse_args() + return args + +args_opt = parse_args() +def run_predistill(): + """ + run predistill + """ + cfg = phase1_cfg + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + context.set_context(reserve_class_name_in_scope=False) + load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path + load_student_checkpoint_path = args_opt.load_gd_ckpt_path + netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, + student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, + is_training=True, task_type='classification', + num_labels=args_opt.num_labels, is_predistill=True) + + rank = 0 + device_num = 1 + dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.train_data_dir, args_opt.schema_dir) + + dataset_size = dataset.get_dataset_size() + if args_opt.enable_data_sink == 'true': + repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps + else: + repeat_count = args_opt.td_phase1_epoch_size + + optimizer_cfg = cfg.optimizer_cfg + + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size / 10), + decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size), + power=optimizer_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + td_phase1_save_ckpt_dir)] + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == 'true'), + sink_size=args_opt.data_sink_steps) + +def run_task_distill(ckpt_file): + """ + run task distill + """ + if ckpt_file == '': + raise ValueError("Student ckpt file should not be None") + cfg = phase2_cfg + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path + load_student_checkpoint_path = ckpt_file + netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, + student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, + is_training=True, task_type='classification', + num_labels=args_opt.num_labels, is_predistill=False) + + rank = 0 + device_num = 1 + train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.train_data_dir, args_opt.schema_dir) + + dataset_size = train_dataset.get_dataset_size() + if args_opt.enable_data_sink == 'true': + repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps + else: + repeat_count = args_opt.td_phase2_epoch_size + + optimizer_cfg = cfg.optimizer_cfg + + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10), + decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size), + power=optimizer_cfg.AdamWeightDecay.power) + params = netwithloss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}, + {'order_params': params}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) + + eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, + device_num, rank, args_opt.do_shuffle, + args_opt.eval_data_dir, args_opt.schema_dir) + if args_opt.do_eval.lower() == "true": + callback = [TimeMonitor(dataset_size), LossCallBack(), + ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + td_phase2_save_ckpt_dir), + EvalCallBack(netwithloss.bert, eval_dataset)] + else: + callback = [TimeMonitor(dataset_size), LossCallBack(), + ModelSaveCkpt(netwithloss.bert, + args_opt.save_ckpt_step, + args_opt.max_ckpt_num, + td_phase2_save_ckpt_dir)] + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + + netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + model.train(repeat_count, train_dataset, callbacks=callback, + dataset_sink_mode=(args_opt.enable_data_sink == 'true'), + sink_size=args_opt.data_sink_steps) + +def do_eval_standalone(): + """ + do eval standalone + """ + ckpt_file = args_opt.load_td1_ckpt_path + if ckpt_file == '': + raise ValueError("Student ckpt file should not be None") + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student") + param_dict = load_checkpoint(ckpt_file) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('tinybert_', 'bert_', key) + new_key = re.sub('^bert.', '', new_key) + new_param_dict[new_key] = value + load_param_into_net(eval_model, new_param_dict) + eval_model.set_train(False) + + eval_dataset = create_tinybert_dataset('td', batch_size=1, + device_num=1, rank=0, do_shuffle="false", + data_dir=args_opt.eval_data_dir, + schema_dir=args_opt.schema_dir) + callback = Accuracy() + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in eval_dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = eval_model(input_ids, token_type_id, input_mask) + callback.update(logits[3], label_ids) + acc = callback.acc_num / callback.total_num + print("======================================") + print("============== acc is {}".format(acc)) + print("======================================") + +if __name__ == '__main__': + if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": + raise ValueError("do_train or do eval must have one be true, please confirm your config") + if args_opt.do_train == "true": + # run predistill + run_predistill() + lists = os.listdir(td_phase1_save_ckpt_dir) + if lists: + lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn)) + name_ext = os.path.splitext(lists[-1]) + if name_ext[-1] != ".ckpt": + raise ValueError("Invalid file, checkpoint file should be .ckpt file") + newest_ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1]) + # run task distill + run_task_distill(newest_ckpt_file) + else: + raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved") + else: + do_eval_standalone() diff --git a/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh new file mode 100644 index 0000000000..d45c280723 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_distribute_gd.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH" +echo "for example: bash scripts/run_distribute_gd.sh 8 40 /path/hccl.json" +echo "It is better to use absolute path." +echo "running....... please see details by LOG{}/log.txt" +echo "==============================================================================================================" + +EPOCH_SIZE=$2 + +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export MINDSPORE_HCCL_CONFIG_PATH=$3 +export RANK_TABLE_FILE=$3 +export RANK_SIZE=$1 +cores=`cat /proc/cpuinfo|grep "processor" |wc -l` +echo "the number of logical core" $cores +avg_core_per_rank=`expr $cores \/ $RANK_SIZE` +core_gap=`expr $avg_core_per_rank \- 1` +echo "avg_core_per_rank" $avg_core_per_rank +echo "core_gap" $core_gap +for((i=0;i env.log + taskset -c $cmdopt python ${PROJECT_DIR}/../run_general_distill.py \ + --distribute="true" \ + --device_target="Ascend" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --device_num=$RANK_SIZE \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=100 \ + --max_ckpt_num=1 \ + --save_ckpt_path="" \ + --load_teacher_ckpt_path="" \ + --data_dir="" \ + --schema_dir="" > log.txt 2>&1 & + cd ../ +done diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh new file mode 100644 index 0000000000..343d1ed7ca --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_standalone_gd.sh" +echo "for example: bash scripts/run_standalone_gd.sh" +echo "running....... please see details by log.txt" +echo "==============================================================================================================" + + +mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_general_distill.py \ + --distribute="false" \ + --device_target="Ascend" \ + --epoch_size=3 \ + --device_id=0 \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=100 \ + --max_ckpt_num=1 \ + --save_ckpt_path="" \ + --load_teacher_ckpt_path="" \ + --data_dir="" \ + --schema_dir="" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh new file mode 100644 index 0000000000..dcc01163db --- /dev/null +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scipts/run_standalone_td.sh" +echo "for example: bash scipts/run_standalone_td.sh" +echo "==============================================================================================================" + +mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_task_distill.py \ + --device_target="Ascend" \ + --device_id=0 \ + --do_train="true" \ + --do_eval="true" \ + --td_phase1_epoch_size=10 \ + --td_phase2_epoch_size=3 \ + --num_labels=2 \ + --do_shuffle="true" \ + --enable_data_sink="true" \ + --data_sink_steps=100 \ + --save_ckpt_step=100 \ + --max_ckpt_num=1 \ + --load_teacher_ckpt_path="" \ + --load_gd_ckpt_path="" \ + --load_td1_ckpt_path="" \ + --train_data_dir="" \ + --eval_data_dir="" \ + --schema_dir="" > log.txt 2>&1 & + diff --git a/model_zoo/official/nlp/tinybert/src/__init__.py b/model_zoo/official/nlp/tinybert/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/nlp/tinybert/src/assessment_method.py b/model_zoo/official/nlp/tinybert/src/assessment_method.py new file mode 100644 index 0000000000..748666e3ce --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/assessment_method.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""assessment methods""" + +import numpy as np + +class Accuracy(): + """Accuracy""" + def __init__(self): + self.acc_num = 0 + self.total_num = 0 + + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + self.acc_num += np.sum(labels == logit_id) + self.total_num += len(labels) + +class F1(): + """F1""" + def __init__(self): + self.TP = 0 + self.FP = 0 + self.FN = 0 + + def update(self, logits, labels): + """Update F1 score""" + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7]) + pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) + self.TP += np.sum(pos_eva & pos_label) + self.FP += np.sum(pos_eva & (~pos_label)) + self.FN += np.sum((~pos_eva) & pos_label) + print("-----------------precision is ", self.TP / (self.TP + self.FP)) + print("-----------------recall is ", self.TP / (self.TP + self.FN)) diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py new file mode 100644 index 0000000000..576d2ee6d9 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""create tinybert dataset""" + +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import log as logger + +def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, + do_shuffle="true", data_dir=None, schema_dir=None): + """create tinybert dataset""" + files = os.listdir(data_dir) + data_files = [] + for file_name in files: + if "record" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + if task == "td": + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + else: + columns_list = ["input_ids", "input_mask", "segment_ids"] + + ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, + shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, + shard_equal_rows=True) + + ori_dataset_size = ds.get_dataset_size() + print('origin dataset size: ', ori_dataset_size) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + if task == "td": + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeatcount: {}".format(ds.get_repeat_count())) + + return ds diff --git a/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py b/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py new file mode 100644 index 0000000000..d290842c58 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/fused_layer_norm.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fused layernorm""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr +import mindspore.common.dtype as mstype +from mindspore.nn.cell import Cell + +import numpy as np + + +__all__ = ['FusedLayerNorm'] + +@constexpr +def get_shape_for_norm(x_shape, begin_norm_axis): + print("input_shape: ", x_shape) + norm_shape = x_shape[begin_norm_axis:] + output_shape = (1, -1, 1, int(np.prod(norm_shape))) + print("output_shape: ", output_shape) + return output_shape + +class FusedLayerNorm(Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + + Layer normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Layer Normalization `_. Unlike batch + normalization, layer normalization performs exactly the same computation at training and + testing times. It can be described using the following formula. It is applied across all channels + and pixel but only one batch size. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis + `begin_norm_axis ... R - 1`. + begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions + `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. + begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters + will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with + the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + use_batch_nrom (bool): Whether use batchnorm to preocess. + + Inputs: + - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, + and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) + >>> shape1 = x.shape[1:] + >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) + >>> m(x) + """ + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + use_batch_norm=False): + super(FusedLayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." + .format(normalized_shape, type(normalized_shape))) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.gamma = Parameter(initializer( + gamma_init, normalized_shape), name="gamma") + self.beta = Parameter(initializer( + beta_init, normalized_shape), name="beta") + self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) + + self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) + self.use_batch_norm = use_batch_norm + + def construct(self, input_x): + """fusedlayernorm""" + if self.use_batch_norm and self.training: + ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) + zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) + shape_x = F.shape(input_x) + norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) + input_x = F.reshape(input_x, norm_shape) + output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) + output = F.reshape(output, shape_x) + y = output * self.gamma + self.beta + else: + y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + return y + + def extend_repr(self): + """Display instance object as string.""" + s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) + return s diff --git a/model_zoo/official/nlp/tinybert/src/gd_config.py b/model_zoo/official/nlp/tinybert/src/gd_config.py new file mode 100644 index 0000000000..d2dc09d8fa --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/gd_config.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_general_distill.py and run_task_distill.py +""" +import mindspore.common.dtype as mstype +from easydict import EasyDict as edict +from .tinybert_model import BertConfig + +common_cfg = edict({ + 'loss_scale_value': 2 ** 16, + 'scale_factor': 2, + 'scale_window': 1000, + 'AdamWeightDecay': edict({ + 'learning_rate': 5e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), +}) +''' +Including two kinds of network: \ +teacher network: The BERT-base network. +student network: The network which is inherited from teacher network. +''' +bert_teacher_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) +bert_student_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=384, + num_hidden_layers=4, + num_attention_heads=12, + intermediate_size=1536, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) diff --git a/model_zoo/official/nlp/tinybert/src/td_config.py b/model_zoo/official/nlp/tinybert/src/td_config.py new file mode 100644 index 0000000000..2a9046587e --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/td_config.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""config script for task distill""" + +import mindspore.common.dtype as mstype +from easydict import EasyDict as edict +from .tinybert_model import BertConfig + +phase1_cfg = edict({ + 'loss_scale_value': 2 ** 8, + 'scale_factor': 2, + 'scale_window': 50, + 'optimizer_cfg': edict({ + 'AdamWeightDecay': edict({ + 'learning_rate': 5e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), + }), +}) + +phase2_cfg = edict({ + 'loss_scale_value': 2 ** 16, + 'scale_factor': 2, + 'scale_window': 50, + 'optimizer_cfg': edict({ + 'AdamWeightDecay': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-14, + 'power': 1.0, + 'weight_decay': 1e-4, + 'eps': 1e-6, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + }), + }), +}) + +''' +Including two kinds of network: \ +teacher network: The BERT-base network with finetune. +student network: The model which is producted by GD phase. +''' +td_teacher_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) +td_student_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=30522, + hidden_size=384, + num_hidden_layers=4, + num_attention_heads=12, + intermediate_size=1536, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=False +) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py new file mode 100644 index 0000000000..55da0f3db9 --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -0,0 +1,498 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tinybert model""" + +import re +import mindspore.nn as nn +from mindspore import context +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.communication.management import get_group_size +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from .tinybert_model import BertModel, TinyBertModel, BertModelCLS + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") +# pylint: disable=consider-using-in +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +class ClipGradients(nn.Cell): + """ + Clip gradients. + + Args: + grads (list): List of gradient tuples. + clip_type (Tensor): The way to clip, 'value' or 'norm'. + clip_value (Tensor): Specifies how much to clip. + + Returns: + List, a list of clipped_grad tuples. + """ + def __init__(self): + super(ClipGradients, self).__init__() + self.clip_by_norm = nn.ClipByNorm() + self.cast = P.Cast() + self.dtype = P.DType() + + def construct(self, + grads, + clip_type, + clip_value): + """clip gradients""" + if clip_type != 0 and clip_type != 1: + return grads + new_grads = () + for grad in grads: + dt = self.dtype(grad) + if clip_type == 0: + t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), + self.cast(F.tuple_to_array((clip_value,)), dt)) + else: + t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) + new_grads = new_grads + (t,) + return new_grads + +class SoftCrossEntropy(nn.Cell): + """SoftCrossEntropy loss""" + def __init__(self): + super(SoftCrossEntropy, self).__init__() + self.log_softmax = P.LogSoftmax(axis=-1) + self.softmax = P.Softmax(axis=-1) + self.reduce_mean = P.ReduceMean() + self.cast = P.Cast() + + def construct(self, predicts, targets): + likelihood = self.log_softmax(predicts) + target_prob = self.softmax(targets) + loss = self.reduce_mean(-target_prob * likelihood) + + return self.cast(loss, mstype.float32) + +class BertNetworkWithLoss_gd(nn.Cell): + """ + Provide bert pre-training loss through network. + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + Returns: + Tensor, the loss of the network. + """ + def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False, + is_att_fit=True, is_rep_fit=True): + super(BertNetworkWithLoss_gd, self).__init__() + # load teacher model + self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings) + param_dict = load_checkpoint(teacher_ckpt) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('^bert.bert.', 'teacher.', key) + new_param_dict[new_key] = value + load_param_into_net(self.teacher, new_param_dict) + # no_grad + self.teacher.set_train(False) + params = self.teacher.trainable_params() + for param in params: + param.requires_grad = False + # student model + self.bert = TinyBertModel(student_config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.fit_dense = nn.Dense(student_config.hidden_size, + teacher_config.hidden_size).to_float(teacher_config.compute_type) + self.teacher_layers_num = teacher_config.num_hidden_layers + self.student_layers_num = student_config.num_hidden_layers + self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num) + self.is_att_fit = is_att_fit + self.is_rep_fit = is_rep_fit + self.loss_mse = nn.MSELoss() + self.select = P.Select() + self.zeroslike = P.ZerosLike() + self.dtype = teacher_config.dtype + + def construct(self, + input_ids, + input_mask, + token_type_id): + """general distill network with loss""" + # teacher model + _, _, _, teacher_seq_output, teacher_att_output = self.teacher(input_ids, token_type_id, input_mask) + # student model + _, _, _, student_seq_output, student_att_output = self.bert(input_ids, token_type_id, input_mask) + total_loss = 0 + if self.is_att_fit: + selected_teacher_att_output = () + selected_student_att_output = () + for i in range(self.student_layers_num): + selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],) + selected_student_att_output += (student_att_output[i],) + att_loss = 0 + for i in range(self.student_layers_num): + student_att = selected_student_att_output[i] + teacher_att = selected_teacher_att_output[i] + student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), self.zeroslike(student_att), + student_att) + teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), self.zeroslike(teacher_att), + teacher_att) + att_loss += self.loss_mse(student_att, teacher_att) + total_loss += att_loss + if self.is_rep_fit: + selected_teacher_seq_output = () + selected_student_seq_output = () + for i in range(self.student_layers_num + 1): + selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],) + fit_dense_out = self.fit_dense(student_seq_output[i]) + fit_dense_out = self.cast(fit_dense_out, self.dtype) + selected_student_seq_output += (fit_dense_out,) + rep_loss = 0 + for i in range(self.student_layers_num + 1): + teacher_rep = selected_teacher_seq_output[i] + student_rep = selected_student_seq_output[i] + rep_loss += self.loss_mse(student_rep, teacher_rep) + total_loss += rep_loss + return self.cast(total_loss, mstype.float32) + +class BertTrainWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) + +class BertNetworkWithLoss_td(nn.Cell): + """ + Provide bert pre-training loss through network. + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + Returns: + Tensor, the loss of the network. + """ + def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt, + is_training, task_type, num_labels, use_one_hot_embeddings=False, + is_predistill=True, is_att_fit=True, is_rep_fit=True, + temperature=1.0, dropout_prob=0.1): + super(BertNetworkWithLoss_td, self).__init__() + # load teacher model + self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, + use_one_hot_embeddings, "teacher") + param_dict = load_checkpoint(teacher_ckpt) + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('^bert.', 'teacher.', key) + new_param_dict[new_key] = value + load_param_into_net(self.teacher, new_param_dict) + + # no_grad + self.teacher.set_train(False) + params = self.teacher.trainable_params() + for param in params: + param.requires_grad = False + # load student model + self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, + use_one_hot_embeddings, "student") + param_dict = load_checkpoint(student_ckpt) + if is_predistill: + new_param_dict = {} + for key, value in param_dict.items(): + # new_key = re.sub('tinybert_', 'bert_', key) + new_key = re.sub('tinybert_', 'bert_', 'bert.' + key) + new_param_dict[new_key] = value + load_param_into_net(self.bert, new_param_dict) + else: + new_param_dict = {} + for key, value in param_dict.items(): + new_key = re.sub('tinybert_', 'bert_', key) + # new_key = re.sub('tinybert_', 'bert_', 'bert.'+ key) + new_param_dict[new_key] = value + load_param_into_net(self.bert, new_param_dict) + self.cast = P.Cast() + self.fit_dense = nn.Dense(student_config.hidden_size, + teacher_config.hidden_size).to_float(teacher_config.compute_type) + self.teacher_layers_num = teacher_config.num_hidden_layers + self.student_layers_num = student_config.num_hidden_layers + self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num) + self.is_predistill = is_predistill + self.is_att_fit = is_att_fit + self.is_rep_fit = is_rep_fit + self.task_type = task_type + self.temperature = temperature + self.loss_mse = nn.MSELoss() + self.select = P.Select() + self.zeroslike = P.ZerosLike() + self.dtype = student_config.dtype + self.num_labels = num_labels + self.dtype = teacher_config.dtype + self.soft_cross_entropy = SoftCrossEntropy() + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids): + """task distill network with loss""" + # teacher model + teacher_seq_output, teacher_att_output, teacher_logits, _ = self.teacher(input_ids, token_type_id, input_mask) + # student model + student_seq_output, student_att_output, student_logits, _ = self.bert(input_ids, token_type_id, input_mask) + total_loss = 0 + if self.is_predistill: + if self.is_att_fit: + selected_teacher_att_output = () + selected_student_att_output = () + for i in range(self.student_layers_num): + selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],) + selected_student_att_output += (student_att_output[i],) + att_loss = 0 + for i in range(self.student_layers_num): + student_att = selected_student_att_output[i] + teacher_att = selected_teacher_att_output[i] + student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), + self.zeroslike(student_att), + student_att) + teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), + self.zeroslike(teacher_att), + teacher_att) + att_loss += self.loss_mse(student_att, teacher_att) + total_loss += att_loss + if self.is_rep_fit: + selected_teacher_seq_output = () + selected_student_seq_output = () + for i in range(self.student_layers_num + 1): + selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],) + fit_dense_out = self.fit_dense(student_seq_output[i]) + fit_dense_out = self.cast(fit_dense_out, self.dtype) + selected_student_seq_output += (fit_dense_out,) + rep_loss = 0 + for i in range(self.student_layers_num + 1): + teacher_rep = selected_teacher_seq_output[i] + student_rep = selected_student_seq_output[i] + rep_loss += self.loss_mse(student_rep, teacher_rep) + total_loss += rep_loss + else: + if self.task_type == "classification": + cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) + else: + cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) + total_loss += cls_loss + return self.cast(total_loss, mstype.float32) + +class BertEvaluationCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertEvaluationCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py new file mode 100644 index 0000000000..cc5477bc4f --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -0,0 +1,1054 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .fused_layer_norm import FusedLayerNorm + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded + from dataset. Default: True. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32, + enable_fused_layernorm=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.input_mask_from_dataset = input_mask_from_dataset + self.token_type_ids_from_dataset = token_type_ids_from_dataset + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + """embedding lookup""" + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + use_relative_positions, + embedding_size, + embedding_shape, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [token_type_vocab_size, + embedding_size]), + name='embedding_table') + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + self.full_position_embeddings = Parameter(initializer + (TruncatedNormal(initializer_range), + [max_position_embeddings, + embedding_size]), + name='full_position_embeddings') + + def construct(self, token_type_ids, word_embeddings): + """embedding postprocessor""" + output = word_embeddings + if self.use_token_type: + flat_ids = self.reshape(token_type_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, + self.token_type_vocab_size, self.on_value, self.off_value) + token_type_embeddings = self.array_mul(one_hot_ids, + self.embedding_table) + else: + token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) + token_type_embeddings = self.reshape(token_type_embeddings, self.shape) + output += token_type_embeddings + if not self.use_relative_positions: + _, seq, width = self.shape + position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) + position_embeddings = self.reshape(position_embeddings, (1, seq, width)) + output += position_embeddings + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.add = P.TensorAdd() + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + """bert output""" + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + def __init__(self, length, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._length = length + self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) + self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) + self.range_length = -length + 1 + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self): + """position matrix generator""" + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) + tile_row_out = self.tile(range_vec_row_out, (self._length,)) + tile_col_out = self.tile(range_vec_col_out, (1, self._length)) + range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) + transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) + distance_mat = self.sub(range_mat_out, transpose_out) + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + length, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth]), + name='embeddings_for_position') + self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, + max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.shape = P.Shape() + self.gather = P.GatherV2() # index_select + self.matmul = P.BatchMatMul() + + def construct(self): + """position embedding generation""" + relative_positions_matrix_out = self.relative_positions_matrix() + # Generate embedding for each relative position of dimension depth. + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + min_type = np.finfo(np_type).min + max_type = np.finfo(np_type).max + self.tensor_min_type = Tensor([min_type], dtype=src_type) + self.tensor_max_type = Tensor([max_type], dtype=src_type) + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + """saturate cast""" + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + use_relative_positions=False, + compute_type=mstype.float32): + super(BertAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = ( + batch_size, to_seq_length, num_attention_heads, size_per_head) + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + """bert attention""" + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + attention_scores = self.matmul_trans_b(query_layer, key_layer) + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_keys' = [F|T, F|T, H] + relations_keys = self._generate_relative_positions_embeddings() + relations_keys = self.cast_compute_type(relations_keys) + # query_layer_t is [F, B, N, H] + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + # query_layer_r is [F, B * N, H] + query_layer_r = self.reshape(query_layer_t, + (self.from_seq_length, + self.batch_num, + self.size_per_head)) + # key_position_scores is [F, B * N, F|T] + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + # key_position_scores_r is [F, B, N, F|T] + key_position_scores_r = self.reshape(key_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.from_seq_length)) + # key_position_scores_r_t is [B, N, F, F|T] + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + attention_scores = self.multiply(self.scores_mul, attention_scores) + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_values' = [F|T, F|T, H] + relations_values = self._generate_relative_positions_embeddings() + relations_values = self.cast_compute_type(relations_values) + # attention_probs_t is [F, B, N, T] + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + # attention_probs_r is [F, B * N, T] + attention_probs_r = self.reshape( + attention_probs_t, + (self.from_seq_length, + self.batch_num, + self.to_seq_length)) + # value_position_scores is [F, B * N, H] + value_position_scores = self.matmul(attention_probs_r, + relations_values) + # value_position_scores_r is [F, B, N, H] + value_position_scores_r = self.reshape(value_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.size_per_head)) + # value_position_scores_r_t is [B, N, F, H] + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + return context_layer, attention_scores + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + self.size_per_head = int(hidden_size / num_attention_heads) + self.attention = BertAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + do_return_2d_tensor=True, + compute_type=compute_type) + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + """bert self attention""" + input_tensor = self.reshape(input_tensor, self.shape) + attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output, attention_scores + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the bert encoder layers. Default: 768. + seq_length (int): Length of input sequence. Default: 512. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=768, + seq_length=512, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + def construct(self, hidden_states, attention_mask): + """bert encoder cell""" + # self-attention + attention_output, attention_scores = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output, attention_scores + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False, + enable_fused_layernorm=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + layers.append(layer) + self.layers = nn.CellList(layers) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + def construct(self, input_tensor, attention_mask): + """bert transformer""" + prev_output = self.reshape(input_tensor, self.shape) + all_encoder_layers = () + all_encoder_atts = () + all_encoder_outputs = () + all_encoder_outputs += (prev_output,) + for layer_module in self.layers: + layer_output, encoder_att = layer_module(prev_output, attention_mask) + prev_output = layer_output + if self.return_all_encoders: + all_encoder_outputs += (layer_output,) + layer_output = self.reshape(layer_output, self.out_shape) + all_encoder_layers += (layer_output,) + all_encoder_atts += (encoder_att,) + if not self.return_all_encoders: + prev_output = self.reshape(prev_output, self.out_shape) + all_encoder_layers += (prev_output,) + return all_encoder_layers, all_encoder_outputs, all_encoder_atts + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask_from_dataset = config.input_mask_from_dataset + self.input_mask = None + if not self.input_mask_from_dataset: + self.input_mask = initializer( + "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = (config.batch_size, 1, config.seq_length) + self.broadcast_ones = initializer( + "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() + self.batch_matmul = P.BatchMatMul() + def construct(self, input_mask): + if not self.input_mask_from_dataset: + input_mask = self.input_mask + input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) + attention_mask = self.batch_matmul(self.broadcast_ones, input_mask) + return attention_mask + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + self.bert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + use_relative_positions=config.use_relative_positions, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.bert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """bert model""" + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + # bert encoder + encoder_output, encoder_layers, layer_atts = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + encoder_outputs = () + for output in encoder_layers: + encoder_outputs += (self.cast(output, self.dtype),) + attention_outputs = () + for output in layer_atts: + attention_outputs += (self.cast(output, self.dtype),) + return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs + + +class TinyBertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(TinyBertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + self.tinybert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.tinybert_embedding_postprocessor = EmbeddingPostprocessor( + use_relative_positions=config.use_relative_positions, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.tinybert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """tiny bert model""" + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids) + embedding_output = self.tinybert_embedding_postprocessor(token_type_ids, + word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + # bert encoder + encoder_output, encoder_layers, layer_atts = self.tinybert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + encoder_outputs = () + for output in encoder_layers: + encoder_outputs += (self.cast(output, self.dtype),) + attention_outputs = () + for output in layer_atts: + attention_outputs += (self.cast(output, self.dtype),) + return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs + + +class BertModelCLS(nn.Cell): + """ + This class is responsible for classification task evaluation, + i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2). + The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, + use_one_hot_embeddings=False, phase_type="teacher"): + super(BertModelCLS, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.phase_type = phase_type + if self.phase_type == "teacher": + self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + else: + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.ReLU() + + def construct(self, input_ids, token_type_id, input_mask): + """classification bert model""" + _, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask) + cls = self.cast(pooled_output, self.dtype) + cls = self.dropout(cls) + if self.phase_type == "teacher": + logits = self.dense(cls) + else: + logits = self.dense_1(cls) + logits = self.cast(logits, self.dtype) + log_probs = self.log_softmax(logits) + return seq_output, att_output, logits, log_probs diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py new file mode 100644 index 0000000000..d10fb8642e --- /dev/null +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -0,0 +1,140 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""tinybert utils""" + +import os +import numpy as np +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.callback import Callback +from mindspore.train.serialization import _exec_save_checkpoint +from mindspore.ops import operations as P +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +from .assessment_method import Accuracy + +class ModelSaveCkpt(Callback): + """ + Saves checkpoint. + If the loss in NAN or INF terminating training. + Args: + network (Network): The train network for training. + save_ckpt_num (int): The number to save checkpoint, default is 1000. + max_ckpt_num (int): The max checkpoint number, default is 3. + """ + def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir): + super(ModelSaveCkpt, self).__init__() + self.count = 0 + self.network = network + self.save_ckpt_step = save_ckpt_step + self.max_ckpt_num = max_ckpt_num + self.output_dir = output_dir + + def step_end(self, run_context): + """step end and save ckpt""" + cb_params = run_context.original_args() + if cb_params.cur_step_num % self.save_ckpt_step == 0: + saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step + if saved_ckpt_num > self.max_ckpt_num: + oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num + path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index), + self.save_ckpt_step)) + if os.path.exists(path): + os.remove(path) + _exec_save_checkpoint(self.network, os.path.join(self.output_dir, + "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), + self.save_ckpt_step))) + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + + def step_end(self, run_context): + """step end and print loss""" + cb_params = run_context.original_args() + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs))) + +class EvalCallBack(Callback): + """Evaluation callback""" + def __init__(self, network, dataset): + super(EvalCallBack, self).__init__() + self.network = network + self.global_acc = 0.0 + self.dataset = dataset + + def step_end(self, run_context): + """step end and do evaluation""" + cb_params = run_context.original_args() + if cb_params.cur_step_num % 100 == 0: + callback = Accuracy() + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in self.dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + self.network.set_train(False) + logits = self.network(input_ids, token_type_id, input_mask) + callback.update(logits[3], label_ids) + acc = callback.acc_num / callback.total_num + with open("./eval.log", "a+") as f: + f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + f.write('\n') + + if acc > self.global_acc: + self.global_acc = acc + print("The best acc is {}".format(acc)) + _exec_save_checkpoint(self.network, "eval_model.ckpt") + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr