commit
9ab94fa076
|
@ -0,0 +1,129 @@
|
|||
# TinyBERT Example
|
||||
## Description
|
||||
This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model).
|
||||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download dataset for general distill and task distill such as GLUE.
|
||||
- Prepare a pre-trained bert model and a fine-tuned bert model for specific task such as GLUE.
|
||||
|
||||
## Running the Example
|
||||
### General Distill
|
||||
- Set options in `src/gd_config.py`, including lossscale, optimizer and network.
|
||||
|
||||
- Set options in `scripts/run_standalone_gd.sh`, including device target, data sink config, checkpoint config and dataset. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
|
||||
|
||||
- Run `run_standalone_gd.sh` for non-distributed general distill of BERT-base model.
|
||||
|
||||
``` bash
|
||||
bash scripts/run_standalone_gd.sh
|
||||
```
|
||||
- Run `run_distribute_gd.sh` for distributed general distill of BERT-base model.
|
||||
|
||||
``` bash
|
||||
bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH
|
||||
```
|
||||
|
||||
### Task Distill
|
||||
Task distill has two phases, pre-distill and task distill.
|
||||
- Set options in `src/td_config.py`, including lossscale, optimizer config of phase 1 and 2, as well as network config.
|
||||
|
||||
- Run `run_standalone_td.py` for task distill of BERT-base model.
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_td.sh
|
||||
```
|
||||
|
||||
## Usage
|
||||
### General Distill
|
||||
```
|
||||
usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET]
|
||||
[--epoch_size N] [--device_id N]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||
[--save_checkpoint_steps N] [--max_ckpt_num N]
|
||||
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||
|
||||
options:
|
||||
--distribute whether to run distributely: "true" | "false"
|
||||
--device_target target device to run, currently only support "Ascend"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is ""
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
|
||||
usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_TARGET]
|
||||
[--epoch_size N] [--device_id N] [--device_num N]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
|
||||
[--save_ckpt_steps N] [--max_ckpt_num N]
|
||||
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||
|
||||
options:
|
||||
--distribute whether to run distributely: "true" | "false"
|
||||
--device_target target device to run, currently only support "Ascend"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--device_num device id to run task
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--load_teacher_ckpt_path path of teacher checkpoint to load: PATH, default is ""
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
|
||||
```
|
||||
|
||||
## Options and Parameters
|
||||
`gd_config.py` and `td_config.py` Contain parameters of BERT model and options for optimizer and lossscale.
|
||||
### Options:
|
||||
```
|
||||
Parameters for lossscale:
|
||||
loss_scale_value initial value of loss scale: N, default is 2^8
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 50
|
||||
|
||||
Parameters for task-specific config:
|
||||
load_teacher_ckpt_path teacher checkpoint to load
|
||||
load_student_ckpt_path student checkpoint to load
|
||||
data_dir training data dir
|
||||
eval_data_dir evaluation data dir
|
||||
schema_dir data schema path
|
||||
```
|
||||
|
||||
### Parameters:
|
||||
```
|
||||
Parameters for bert network:
|
||||
batch_size batch size of input dataset: N, default is 16
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
save_ckpt_step number for saving checkponit: N, default is 100
|
||||
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||
type_vocab_size size of token type vocab: N, default is 2
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
|
||||
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
|
||||
enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False
|
||||
|
||||
Parameters for optimizer:
|
||||
optimizer optimizer used in the network: AdamWeightDecay
|
||||
learning_rate value of learning rate: Q
|
||||
end_learning_rate value of end learning rate: Q, must be positive
|
||||
power power: Q
|
||||
weight_decay weight decay: Q
|
||||
eps term added to the denominator to improve numerical stability: Q
|
||||
```
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""general distill script"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd
|
||||
|
||||
def run_general_distill():
|
||||
"""
|
||||
run general distill
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='tinybert general distill')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
|
||||
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
if args_opt.distribute == "true":
|
||||
D.init('hccl')
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
|
||||
teacher_ckpt=args_opt.load_teacher_ckpt_path,
|
||||
student_config=bert_student_net_cfg,
|
||||
is_training=True, use_one_hot_embeddings=False)
|
||||
|
||||
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
|
||||
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
|
||||
if args_opt.enable_data_sink == "true":
|
||||
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.epoch_size
|
||||
|
||||
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
|
||||
decay_steps=int(dataset_size * args_opt.epoch_size),
|
||||
power=common_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
|
||||
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
save_ckpt_dir)]
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||
scale_factor=common_cfg.scale_factor,
|
||||
scale_window=common_cfg.scale_window)
|
||||
|
||||
netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
run_general_distill()
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""task distill script"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.assessment_method import Accuracy
|
||||
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
|
||||
from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td
|
||||
from src.tinybert_model import BertModelCLS
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
|
||||
td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase2_save_ckpt')
|
||||
if not os.path.exists(td_phase1_save_ckpt_dir):
|
||||
os.makedirs(td_phase1_save_ckpt_dir)
|
||||
if not os.path.exists(td_phase2_save_ckpt_dir):
|
||||
os.makedirs(td_phase2_save_ckpt_dir)
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='tinybert task distill')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.")
|
||||
parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.")
|
||||
parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.")
|
||||
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
|
||||
help="Epoch size for td phase 1, default is 10.")
|
||||
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args_opt = parse_args()
|
||||
def run_predistill():
|
||||
"""
|
||||
run predistill
|
||||
"""
|
||||
cfg = phase1_cfg
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
|
||||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
num_labels=args_opt.num_labels, is_predistill=True)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, args_opt.schema_dir)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.td_phase1_epoch_size
|
||||
|
||||
optimizer_cfg = cfg.optimizer_cfg
|
||||
|
||||
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(dataset_size / 10),
|
||||
decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
|
||||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase1_save_ckpt_dir)]
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
def run_task_distill(ckpt_file):
|
||||
"""
|
||||
run task distill
|
||||
"""
|
||||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
cfg = phase2_cfg
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
||||
load_student_checkpoint_path = ckpt_file
|
||||
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
||||
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
is_training=True, task_type='classification',
|
||||
num_labels=args_opt.num_labels, is_predistill=False)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, args_opt.schema_dir)
|
||||
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.td_phase2_epoch_size
|
||||
|
||||
optimizer_cfg = cfg.optimizer_cfg
|
||||
|
||||
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
|
||||
decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
|
||||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
|
||||
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||
ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase2_save_ckpt_dir),
|
||||
EvalCallBack(netwithloss.bert, eval_dataset)]
|
||||
else:
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||
ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase2_save_ckpt_dir)]
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
def do_eval_standalone():
|
||||
"""
|
||||
do eval standalone
|
||||
"""
|
||||
ckpt_file = args_opt.load_td1_ckpt_path
|
||||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student")
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', key)
|
||||
new_key = re.sub('^bert.', '', new_key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(eval_model, new_param_dict)
|
||||
eval_model.set_train(False)
|
||||
|
||||
eval_dataset = create_tinybert_dataset('td', batch_size=1,
|
||||
device_num=1, rank=0, do_shuffle="false",
|
||||
data_dir=args_opt.eval_data_dir,
|
||||
schema_dir=args_opt.schema_dir)
|
||||
callback = Accuracy()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in eval_dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(Tensor(data[i]))
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
logits = eval_model(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits[3], label_ids)
|
||||
acc = callback.acc_num / callback.total_num
|
||||
print("======================================")
|
||||
print("============== acc is {}".format(acc))
|
||||
print("======================================")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
||||
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
||||
if args_opt.do_train == "true":
|
||||
# run predistill
|
||||
run_predistill()
|
||||
lists = os.listdir(td_phase1_save_ckpt_dir)
|
||||
if lists:
|
||||
lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn))
|
||||
name_ext = os.path.splitext(lists[-1])
|
||||
if name_ext[-1] != ".ckpt":
|
||||
raise ValueError("Invalid file, checkpoint file should be .ckpt file")
|
||||
newest_ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1])
|
||||
# run task distill
|
||||
run_task_distill(newest_ckpt_file)
|
||||
else:
|
||||
raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved")
|
||||
else:
|
||||
do_eval_standalone()
|
|
@ -0,0 +1,72 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash scripts/run_distribute_gd.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: bash scripts/run_distribute_gd.sh 8 40 /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "running....... please see details by LOG{}/log.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$3
|
||||
export RANK_TABLE_FILE=$3
|
||||
export RANK_SIZE=$1
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--distribute="true" \
|
||||
--device_target="Ascend" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--device_num=$RANK_SIZE \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
--save_ckpt_step=100 \
|
||||
--max_ckpt_num=1 \
|
||||
--save_ckpt_path="" \
|
||||
--load_teacher_ckpt_path="" \
|
||||
--data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash scripts/run_standalone_gd.sh"
|
||||
echo "for example: bash scripts/run_standalone_gd.sh"
|
||||
echo "running....... please see details by log.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_general_distill.py \
|
||||
--distribute="false" \
|
||||
--device_target="Ascend" \
|
||||
--epoch_size=3 \
|
||||
--device_id=0 \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
--save_ckpt_step=100 \
|
||||
--max_ckpt_num=1 \
|
||||
--save_ckpt_path="" \
|
||||
--load_teacher_ckpt_path="" \
|
||||
--data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash scipts/run_standalone_td.sh"
|
||||
echo "for example: bash scipts/run_standalone_td.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_task_distill.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=0 \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--td_phase1_epoch_size=10 \
|
||||
--td_phase2_epoch_size=3 \
|
||||
--num_labels=2 \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
--save_ckpt_step=100 \
|
||||
--max_ckpt_num=1 \
|
||||
--load_teacher_ckpt_path="" \
|
||||
--load_gd_ckpt_path="" \
|
||||
--load_td1_ckpt_path="" \
|
||||
--train_data_dir="" \
|
||||
--eval_data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""assessment methods"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Accuracy():
|
||||
"""Accuracy"""
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
self.acc_num += np.sum(labels == logit_id)
|
||||
self.total_num += len(labels)
|
||||
|
||||
class F1():
|
||||
"""F1"""
|
||||
def __init__(self):
|
||||
self.TP = 0
|
||||
self.FP = 0
|
||||
self.FN = 0
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""Update F1 score"""
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
logit_id = np.reshape(logit_id, -1)
|
||||
pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7])
|
||||
pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7])
|
||||
self.TP += np.sum(pos_eva & pos_label)
|
||||
self.FP += np.sum(pos_eva & (~pos_label))
|
||||
self.FN += np.sum((~pos_eva) & pos_label)
|
||||
print("-----------------precision is ", self.TP / (self.TP + self.FP))
|
||||
print("-----------------recall is ", self.TP / (self.TP + self.FN))
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create tinybert dataset"""
|
||||
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
|
||||
def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
||||
do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
"""create tinybert dataset"""
|
||||
files = os.listdir(data_dir)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
if "record" in file_name:
|
||||
data_files.append(os.path.join(data_dir, file_name))
|
||||
if task == "td":
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
else:
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids"]
|
||||
|
||||
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
if task == "td":
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
|
||||
return ds
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""fused layernorm"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
__all__ = ['FusedLayerNorm']
|
||||
|
||||
@constexpr
|
||||
def get_shape_for_norm(x_shape, begin_norm_axis):
|
||||
print("input_shape: ", x_shape)
|
||||
norm_shape = x_shape[begin_norm_axis:]
|
||||
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
|
||||
print("output_shape: ", output_shape)
|
||||
return output_shape
|
||||
|
||||
class FusedLayerNorm(Cell):
|
||||
r"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
Layer normalization is widely used in recurrent neural networks. It applies
|
||||
normalization over a mini-batch of inputs for each single training case as described
|
||||
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
|
||||
normalization, layer normalization performs exactly the same computation at training and
|
||||
testing times. It can be described using the following formula. It is applied across all channels
|
||||
and pixel but only one batch size.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Args:
|
||||
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
|
||||
`begin_norm_axis ... R - 1`.
|
||||
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
|
||||
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
|
||||
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
||||
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
||||
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
use_batch_nrom (bool): Whether use batchnorm to preocess.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
|
||||
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||
>>> shape1 = x.shape[1:]
|
||||
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
|
||||
>>> m(x)
|
||||
"""
|
||||
def __init__(self,
|
||||
normalized_shape,
|
||||
begin_norm_axis=-1,
|
||||
begin_params_axis=-1,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
use_batch_norm=False):
|
||||
super(FusedLayerNorm, self).__init__()
|
||||
if not isinstance(normalized_shape, (tuple, list)):
|
||||
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||
.format(normalized_shape, type(normalized_shape)))
|
||||
self.normalized_shape = normalized_shape
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.begin_params_axis = begin_params_axis
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, normalized_shape), name="gamma")
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, normalized_shape), name="beta")
|
||||
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
|
||||
|
||||
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
def construct(self, input_x):
|
||||
"""fusedlayernorm"""
|
||||
if self.use_batch_norm and self.training:
|
||||
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
|
||||
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
|
||||
shape_x = F.shape(input_x)
|
||||
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
|
||||
input_x = F.reshape(input_x, norm_shape)
|
||||
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
|
||||
output = F.reshape(output, shape_x)
|
||||
y = output * self.gamma + self.beta
|
||||
else:
|
||||
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||
return y
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||
return s
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_general_distill.py and run_task_distill.py
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from easydict import EasyDict as edict
|
||||
from .tinybert_model import BertConfig
|
||||
|
||||
common_cfg = edict({
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
})
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
teacher network: The BERT-base network.
|
||||
student network: The network which is inherited from teacher network.
|
||||
'''
|
||||
bert_teacher_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=False
|
||||
)
|
||||
bert_student_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=384,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1536,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=False
|
||||
)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""config script for task distill"""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from easydict import EasyDict as edict
|
||||
from .tinybert_model import BertConfig
|
||||
|
||||
phase1_cfg = edict({
|
||||
'loss_scale_value': 2 ** 8,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
phase2_cfg = edict({
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
teacher network: The BERT-base network with finetune.
|
||||
student network: The model which is producted by GD phase.
|
||||
'''
|
||||
td_teacher_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=False
|
||||
)
|
||||
td_student_net_cfg = BertConfig(
|
||||
batch_size=32,
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=384,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=1536,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=False
|
||||
)
|
|
@ -0,0 +1,498 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Tinybert model"""
|
||||
|
||||
import re
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
# pylint: disable=consider-using-in
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""clip gradients"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
class SoftCrossEntropy(nn.Cell):
|
||||
"""SoftCrossEntropy loss"""
|
||||
def __init__(self):
|
||||
super(SoftCrossEntropy, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.softmax = P.Softmax(axis=-1)
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, predicts, targets):
|
||||
likelihood = self.log_softmax(predicts)
|
||||
target_prob = self.softmax(targets)
|
||||
loss = self.reduce_mean(-target_prob * likelihood)
|
||||
|
||||
return self.cast(loss, mstype.float32)
|
||||
|
||||
class BertNetworkWithLoss_gd(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False,
|
||||
is_att_fit=True, is_rep_fit=True):
|
||||
super(BertNetworkWithLoss_gd, self).__init__()
|
||||
# load teacher model
|
||||
self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings)
|
||||
param_dict = load_checkpoint(teacher_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('^bert.bert.', 'teacher.', key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.teacher, new_param_dict)
|
||||
# no_grad
|
||||
self.teacher.set_train(False)
|
||||
params = self.teacher.trainable_params()
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
# student model
|
||||
self.bert = TinyBertModel(student_config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.fit_dense = nn.Dense(student_config.hidden_size,
|
||||
teacher_config.hidden_size).to_float(teacher_config.compute_type)
|
||||
self.teacher_layers_num = teacher_config.num_hidden_layers
|
||||
self.student_layers_num = student_config.num_hidden_layers
|
||||
self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
|
||||
self.is_att_fit = is_att_fit
|
||||
self.is_rep_fit = is_rep_fit
|
||||
self.loss_mse = nn.MSELoss()
|
||||
self.select = P.Select()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.dtype = teacher_config.dtype
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id):
|
||||
"""general distill network with loss"""
|
||||
# teacher model
|
||||
_, _, _, teacher_seq_output, teacher_att_output = self.teacher(input_ids, token_type_id, input_mask)
|
||||
# student model
|
||||
_, _, _, student_seq_output, student_att_output = self.bert(input_ids, token_type_id, input_mask)
|
||||
total_loss = 0
|
||||
if self.is_att_fit:
|
||||
selected_teacher_att_output = ()
|
||||
selected_student_att_output = ()
|
||||
for i in range(self.student_layers_num):
|
||||
selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
|
||||
selected_student_att_output += (student_att_output[i],)
|
||||
att_loss = 0
|
||||
for i in range(self.student_layers_num):
|
||||
student_att = selected_student_att_output[i]
|
||||
teacher_att = selected_teacher_att_output[i]
|
||||
student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), self.zeroslike(student_att),
|
||||
student_att)
|
||||
teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), self.zeroslike(teacher_att),
|
||||
teacher_att)
|
||||
att_loss += self.loss_mse(student_att, teacher_att)
|
||||
total_loss += att_loss
|
||||
if self.is_rep_fit:
|
||||
selected_teacher_seq_output = ()
|
||||
selected_student_seq_output = ()
|
||||
for i in range(self.student_layers_num + 1):
|
||||
selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
|
||||
fit_dense_out = self.fit_dense(student_seq_output[i])
|
||||
fit_dense_out = self.cast(fit_dense_out, self.dtype)
|
||||
selected_student_seq_output += (fit_dense_out,)
|
||||
rep_loss = 0
|
||||
for i in range(self.student_layers_num + 1):
|
||||
teacher_rep = selected_teacher_seq_output[i]
|
||||
student_rep = selected_student_seq_output[i]
|
||||
rep_loss += self.loss_mse(student_rep, teacher_rep)
|
||||
total_loss += rep_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
class BertTrainWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertNetworkWithLoss_td(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt,
|
||||
is_training, task_type, num_labels, use_one_hot_embeddings=False,
|
||||
is_predistill=True, is_att_fit=True, is_rep_fit=True,
|
||||
temperature=1.0, dropout_prob=0.1):
|
||||
super(BertNetworkWithLoss_td, self).__init__()
|
||||
# load teacher model
|
||||
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "teacher")
|
||||
param_dict = load_checkpoint(teacher_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('^bert.', 'teacher.', key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.teacher, new_param_dict)
|
||||
|
||||
# no_grad
|
||||
self.teacher.set_train(False)
|
||||
params = self.teacher.trainable_params()
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
# load student model
|
||||
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "student")
|
||||
param_dict = load_checkpoint(student_ckpt)
|
||||
if is_predistill:
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
# new_key = re.sub('tinybert_', 'bert_', key)
|
||||
new_key = re.sub('tinybert_', 'bert_', 'bert.' + key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.bert, new_param_dict)
|
||||
else:
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', key)
|
||||
# new_key = re.sub('tinybert_', 'bert_', 'bert.'+ key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.bert, new_param_dict)
|
||||
self.cast = P.Cast()
|
||||
self.fit_dense = nn.Dense(student_config.hidden_size,
|
||||
teacher_config.hidden_size).to_float(teacher_config.compute_type)
|
||||
self.teacher_layers_num = teacher_config.num_hidden_layers
|
||||
self.student_layers_num = student_config.num_hidden_layers
|
||||
self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
|
||||
self.is_predistill = is_predistill
|
||||
self.is_att_fit = is_att_fit
|
||||
self.is_rep_fit = is_rep_fit
|
||||
self.task_type = task_type
|
||||
self.temperature = temperature
|
||||
self.loss_mse = nn.MSELoss()
|
||||
self.select = P.Select()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.dtype = student_config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dtype = teacher_config.dtype
|
||||
self.soft_cross_entropy = SoftCrossEntropy()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids):
|
||||
"""task distill network with loss"""
|
||||
# teacher model
|
||||
teacher_seq_output, teacher_att_output, teacher_logits, _ = self.teacher(input_ids, token_type_id, input_mask)
|
||||
# student model
|
||||
student_seq_output, student_att_output, student_logits, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
total_loss = 0
|
||||
if self.is_predistill:
|
||||
if self.is_att_fit:
|
||||
selected_teacher_att_output = ()
|
||||
selected_student_att_output = ()
|
||||
for i in range(self.student_layers_num):
|
||||
selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
|
||||
selected_student_att_output += (student_att_output[i],)
|
||||
att_loss = 0
|
||||
for i in range(self.student_layers_num):
|
||||
student_att = selected_student_att_output[i]
|
||||
teacher_att = selected_teacher_att_output[i]
|
||||
student_att = self.select(student_att <= self.cast(-100.0, mstype.float32),
|
||||
self.zeroslike(student_att),
|
||||
student_att)
|
||||
teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32),
|
||||
self.zeroslike(teacher_att),
|
||||
teacher_att)
|
||||
att_loss += self.loss_mse(student_att, teacher_att)
|
||||
total_loss += att_loss
|
||||
if self.is_rep_fit:
|
||||
selected_teacher_seq_output = ()
|
||||
selected_student_seq_output = ()
|
||||
for i in range(self.student_layers_num + 1):
|
||||
selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
|
||||
fit_dense_out = self.fit_dense(student_seq_output[i])
|
||||
fit_dense_out = self.cast(fit_dense_out, self.dtype)
|
||||
selected_student_seq_output += (fit_dense_out,)
|
||||
rep_loss = 0
|
||||
for i in range(self.student_layers_num + 1):
|
||||
teacher_rep = selected_teacher_seq_output[i]
|
||||
student_rep = selected_student_seq_output[i]
|
||||
rep_loss += self.loss_mse(student_rep, teacher_rep)
|
||||
total_loss += rep_loss
|
||||
else:
|
||||
if self.task_type == "classification":
|
||||
cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
|
||||
else:
|
||||
cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1])
|
||||
total_loss += cls_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
class BertEvaluationCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertEvaluationCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""tinybert utils"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import _exec_save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from .assessment_method import Accuracy
|
||||
|
||||
class ModelSaveCkpt(Callback):
|
||||
"""
|
||||
Saves checkpoint.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Args:
|
||||
network (Network): The train network for training.
|
||||
save_ckpt_num (int): The number to save checkpoint, default is 1000.
|
||||
max_ckpt_num (int): The max checkpoint number, default is 3.
|
||||
"""
|
||||
def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir):
|
||||
super(ModelSaveCkpt, self).__init__()
|
||||
self.count = 0
|
||||
self.network = network
|
||||
self.save_ckpt_step = save_ckpt_step
|
||||
self.max_ckpt_num = max_ckpt_num
|
||||
self.output_dir = output_dir
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and save ckpt"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.save_ckpt_step == 0:
|
||||
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
|
||||
if saved_ckpt_num > self.max_ckpt_num:
|
||||
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
|
||||
path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||
self.save_ckpt_step))
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
_exec_save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and print loss"""
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""Evaluation callback"""
|
||||
def __init__(self, network, dataset):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.network = network
|
||||
self.global_acc = 0.0
|
||||
self.dataset = dataset
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and do evaluation"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % 100 == 0:
|
||||
callback = Accuracy()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in self.dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(Tensor(data[i]))
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
self.network.set_train(False)
|
||||
logits = self.network(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits[3], label_ids)
|
||||
acc = callback.acc_num / callback.total_num
|
||||
with open("./eval.log", "a+") as f:
|
||||
f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
f.write('\n')
|
||||
|
||||
if acc > self.global_acc:
|
||||
self.global_acc = acc
|
||||
print("The best acc is {}".format(acc))
|
||||
_exec_save_checkpoint(self.network, "eval_model.ckpt")
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
Loading…
Reference in New Issue