forked from mindspore-Ecosystem/mindspore
add ternarybert to model zoo
This commit is contained in:
parent
21addb331d
commit
65c7eb2461
|
@ -0,0 +1,395 @@
|
|||
|
||||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [TernaryBERT Description](#ternarybert-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Train](#train)
|
||||
- [Eval](#eval)
|
||||
- [Options and Parameters](#options-and-parameters)
|
||||
- [Parameters](#parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [evaluation on STS-B dataset](#evaluation-on-STS-B-dataset)
|
||||
- [evaluation on QNLI dataset](#evaluation-on-qnli-dataset)
|
||||
- [evaluation on MNLI dataset](#evaluation-on-mnli-dataset)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [TernaryBERT Description](#contents)
|
||||
|
||||
[TernaryBERT](https://arxiv.org/abs/2009.12812) ternarizes the weights in a fine-tuned [BERT](https://arxiv.org/abs/1810.04805) or [TinyBERT](https://arxiv.org/abs/1909.10351) model and achieves competitive performances in natural language processing tasks. TernaryBERT outperforms the other BERT quantization methods, and even achieves comparable performance as the full-precision model while being 14.9x smaller
|
||||
|
||||
[Paper](https://arxiv.org/abs/2009.12812): Wei Zhang, Lu Hou, Yichun Yin, Lifeng Shang, Xiao Chen, Xin Jiang and Qun Liu. [TernaryBERT: Distillation-aware Ultra-low Bit BERT](https://arxiv.org/abs/2009.12812). arXiv preprint arXiv:2009.12812.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The backbone structure of TernaryBERT is transformer, the transformer contains six encoder modules, one encoder contains one self-attention module and one self-attention module contains one attention module.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
- Download glue dataset for task distillation. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
- Software:
|
||||
- sklearn
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
|
||||
# run training example
|
||||
|
||||
sh scripts/run_train.sh
|
||||
|
||||
Before running the shell script, please set the `task_name`, `teacher_model_dir`, `student_model_dir` and `data_dir` in the run_train.sh file first.
|
||||
|
||||
# run evaluation example
|
||||
|
||||
sh scripts/run_eval.sh
|
||||
|
||||
Before running the shell script, please set the `task_name`, `model_dir` and `data_dir` in the run_eval.sh file first.
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
|
||||
.
|
||||
└─bert
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_train.sh # shell script for training phase
|
||||
├─run_eval.sh # shell script for evaluation phase
|
||||
├─src
|
||||
├─__init__.py
|
||||
├─assessment_method.py # assessment method for evaluation
|
||||
├─cell_wrapper.py # cell for training
|
||||
├─config.py # parameter configuration for training and evaluation phase
|
||||
├─dataset.py # data processing
|
||||
├─quant.py # function for quantization
|
||||
├─tinybert_model.py # backbone code of network
|
||||
├─utils.py # util function
|
||||
├─__init__.py
|
||||
├─train.py # train net for task distillation
|
||||
├─eval.py # evaluate net after task distillation
|
||||
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### Train
|
||||
|
||||
```text
|
||||
|
||||
usage: train.py [--h]
|
||||
[--device_target {GPU,Ascend}]
|
||||
[--do_eval {true,false}]
|
||||
[--epoch_size EPOCH_SIZE]
|
||||
[--device_id DEVICE_ID]
|
||||
[--do_shuffle {true,false}]
|
||||
[--enable_data_sink {true,false}]
|
||||
[--save_ckpt_step SAVE_CKPT_STEP]
|
||||
[--eval_ckpt_step EVAL_CKPT_STEP]
|
||||
[--max_ckpt_num MAX_CKPT_NUM]
|
||||
[--data_sink_steps DATA_SINK_STEPS]
|
||||
[--teacher_model_dir TEACHER_MODEL_DIR]
|
||||
[--student_model_dir STUDENT_MODEL_DIR]
|
||||
[--data_dir DATA_DIR]
|
||||
[--output_dir OUTPUT_DIR]
|
||||
[--task_name {sts-b,qnli,mnli}]
|
||||
[--dataset_type DATASET_TYPE]
|
||||
[--seed SEED]
|
||||
[--train_batch_size TRAIN_BATCH_SIZE]
|
||||
[--eval_batch_size EVAL_BATCH_SIZE]
|
||||
|
||||
options:
|
||||
--device_target Device where the code will be implemented: "GPU" | "Ascend", default is "GPU"
|
||||
--do_eval Do eval task during training or not: "true" | "false", default is "true"
|
||||
--epoch_size Epoch size for train phase: N, default is 3
|
||||
--device_id Device id: N, default is 0
|
||||
--do_shuffle Enable shuffle for train dataset: "true" | "false", default is "true"
|
||||
--enable_data_sink Enable data sink: "true" | "false", default is "true"
|
||||
--save_ckpt_step If do_eval is false, the checkpoint will be saved every save_ckpt_step: N, default is 50
|
||||
--eval_ckpt_step If do_eval is true, the evaluation will be ran every eval_ckpt_step: N, default is 50
|
||||
--max_ckpt_num The number of checkpoints will not be larger than max_ckpt_num: N, default is 50
|
||||
--data_sink_steps Sink steps for each epoch: N, default is 1
|
||||
--teacher_model_dir The checkpoint directory of teacher model: PATH, default is ""
|
||||
--student_model_dir The checkpoint directory of student model: PATH, default is ""
|
||||
--data_dir Data directory: PATH, default is ""
|
||||
--output_dir The output checkpoint directory: PATH, default is "./"
|
||||
--task_name The name of the task to train: "sts-b" | "qnli" | "mnli", default is "sts-b"
|
||||
--dataset_type The name of the task to train: "tfrecord" | "mindrecord", default is "tfrecord"
|
||||
--seed The random seed: N, default is 1
|
||||
--train_batch_size Batch size for training: N, default is 16
|
||||
--eval_batch_size Eval Batch size in callback: N, default is 32
|
||||
|
||||
```
|
||||
|
||||
### Eval
|
||||
|
||||
```text
|
||||
|
||||
usage: eval.py [--h]
|
||||
[--device_target {GPU,Ascend}]
|
||||
[--device_id DEVICE_ID]
|
||||
[--model_dir MODEL_DIR]
|
||||
[--data_dir DATA_DIR]
|
||||
[--task_name {sts-b,qnli,mnli}]
|
||||
[--dataset_type DATASET_TYPE]
|
||||
[--batch_size BATCH_SIZE]
|
||||
|
||||
options:
|
||||
--device_target Device where the code will be implemented: "GPU" | "Ascend", default is "GPU"
|
||||
--device_id Device id: N, default is 0
|
||||
--model_dir The checkpoint directory of model: PATH, default is ""
|
||||
--data_dir Data directory: PATH, default is ""
|
||||
--task_name The name of the task to train: "sts-b" | "qnli" | "mnli", default is "sts-b"
|
||||
--dataset_type The name of the task to train: "tfrecord" | "mindrecord", default is "tfrecord"
|
||||
--batch_size Batch size for evaluating: N, default is 32
|
||||
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
`config.py`contains parameters of glue tasks, train, optimizer, eval, teacher BERT model and student BERT model.
|
||||
|
||||
```text
|
||||
|
||||
Parameters for glue task:
|
||||
num_labels the numbers of labels: N.
|
||||
seq_length length of input sequence: N
|
||||
task_type the type of task: "classification" | "regression"
|
||||
metrics the eval metric for task: Accuracy | F1 | Pearsonr | Matthews
|
||||
|
||||
Parameters for train:
|
||||
batch_size batch size of input dataset: N, default is 16
|
||||
loss_scale_value initial value of loss scale: N, default is 2^16
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 50
|
||||
|
||||
Parameters for optimizer:
|
||||
learning_rate value of learning rate: Q, default is 5e-5
|
||||
end_learning_rate value of end learning rate: Q, must be positive, default is 1e-14
|
||||
power power: Q, default is 1.0
|
||||
weight_decay weight decay: Q, default is 1e-4
|
||||
eps term added to the denominator to improve numerical stability: Q, default is 1e-6
|
||||
warmup_ratio the ratio of warmup steps to total steps: Q, default is 0.1
|
||||
|
||||
Parameters for eval:
|
||||
batch_size batch size of input dataset: N, default is 32
|
||||
|
||||
Parameters for teacher bert network:
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
save_ckpt_step number for saving checkponit: N, default is 100
|
||||
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||
type_vocab_size size of token type vocab: N, default is 2
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
|
||||
Parameters for student bert network:
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
save_ckpt_step number for saving checkponit: N, default is 100
|
||||
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||
type_vocab_size size of token type vocab: N, default is 2
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
do_quant do activation quantilization or not: True | False, default is True
|
||||
embedding_bits the quant bits of embedding: N, default is 2
|
||||
weight_bits the quant bits of weight: N, default is 2
|
||||
cls_dropout_prob dropout probability for BertModelCLS: Q
|
||||
activation_init initialization value of activation quantilization: Q, default is 2.5
|
||||
is_lgt_fit use label ground truth loss or not: True | False, default is False
|
||||
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
Before running the command below, please check `teacher_model_dir`, `student_model_dir` and `data_dir` has been set. Please set the path to be the absolute full path, e.g:"/home/xxx/model_dir/".
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python train.py --task_name='sts-b' --teacher_model_dir='/home/xxx/model_dir/' --student_model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
|
||||
shell
|
||||
sh scripts/run_train.sh
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. After training, you will get some checkpoint files under the script folder by default. The eval metric value will be achieved as follows:
|
||||
|
||||
```text
|
||||
|
||||
step: 50, Pearsonr 72.50008506516072, best_Pearsonr 72.50008506516072
|
||||
step 100, Pearsonr 81.3580301181608, best_Pearsonr 81.3580301181608
|
||||
step 150, Pearsonr 83.60461724688754, best_Pearsonr 83.60461724688754
|
||||
step 200, Pearsonr 82.23210161651377, best_Pearsonr 83.60461724688754
|
||||
...
|
||||
step 1050, Pearsonr 87.5606067964618332, best_Pearsonr 87.58388835685436
|
||||
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
If you want to after running and continue to eval.
|
||||
|
||||
#### evaluation on STS-B dataset
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python eval.py --task_name='sts-b' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
|
||||
shell
|
||||
sh scripts/run_eval.sh
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
|
||||
|
||||
```text
|
||||
|
||||
eval step: 0, Pearsonr: 96.91109003302263
|
||||
eval step: 1, Pearsonr: 95.6800637493701
|
||||
eval step: 2, Pearsonr: 94.23823082886167
|
||||
...
|
||||
The best Pearsonr: 87.58388835685437
|
||||
|
||||
```
|
||||
|
||||
#### evaluation on QNLI dataset
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python eval.py --task_name='qnli' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
|
||||
shell
|
||||
sh scripts/run_eval.sh
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
|
||||
|
||||
```text
|
||||
|
||||
eval step: 0, Accuracy: 96.875
|
||||
eval step: 1, Accuracy: 89.0625
|
||||
eval step: 2, Accuracy: 89.58333333333334
|
||||
...
|
||||
The best Accuracy: 90.426505583013
|
||||
|
||||
```
|
||||
|
||||
#### evaluation on MNLI dataset
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python eval.py --task_name='mnli' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
|
||||
shell
|
||||
sh scripts/run_eval.sh
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
|
||||
|
||||
```text
|
||||
|
||||
eval step: 0, Accuracy: 90.625
|
||||
eval step: 1, Accuracy: 81.25
|
||||
eval step: 2, Accuracy: 79.16666666666666
|
||||
...
|
||||
The best Accuracy: 83.70860927152319
|
||||
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### training Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | :---------------------------------------------------- |
|
||||
| Model Version | TernaryBERT |
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 08/20/2020 |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | STS-B, QNLI, MNLI |
|
||||
| batch_size | 16, 16, 16 |
|
||||
| Metric value | 87.58388835685437, 90.426505583013, 83.70860927152319 |
|
||||
| Speed | |
|
||||
| Total time | |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | :---------------------------------------------------- |
|
||||
| Model Version | TernaryBERT |
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 08/20/2020 |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | STS-B, QNLI, MNLI |
|
||||
| batch_size | 32, 32, 32 |
|
||||
| Accuracy | 87.58388835685437, 90.426505583013, 83.70860927152319 |
|
||||
| Speed | |
|
||||
| Total time | |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In train.py, we set do_shuffle to shuffle the dataset.
|
||||
|
||||
In config.py, we set the hidden_dropout_prob, attention_pros_dropout_prob and cls_dropout_prob to dropout some network node.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""eval standalone script"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.config import eval_cfg, student_net_cfg, task_cfg
|
||||
from src.tinybert_model import BertModelCLS
|
||||
|
||||
|
||||
DATA_NAME = 'eval.tf_record'
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='ternarybert evaluation')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
|
||||
help='Device where the code will be implemented. (Default: GPU)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
|
||||
parser.add_argument('--model_dir', type=str, default='', help='The checkpoint directory of model.')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
|
||||
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
|
||||
help='The name of the task to train. (Default: sts-b)')
|
||||
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
|
||||
help='The name of the task to train. (Default: tfrecord)')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluating')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_ckpt(ckpt_file):
|
||||
lists = os.listdir(ckpt_file)
|
||||
lists.sort(key=lambda fn: os.path.getmtime(ckpt_file + '/' + fn))
|
||||
return os.path.join(ckpt_file, lists[-1])
|
||||
|
||||
|
||||
def do_eval_standalone(args_opt):
|
||||
"""
|
||||
do eval standalone
|
||||
"""
|
||||
ckpt_file = os.path.join(args_opt.model_dir, args_opt.task_name)
|
||||
ckpt_file = get_ckpt(ckpt_file)
|
||||
print('ckpt file:', ckpt_file)
|
||||
task = task_cfg[args_opt.task_name]
|
||||
student_net_cfg.seq_length = task.seq_length
|
||||
eval_cfg.batch_size = args_opt.batch_size
|
||||
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, DATA_NAME)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
|
||||
|
||||
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
|
||||
device_num=1,
|
||||
rank=0,
|
||||
do_shuffle='false',
|
||||
data_dir=eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=False)
|
||||
print('eval dataset size:', eval_dataset.get_dataset_size())
|
||||
print('eval dataset batch size:', eval_dataset.get_batch_size())
|
||||
|
||||
eval_model = BertModelCLS(student_net_cfg, False, task.num_labels, 0.0, phase_type='student')
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', key)
|
||||
new_key = re.sub('^bert.', '', new_key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(eval_model, new_param_dict)
|
||||
eval_model.set_train(False)
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
callback = task.metrics()
|
||||
for step, data in enumerate(eval_dataset.create_dict_iterator()):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
_, _, logits, _ = eval_model(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits, label_ids)
|
||||
print('eval step: {}, {}: {}'.format(step, callback.name, callback.get_metrics()))
|
||||
metrics = callback.get_metrics()
|
||||
print('The best {}: {}'.format(callback.name, metrics))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
do_eval_standalone(args)
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bert hub interface for bert base"""
|
||||
|
||||
from src.tinybert_model import BertModel
|
||||
from src.tinybert_model import BertConfig
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
tinybert_student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=True,
|
||||
embedding_bits=2,
|
||||
weight_bits=2,
|
||||
weight_clip_value=3.0,
|
||||
cls_dropout_prob=0.1,
|
||||
activation_init=2.5,
|
||||
is_lgt_fit=False
|
||||
)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""
|
||||
Create tinybert network.
|
||||
"""
|
||||
if name == "ternarybert":
|
||||
if "seq_length" in kwargs:
|
||||
tinybert_student_net_cfg.seq_length = kwargs["seq_length"]
|
||||
is_training = kwargs.get("is_training", False)
|
||||
return BertModel(tinybert_student_net_cfg, is_training, *args)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--task_name=sts-b \
|
||||
--device_id=0 \
|
||||
--model_dir="" \
|
||||
--data_dir="" > log.txt
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--task_name=sts-b \
|
||||
--device_id=0 \
|
||||
--teacher_model_dir="" \
|
||||
--student_model_dir="" \
|
||||
--data_dir="" > log.txt
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""assessment methods"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Accuracy:
|
||||
"""Accuracy"""
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
self.name = 'Accuracy'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
self.acc_num += np.sum(labels == logit_id)
|
||||
self.total_num += len(labels)
|
||||
|
||||
def get_metrics(self):
|
||||
return self.acc_num / self.total_num * 100.0
|
||||
|
||||
|
||||
class F1:
|
||||
"""F1"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'F1'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
p = tp / (tp + fp)
|
||||
r = tp / (tp + fn)
|
||||
return 2.0 * p * r / (p + r) * 100.0
|
||||
|
||||
|
||||
class Pearsonr:
|
||||
"""Pearsonr"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Pearsonr'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.reshape(logits, -1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels])
|
||||
self.logits_array = np.concatenate([self.logits_array, logits])
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
x_mean = self.logits_array.mean()
|
||||
y_mean = self.labels_array.mean()
|
||||
xm = self.logits_array - x_mean
|
||||
ym = self.labels_array - y_mean
|
||||
norm_xm = np.linalg.norm(xm)
|
||||
norm_ym = np.linalg.norm(ym)
|
||||
return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
|
||||
|
||||
|
||||
class Matthews:
|
||||
"""Matthews"""
|
||||
def __init__(self):
|
||||
self.logits_array = np.array([])
|
||||
self.labels_array = np.array([])
|
||||
self.name = 'Matthews'
|
||||
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logits = np.argmax(logits, axis=1)
|
||||
self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
|
||||
self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
|
||||
|
||||
def get_metrics(self):
|
||||
if len(self.labels_array) < 2:
|
||||
return 0.0
|
||||
tp = np.sum(self.labels_array & self.logits_array)
|
||||
fp = np.sum(self.labels_array & (~self.logits_array))
|
||||
fn = np.sum((~self.labels_array) & self.logits_array)
|
||||
tn = np.sum((~self.labels_array) & (~self.logits_array))
|
||||
return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0
|
|
@ -0,0 +1,525 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Train Cell."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .tinybert_model import BertModelCLS
|
||||
from .quant import QuantizeWeightCell
|
||||
from .config import gradient_cfg
|
||||
|
||||
|
||||
class ClipByNorm(nn.Cell):
|
||||
r"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
|
||||
The output of this layer remains the same if the :math:`L_2`-norm of the input tensor
|
||||
is not greater than the argument clip_norm. Otherwise the tensor will be normalized as:
|
||||
|
||||
.. math::
|
||||
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
||||
|
||||
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)]): Compute the L2-norm along the Specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
||||
Or a tensor shape can be broadcast to input shape.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped tensor with the same shape as the input, whose type is float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.ClipByNorm()
|
||||
>>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
||||
>>> output = net(input, clip_norm)
|
||||
>>> print(output.shape)
|
||||
(4, 16)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ClipByNorm, self).__init__()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.select_ = P.Select()
|
||||
self.greater_ = P.Greater()
|
||||
self.cast = P.Cast()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.max_op = P.Maximum()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.fill = P.Fill()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, x, clip_norm):
|
||||
"""add ms_function decorator for pynative mode"""
|
||||
mul_x = F.square(x)
|
||||
if mul_x.shape == (1,):
|
||||
l2sum = self.cast(mul_x, mstype.float32)
|
||||
else:
|
||||
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
|
||||
cond = self.greater_(l2sum, 0)
|
||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||
|
||||
intermediate = x * clip_norm
|
||||
|
||||
max_norm = self.max_op(l2norm, clip_norm)
|
||||
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
||||
values_clip = self.reshape(values_clip, self.shape(x))
|
||||
values_clip = F.identity(values_clip)
|
||||
return values_clip
|
||||
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
# pylint: disable=consider-using-in
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""clip gradients"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
|
||||
class SoftmaxCrossEntropy(nn.Cell):
|
||||
"""SoftmaxCrossEntropy loss"""
|
||||
def __init__(self):
|
||||
super(SoftmaxCrossEntropy, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.softmax = P.Softmax(axis=-1)
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, predicts, targets):
|
||||
likelihood = self.log_softmax(predicts)
|
||||
target_prob = self.softmax(targets)
|
||||
loss = self.reduce_mean(-target_prob * likelihood)
|
||||
|
||||
return self.cast(loss, mstype.float32)
|
||||
|
||||
|
||||
class BertNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
Args:
|
||||
teacher_config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, teacher_config, teacher_ckpt, student_config, student_ckpt,
|
||||
is_training, task_type, num_labels, use_one_hot_embeddings=False,
|
||||
temperature=1.0, dropout_prob=0.1):
|
||||
super(BertNetworkWithLoss, self).__init__()
|
||||
# load teacher model
|
||||
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "teacher")
|
||||
param_dict = load_checkpoint(teacher_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = 'teacher.' + key
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.teacher, new_param_dict)
|
||||
|
||||
# no_grad
|
||||
self.teacher.set_train(False)
|
||||
params = self.teacher.trainable_params()
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
# load student model
|
||||
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
|
||||
use_one_hot_embeddings, "student")
|
||||
param_dict = load_checkpoint(student_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = 'bert.' + key
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.bert, new_param_dict)
|
||||
self.cast = P.Cast()
|
||||
self.teacher_layers_num = teacher_config.num_hidden_layers
|
||||
self.student_layers_num = student_config.num_hidden_layers
|
||||
self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
|
||||
self.is_att_fit = student_config.is_att_fit
|
||||
self.is_rep_fit = student_config.is_rep_fit
|
||||
self.is_lgt_fit = student_config.is_lgt_fit
|
||||
self.task_type = task_type
|
||||
self.temperature = temperature
|
||||
self.loss_mse = nn.MSELoss()
|
||||
self.lgt_fct = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
self.select = P.Select()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.dtype = student_config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.soft_cross_entropy = SoftmaxCrossEntropy()
|
||||
self.compute_type = student_config.compute_type
|
||||
self.embedding_bits = student_config.embedding_bits
|
||||
self.weight_bits = student_config.weight_bits
|
||||
self.weight_clip_value = student_config.weight_clip_value
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids):
|
||||
"""task distill network with loss"""
|
||||
# teacher model
|
||||
teacher_seq_output, teacher_att_output, teacher_logits, _ = self.teacher(input_ids, token_type_id, input_mask)
|
||||
# student model
|
||||
student_seq_output, student_att_output, student_logits, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
total_loss = 0
|
||||
if self.is_att_fit:
|
||||
selected_teacher_att_output = ()
|
||||
selected_student_att_output = ()
|
||||
for i in range(self.student_layers_num):
|
||||
selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
|
||||
selected_student_att_output += (student_att_output[i],)
|
||||
att_loss = 0
|
||||
for i in range(self.student_layers_num):
|
||||
student_att = selected_student_att_output[i]
|
||||
teacher_att = selected_teacher_att_output[i]
|
||||
student_att = self.select(student_att <= self.cast(-100.0, mstype.float32),
|
||||
self.zeroslike(student_att),
|
||||
student_att)
|
||||
teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32),
|
||||
self.zeroslike(teacher_att),
|
||||
teacher_att)
|
||||
att_loss += self.loss_mse(student_att, teacher_att)
|
||||
total_loss += att_loss
|
||||
if self.is_rep_fit:
|
||||
selected_teacher_seq_output = ()
|
||||
selected_student_seq_output = ()
|
||||
for i in range(self.student_layers_num + 1):
|
||||
selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
|
||||
selected_student_seq_output += (student_seq_output[i],)
|
||||
rep_loss = 0
|
||||
for i in range(self.student_layers_num + 1):
|
||||
student_rep = selected_student_seq_output[i]
|
||||
teacher_rep = selected_teacher_seq_output[i]
|
||||
rep_loss += self.loss_mse(student_rep, teacher_rep)
|
||||
total_loss += rep_loss
|
||||
if self.task_type == 'classification':
|
||||
cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
|
||||
if self.is_lgt_fit:
|
||||
student_logits = self.cast(student_logits, mstype.float32)
|
||||
label_ids_reshape = self.reshape(self.cast(label_ids, mstype.int32), (-1,))
|
||||
lgt_loss = self.lgt_fct(student_logits, label_ids_reshape)
|
||||
total_loss += lgt_loss
|
||||
else:
|
||||
student_logits = self.reshape(student_logits, (-1,))
|
||||
label_ids = self.reshape(label_ids, (-1,))
|
||||
cls_loss = self.loss_mse(student_logits, label_ids)
|
||||
total_loss += cls_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class BertTrainWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
self.saved_params = self.weights.clone(prefix='saved')
|
||||
self.length = len(self.weights)
|
||||
self.quant_embedding_list = []
|
||||
self.quant_weight_list = []
|
||||
for i, key in enumerate(self.saved_params):
|
||||
if 'embedding_lookup' in key.name:
|
||||
self.quant_embedding_list.append(i)
|
||||
elif 'weight' in key.name and 'dense_1' not in key.name:
|
||||
self.quant_weight_list.append(i)
|
||||
self.quant_embedding_list_length = len(self.quant_embedding_list)
|
||||
self.quant_weight_list_length = len(self.quant_weight_list)
|
||||
|
||||
self.quantize_embedding = QuantizeWeightCell(num_bits=network.embedding_bits,
|
||||
compute_type=network.compute_type,
|
||||
clip_value=network.weight_clip_value)
|
||||
self.quantize_weight = QuantizeWeightCell(num_bits=network.weight_bits,
|
||||
compute_type=network.compute_type,
|
||||
clip_value=network.weight_clip_value)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
assign_embedding = ()
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
||||
F.control_depend(saved, assign_embedding[i])
|
||||
assign_weight = ()
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
||||
F.control_depend(saved, assign_weight[i])
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
F.control_depend(assign_embedding[i], input_ids)
|
||||
for i in range(self.quant_weight_list_length):
|
||||
F.control_depend(assign_weight[i], input_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before grad operation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
F.control_depend(input_ids, grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, gradient_cfg.clip_type, gradient_cfg.clip_value), grads)
|
||||
restore = ()
|
||||
for i in range(self.length):
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
F.control_depend(grads, restore[i])
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
for i in range(self.length):
|
||||
F.control_depend(restore[i], succ)
|
||||
return succ
|
||||
|
||||
|
||||
class BertTrainCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertTrainCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.sens = sens
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.saved_params = self.weights.clone(prefix='saved')
|
||||
self.length = len(self.weights)
|
||||
self.quant_embedding_list = []
|
||||
self.quant_weight_list = []
|
||||
for i, key in enumerate(self.saved_params):
|
||||
if 'embedding_lookup' in key.name and 'min' not in key.name and 'max' not in key.name:
|
||||
self.quant_embedding_list.append(i)
|
||||
elif 'weight' in key.name and 'dense_1' not in key.name:
|
||||
self.quant_weight_list.append(i)
|
||||
self.quant_embedding_list_length = len(self.quant_embedding_list)
|
||||
self.quant_weight_list_length = len(self.quant_weight_list)
|
||||
|
||||
self.quantize_embedding = QuantizeWeightCell(num_bits=network.embedding_bits,
|
||||
compute_type=network.compute_type,
|
||||
clip_value=network.weight_clip_value)
|
||||
self.quantize_weight = QuantizeWeightCell(num_bits=network.weight_bits,
|
||||
compute_type=network.compute_type,
|
||||
clip_value=network.weight_clip_value)
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
assign_embedding = ()
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
||||
F.control_depend(saved, assign_embedding[i])
|
||||
assign_weight = ()
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
||||
F.control_depend(saved, assign_weight[i])
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
F.control_depend(assign_embedding[i], input_ids)
|
||||
for i in range(self.quant_weight_list_length):
|
||||
F.control_depend(assign_weight[i], input_ids)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
F.control_depend(input_ids, grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, gradient_cfg.clip_type, gradient_cfg.clip_value), grads)
|
||||
restore = ()
|
||||
for i in range(self.length):
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
F.control_depend(grads, restore[i])
|
||||
succ = self.optimizer(grads)
|
||||
for i in range(self.length):
|
||||
F.control_depend(restore[i], succ)
|
||||
return succ
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""config script"""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from easydict import EasyDict as edict
|
||||
from .tinybert_model import BertConfig
|
||||
from .assessment_method import Accuracy, F1, Pearsonr, Matthews
|
||||
|
||||
|
||||
gradient_cfg = edict({
|
||||
'clip_type': 1,
|
||||
'clip_value': 1.0
|
||||
})
|
||||
|
||||
task_cfg = edict({
|
||||
"sst-2": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Accuracy}),
|
||||
"qnli": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
|
||||
"mnli": edict({"num_labels": 3, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
|
||||
"cola": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Matthews}),
|
||||
"mrpc": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
|
||||
"sts-b": edict({"num_labels": 1, "seq_length": 128, "task_type": "regression", "metrics": Pearsonr}),
|
||||
"qqp": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
|
||||
"rte": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy})
|
||||
})
|
||||
|
||||
train_cfg = edict({
|
||||
'batch_size': 16,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
'eps': 1e-6,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'warmup_ratio': 0.1
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
eval_cfg = edict({
|
||||
'batch_size': 32,
|
||||
})
|
||||
|
||||
teacher_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=False
|
||||
)
|
||||
student_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
do_quant=True,
|
||||
embedding_bits=2,
|
||||
weight_bits=2,
|
||||
weight_clip_value=3.0,
|
||||
cls_dropout_prob=0.1,
|
||||
activation_init=2.5,
|
||||
is_lgt_fit=False
|
||||
)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create tinybert dataset"""
|
||||
|
||||
from enum import Enum
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Enumerate supported dataset format"""
|
||||
TFRECORD = 1
|
||||
MINDRECORD = 2
|
||||
|
||||
|
||||
def create_tinybert_dataset(batch_size=32, device_num=1, rank=0, do_shuffle="true", data_dir=None,
|
||||
data_type='tfrecord', seq_length=128, task_type=mstype.int32, drop_remainder=True):
|
||||
"""create tinybert dataset"""
|
||||
if isinstance(data_dir, list):
|
||||
data_files = data_dir
|
||||
else:
|
||||
data_files = [data_dir]
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
|
||||
shuffle = (do_shuffle == "true")
|
||||
|
||||
if data_type == 'mindrecord':
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
|
||||
shard_id=rank)
|
||||
else:
|
||||
ds = de.TFRecordDataset(data_files, columns_list=columns_list, shuffle=shuffle, num_shards=device_num,
|
||||
shard_id=rank, shard_equal_rows=(device_num == 1))
|
||||
|
||||
if device_num == 1 and shuffle is True:
|
||||
ds = ds.shuffle(10000)
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
slice_op = C.Slice(slice(0, seq_length, 1))
|
||||
label_type = mstype.int32 if task_type == 'classification' else mstype.float32
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["segment_ids"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_mask"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_ids"])
|
||||
ds = ds.map(operations=[C.TypeCast(label_type), slice_op], input_columns=["label_ids"])
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
return ds
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Quantization function."""
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
class QuantizeWeightCell(nn.Cell):
|
||||
"""
|
||||
The ternary fake quant op for weight.
|
||||
|
||||
Args:
|
||||
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
|
||||
clip_value (float): Clips weight to be in [-clip_value, clip_value].
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_bits=8, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
|
||||
super(QuantizeWeightCell, self).__init__()
|
||||
self.num_bits = num_bits
|
||||
self.compute_type = compute_type
|
||||
self.clip_value = clip_value
|
||||
self.per_channel = per_channel
|
||||
|
||||
self.clamp = C.clip_by_value
|
||||
self.abs = P.Abs()
|
||||
self.sum = P.ReduceSum()
|
||||
self.nelement = F.size
|
||||
self.div = P.Div()
|
||||
self.cast = P.Cast()
|
||||
self.max = P.ReduceMax()
|
||||
self.min = P.ReduceMin()
|
||||
self.round = P.Round()
|
||||
|
||||
def construct(self, weight):
|
||||
"""quantize weight cell"""
|
||||
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
|
||||
if self.num_bits == 2:
|
||||
if self.per_channel:
|
||||
n = self.nelement(tensor[0])
|
||||
m = self.div(self.sum(self.abs(tensor), 1), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor[:] > thres[0], self.compute_type)
|
||||
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
|
||||
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
|
||||
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
n = self.nelement(tensor)
|
||||
m = self.div(self.sum(self.abs(tensor)), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor > thres, self.compute_type)
|
||||
neg = self.cast(tensor < -thres, self.compute_type)
|
||||
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
|
||||
alpha = self.sum(self.abs(mask * self.cast(tensor, self.compute_type))) / self.sum(mask)
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
tensor_max = self.cast(self.max(tensor), self.compute_type)
|
||||
tensor_min = self.cast(self.min(tensor), self.compute_type)
|
||||
s = (tensor_max - tensor_min) / (2 ** self.cast(self.num_bits, self.compute_type) - 1)
|
||||
output = self.round(self.div(tensor - tensor_min, s)) * s + tensor_min
|
||||
return output
|
||||
|
||||
|
||||
class QuantizeWeight:
|
||||
"""
|
||||
Quantize weight into specified bit.
|
||||
|
||||
Args:
|
||||
num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
|
||||
clip_value (float): Clips weight to be in [-clip_value, clip_value].
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_bits=2, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
|
||||
self.num_bits = num_bits
|
||||
self.compute_type = compute_type
|
||||
self.clip_value = clip_value
|
||||
self.per_channel = per_channel
|
||||
|
||||
self.clamp = C.clip_by_value
|
||||
self.abs = P.Abs()
|
||||
self.sum = P.ReduceSum()
|
||||
self.nelement = F.size
|
||||
self.div = P.Div()
|
||||
self.cast = P.Cast()
|
||||
self.max = P.ReduceMax()
|
||||
self.min = P.ReduceMin()
|
||||
self.floor = P.Floor()
|
||||
|
||||
def construct(self, weight):
|
||||
"""quantize weight"""
|
||||
tensor = self.clamp(weight, -self.clip_value, self.clip_value)
|
||||
if self.num_bits == 2:
|
||||
if self.per_channel:
|
||||
n = self.nelement(tensor[0])
|
||||
m = self.div(self.sum(self.abs(tensor), 1), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor[:] > thres[0], self.compute_type)
|
||||
neg = self.cast(tensor[:] < -thres[0], self.compute_type)
|
||||
mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
|
||||
alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
n = self.nelement(tensor)
|
||||
m = self.div(self.sum(self.abs(tensor)), n)
|
||||
thres = 0.7 * m
|
||||
pos = self.cast(tensor > thres, self.compute_type)
|
||||
neg = self.cast(tensor < -thres, self.compute_type)
|
||||
mask = self.cast(self.abs(tensor) > thres, self.compute_type)
|
||||
alpha = self.sum(self.abs(mask * tensor)) / self.sum(mask)
|
||||
output = alpha * pos - alpha * neg
|
||||
else:
|
||||
tensor_max = self.max(tensor)
|
||||
tensor_min = self.min(tensor)
|
||||
s = (tensor_max - tensor_min) / (2 ** self.num_bits - 1)
|
||||
output = self.floor(self.div((tensor - tensor_min), s) + 0.5) * s + tensor_min
|
||||
return output
|
||||
|
||||
|
||||
def convert_network(network, embedding_bits=2, weight_bits=2, clip_value=1.0):
|
||||
quantize_embedding = QuantizeWeight(num_bits=embedding_bits, clip_value=clip_value)
|
||||
quantize_weight = QuantizeWeight(num_bits=weight_bits, clip_value=clip_value)
|
||||
for name, param in network.parameters_and_names():
|
||||
if 'bert_embedding_lookup' in name and 'min' not in name and 'max' not in name:
|
||||
quantized_param = quantize_embedding.construct(param)
|
||||
param.set_data(quantized_param)
|
||||
elif 'weight' in name and 'dense_1' not in name:
|
||||
quantized_param = quantize_weight.construct(param)
|
||||
param.set_data(quantized_param)
|
||||
|
||||
|
||||
def save_params(network):
|
||||
return {name: Parameter(param, 'saved_params') for name, param in network.parameters_and_names()}
|
||||
|
||||
|
||||
def restore_params(network, params_dict):
|
||||
for name, param in network.parameters_and_names():
|
||||
param.set_data(params_dict[name])
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,187 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ternarybert utils"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from .quant import convert_network, save_params, restore_params
|
||||
|
||||
|
||||
class ModelSaveCkpt(Callback):
|
||||
"""
|
||||
Saves checkpoint.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Args:
|
||||
network (Network): The train network for training.
|
||||
save_ckpt_step (int): The step to save checkpoint.
|
||||
max_ckpt_num (int): The max checkpoint number.
|
||||
"""
|
||||
def __init__(self, network, save_ckpt_step, max_ckpt_num, output_dir, embedding_bits=2, weight_bits=2,
|
||||
clip_value=1.0):
|
||||
super(ModelSaveCkpt, self).__init__()
|
||||
self.count = 0
|
||||
self.network = network
|
||||
self.save_ckpt_step = save_ckpt_step
|
||||
self.max_ckpt_num = max_ckpt_num
|
||||
self.output_dir = output_dir
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
self.embedding_bits = embedding_bits
|
||||
self.weight_bits = weight_bits
|
||||
self.clip_value = clip_value
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and save ckpt"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.save_ckpt_step == 0:
|
||||
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
|
||||
if saved_ckpt_num > self.max_ckpt_num:
|
||||
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
|
||||
path = os.path.join(self.output_dir, "ternary_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||
self.save_ckpt_step))
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
params_dict = save_params(self.network)
|
||||
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
|
||||
save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"ternary_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
restore_params(self.network, params_dict)
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and print loss"""
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
|
||||
|
||||
class StepCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(StepCallBack, self).__init__()
|
||||
self.start_time = 0.0
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.start_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
time_cost = time.time() - self.start_time
|
||||
cb_params = run_context.original_args()
|
||||
print("step: {}, second_per_step: {}".format(cb_params.cur_step_num, time_cost))
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""Evaluation callback"""
|
||||
def __init__(self, network, dataset, eval_ckpt_step, save_ckpt_dir, embedding_bits=2, weight_bits=2,
|
||||
clip_value=1.0, metrics=None):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.network = network
|
||||
self.global_metrics = 0.0
|
||||
self.dataset = dataset
|
||||
self.eval_ckpt_step = eval_ckpt_step
|
||||
self.save_ckpt_dir = save_ckpt_dir
|
||||
self.embedding_bits = embedding_bits
|
||||
self.weight_bits = weight_bits
|
||||
self.clip_value = clip_value
|
||||
self.metrics = metrics
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and do evaluation"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.eval_ckpt_step == 0:
|
||||
params_dict = save_params(self.network)
|
||||
convert_network(self.network, self.embedding_bits, self.weight_bits, self.clip_value)
|
||||
self.network.set_train(False)
|
||||
callback = self.metrics()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in self.dataset:
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
_, _, logits, _ = self.network(input_ids, token_type_id, input_mask)
|
||||
callback.update(logits, label_ids)
|
||||
metrics = callback.get_metrics()
|
||||
|
||||
if metrics > self.global_metrics:
|
||||
self.global_metrics = metrics
|
||||
eval_model_ckpt_file = os.path.join(self.save_ckpt_dir, 'eval_model.ckpt')
|
||||
if os.path.exists(eval_model_ckpt_file):
|
||||
os.remove(eval_model_ckpt_file)
|
||||
save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
print('step {}, {} {}, best_{} {}'.format(cb_params.cur_step_num,
|
||||
callback.name,
|
||||
metrics,
|
||||
callback.name,
|
||||
self.global_metrics))
|
||||
restore_params(self.network, params_dict)
|
||||
self.network.set_train(True)
|
||||
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""task distill script"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore import set_seed
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.utils import StepCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
from src.config import train_cfg, eval_cfg, teacher_net_cfg, student_net_cfg, task_cfg
|
||||
from src.cell_wrapper import BertNetworkWithLoss, BertTrainCell
|
||||
|
||||
WEIGHTS_NAME = 'eval_model.ckpt'
|
||||
EVAL_DATA_NAME = 'eval.tf_record'
|
||||
TRAIN_DATA_NAME = 'train.tf_record'
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='ternarybert task distill')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
|
||||
help='Device where the code will be implemented. (Default: GPU)')
|
||||
parser.add_argument('--do_eval', type=str, default='true', choices=['true', 'false'],
|
||||
help='Do eval task during training or not. (Default: true)')
|
||||
parser.add_argument('--epoch_size', type=int, default=3, help='Epoch size for train phase. (Default: 3)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
|
||||
parser.add_argument('--do_shuffle', type=str, default='true', choices=['true', 'false'],
|
||||
help='Enable shuffle for train dataset. (Default: true)')
|
||||
parser.add_argument('--enable_data_sink', type=str, default='true', choices=['true', 'false'],
|
||||
help='Enable data sink. (Default: true)')
|
||||
parser.add_argument('--save_ckpt_step', type=int, default=50,
|
||||
help='If do_eval is false, the checkpoint will be saved every save_ckpt_step. (Default: 50)')
|
||||
parser.add_argument('--eval_ckpt_step', type=int, default=50,
|
||||
help='If do_eval is true, the evaluation will be ran every eval_ckpt_step. (Default: 50)')
|
||||
parser.add_argument('--max_ckpt_num', type=int, default=10,
|
||||
help='The number of checkpoints will not be larger than max_ckpt_num. (Default: 10)')
|
||||
parser.add_argument('--data_sink_steps', type=int, default=1, help='Sink steps for each epoch. (Default: 1)')
|
||||
parser.add_argument('--teacher_model_dir', type=str, default='', help='The checkpoint directory of teacher model.')
|
||||
parser.add_argument('--student_model_dir', type=str, default='', help='The checkpoint directory of student model.')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
|
||||
parser.add_argument('--output_dir', type=str, default='./', help='The output checkpoint directory.')
|
||||
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
|
||||
help='The name of the task to train. (Default: sts-b)')
|
||||
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
|
||||
help='The name of the task to train. (Default: tfrecord)')
|
||||
parser.add_argument('--seed', type=int, default=1, help='The random seed')
|
||||
parser.add_argument('--train_batch_size', type=int, default=16, help='Batch size for training')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=32, help='Eval Batch size in callback')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run_task_distill(args_opt):
|
||||
"""
|
||||
run task distill
|
||||
"""
|
||||
task = task_cfg[args_opt.task_name]
|
||||
teacher_net_cfg.seq_length = task.seq_length
|
||||
student_net_cfg.seq_length = task.seq_length
|
||||
train_cfg.batch_size = args_opt.train_batch_size
|
||||
eval_cfg.batch_size = args_opt.eval_batch_size
|
||||
teacher_ckpt = os.path.join(args_opt.teacher_model_dir, args_opt.task_name, WEIGHTS_NAME)
|
||||
student_ckpt = os.path.join(args_opt.student_model_dir, args_opt.task_name, WEIGHTS_NAME)
|
||||
train_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, TRAIN_DATA_NAME)
|
||||
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, EVAL_DATA_NAME)
|
||||
save_ckpt_dir = os.path.join(args_opt.output_dir, args_opt.task_name)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
|
||||
|
||||
rank = 0
|
||||
device_num = 1
|
||||
train_dataset = create_tinybert_dataset(batch_size=train_cfg.batch_size,
|
||||
device_num=device_num,
|
||||
rank=rank,
|
||||
do_shuffle=args_opt.do_shuffle,
|
||||
data_dir=train_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=True)
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
print('train dataset size:', dataset_size)
|
||||
eval_dataset = create_tinybert_dataset(batch_size=eval_cfg.batch_size,
|
||||
device_num=device_num,
|
||||
rank=rank,
|
||||
do_shuffle=args_opt.do_shuffle,
|
||||
data_dir=eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type=task.task_type,
|
||||
drop_remainder=False)
|
||||
print('eval dataset size:', eval_dataset.get_dataset_size())
|
||||
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.epoch_size
|
||||
|
||||
netwithloss = BertNetworkWithLoss(teacher_config=teacher_net_cfg, teacher_ckpt=teacher_ckpt,
|
||||
student_config=student_net_cfg, student_ckpt=student_ckpt,
|
||||
is_training=True, task_type=task.task_type, num_labels=task.num_labels)
|
||||
params = netwithloss.trainable_params()
|
||||
optimizer_cfg = train_cfg.optimizer_cfg
|
||||
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(dataset_size * args_opt.epoch_size *
|
||||
optimizer_cfg.AdamWeightDecay.warmup_ratio),
|
||||
decay_steps=int(dataset_size * args_opt.epoch_size),
|
||||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
|
||||
netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
if args_opt.do_eval == 'true':
|
||||
eval_dataset = list(eval_dataset.create_dict_iterator())
|
||||
callback = [EvalCallBack(network=netwithloss.bert,
|
||||
dataset=eval_dataset,
|
||||
eval_ckpt_step=args_opt.eval_ckpt_step,
|
||||
save_ckpt_dir=save_ckpt_dir,
|
||||
embedding_bits=student_net_cfg.embedding_bits,
|
||||
weight_bits=student_net_cfg.weight_bits,
|
||||
clip_value=student_net_cfg.weight_clip_value,
|
||||
metrics=task.metrics)]
|
||||
else:
|
||||
callback = [StepCallBack(),
|
||||
ModelSaveCkpt(network=netwithloss.bert,
|
||||
save_ckpt_step=args_opt.save_ckpt_step,
|
||||
max_ckpt_num=args_opt.max_ckpt_num,
|
||||
output_dir=save_ckpt_dir,
|
||||
embedding_bits=student_net_cfg.embedding_bits,
|
||||
weight_bits=student_net_cfg.weight_bits,
|
||||
clip_value=student_net_cfg.weight_clip_value)]
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
run_task_distill(args)
|
Loading…
Reference in New Issue