!18003 finish q8bert task
Merge pull request !18003 from chenzhuo/q8bert
This commit is contained in:
commit
6103989fcc
|
@ -369,22 +369,21 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param
|
|||
if quant_new_params is not None and not isinstance(quant_new_params, list):
|
||||
raise TypeError("quant_new_params must be list or None.")
|
||||
iterable_dict = {
|
||||
'weight': iter(list(filter(lambda item: item[0].endswith('weight'), params_dict.items()))),
|
||||
'bias': iter(list(filter(lambda item: item[0].endswith('bias'), params_dict.items()))),
|
||||
'gamma': iter(list(filter(lambda item: item[0].endswith('gamma'), params_dict.items()))),
|
||||
'beta': iter(list(filter(lambda item: item[0].endswith('beta'), params_dict.items()))),
|
||||
'moving_mean': iter(list(filter(lambda item: item[0].endswith('moving_mean'), params_dict.items()))),
|
||||
'moving_variance': iter(list(filter(lambda item: item[0].endswith('moving_variance'), params_dict.items()))),
|
||||
'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))),
|
||||
'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))),
|
||||
'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items())))
|
||||
}
|
||||
for param in params_dict.items():
|
||||
key_name = param[0].split(".")[-1]
|
||||
if key_name not in iterable_dict:
|
||||
iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value),
|
||||
params_dict.items())))
|
||||
|
||||
for name, param in quant_model.parameters_and_names():
|
||||
key_name = name.split(".")[-1]
|
||||
if key_name not in iterable_dict.keys():
|
||||
if key_name not in quant_new_params:
|
||||
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
|
||||
raise ValueError(f"Can't find match parameter in ckpt, param name = {name}")
|
||||
continue
|
||||
value_param = next(iterable_dict[key_name], None)
|
||||
if value_param:
|
||||
|
|
|
@ -635,8 +635,8 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
|||
>>> input_data = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
|
||||
>>> result = conv2d_bnfold(input_data)
|
||||
>>> print(result)
|
||||
[[[[5.9296875, 13.8359375]
|
||||
[11.859375, 17.78125]]]]
|
||||
[[[[5.9296875 13.8359375]
|
||||
[11.859375 17.78125]]]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -875,8 +875,8 @@ class Conv2dBnFoldQuant(Cell):
|
|||
>>> input_data = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
|
||||
>>> result = conv2d_bnfold(input_data)
|
||||
>>> print(result)
|
||||
[[[[5.9296875, 13.8359375]
|
||||
[11.859375, 17.78125]]]]
|
||||
[[[[5.9296875 13.8359375]
|
||||
[11.859375 17.78125]]]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -1,50 +1,53 @@
|
|||
|
||||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [Q8BERT Description](#q8bert-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Parameters](#parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [training Performance](#training-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
- [Contents](#Contents)
|
||||
- [Q8BERT Description](#Q8BERT-Description)
|
||||
- [Model Architecture](#Model-architecture)
|
||||
- [Dataset](#Dataset)
|
||||
- [Environment Requirements](#Environment-requirements)
|
||||
- [Quick Start](#Quick-start)
|
||||
- [Script Description](#Script-description)
|
||||
- [Script and Sample Code](#Script-and-sample-code)
|
||||
- [Parameters](#Parameters)
|
||||
- [Training Process](#Training-Process)
|
||||
- [Running on Ascend and GPU platform](#Running-on-Ascend-and-GPU-platform)
|
||||
- [Training based STS-B dataset](#Training-based-STS-B-dataset)
|
||||
- [Evaling Process](#Evaling-process)
|
||||
- [Evaling based STS-B dataset](#Evaling-based-STS-B-dataset)
|
||||
- [Export model](#Export-model)
|
||||
- [Performance](#Performance)
|
||||
- [Performance Evaluation](#Performance-Evaluation)
|
||||
- [Description of Random Situation](#Description-of-random-situation)
|
||||
- [ModelZoo Homepage](#Modelzoo-homepage)
|
||||
|
||||
# [Q8BERT Description](#contents)
|
||||
|
||||
[Q8BERT](https://arxiv.org/abs/1910.06188) is a quantization-aware training during the fine-tuning phase of [BERT](https://arxiv.org/abs/1810.04805)
|
||||
in order to compress BERT by 4× with minimal accuracy loss. Furthermore, the
|
||||
[Q8BERT](https://arxiv.org/abs/1910.06188) is the model where the quantization-aware training is applied into [BERT](https://arxiv.org/abs/1810.04805)
|
||||
in order to compress BERT size under minimal accuracy loss. Furthermore, the
|
||||
produced quantized model can accelerate inference speed if it is optimized for 8bit Integer supporting hardware.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1910.06188): Ofir Zafrir, Guy Boudoukh, Peter Izsak and Moshe Wasserblat. [Q8BERT: Quantized 8Bit BERT](https://arxiv.org/abs/1910.06188). arXiv preprint arXiv:2009.12812.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The backbone structure of Q8BERT is transformer, the transformer contains 12 encoder modules, one encoder contains one self-attention module and one self-attention module contains one attention module.
|
||||
The backbone structure of Q8BERT is transformer, the transformer contains 12 encoder modules, one encoder contains one self-attention module and one self-attention module contains one attention module.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
- Download glue dataset for task distillation. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository.
|
||||
- Download glue dataset for fine-tuning. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Hardware
|
||||
- Prepare hardware environment with Ascend or GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
- Software:
|
||||
- numpy
|
||||
- numpy, sklearn
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
|
@ -53,10 +56,9 @@ After installing MindSpore via the official website, you can start training and
|
|||
```bash
|
||||
|
||||
# run training example
|
||||
run_train.sh
|
||||
|
||||
Before running the shell script, please set the `task_name`, `teacher_model_dir`, `student_model_dir` and `data_dir` in the run_train.sh file first.
|
||||
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
# run evaling example
|
||||
bash run_eval.sh [TASK_NAME] [DEVICE_TARGET] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -65,166 +67,147 @@ Before running the shell script, please set the `task_name`, `teacher_model_dir`
|
|||
|
||||
```text
|
||||
└─q8bert
|
||||
├─README.md
|
||||
├─README.md # document in English
|
||||
├─README_CN.md # document in Chinese
|
||||
├─scripts
|
||||
├─run_train.sh # shell script for training phase
|
||||
├─run_standalone_train.sh # shell script for training phase
|
||||
├─run_eval.sh # shell script for evaling phase
|
||||
├─src
|
||||
├─__init__.py
|
||||
├─dataset.py # data processing
|
||||
├─bert_model.py # backbone code of bert
|
||||
├─q8bert_model.py # quantization for Bert
|
||||
├─q8bert.py # backbone code of q8bert
|
||||
├─utils.py # some utils function of q8bert
|
||||
├─q8bert.py # backbone code of Q8BERT
|
||||
├─utils.py # some utils function of Q8BERT
|
||||
├─__init__.py
|
||||
├─run_train.py # train net for task distillation
|
||||
├─train.py # train net
|
||||
├─eval.py # eval net
|
||||
├─export.py # export model
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
## [Script Parameters]
|
||||
|
||||
### Train
|
||||
|
||||
```text
|
||||
|
||||
usage: run_train.py [--h] [--device_target {GPU,Ascend}][--epoch_num EPOCH_NUM] [--task_name {SST-2,QNLI,MNLI,COLA,QQP,"STS-B,RTE}][--do_shuffle {true,false}] [--enable_data_sink {true,false}][--do_eval {true,false}][--device_id DEVICE_ID] [--save_ckpt_step SAVE_CKPT_STEP] [--eval_ckpt_step EVAL_CKPT_STEP] [--max_ckpt_num MAX_CKPT_NUM] [--load_ckpt_path LOAD_CKPT_PATH] [--train_data_dir TRAIN_DATA_DIR] [--eval_data_dir EVAL_DATA_DIR] [--device_id DEVICE_ID] [--logging_step LOGGIND_STEP] [--do_quant {true,false}]
|
||||
usage:
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
|
||||
options:
|
||||
[TASK_NAME] The name of the task to train: "STS-B"| "QNLI"| SST-2"
|
||||
[DEVICE_TARGET] Device where the code will be implemented: "GPU" | "Ascend"
|
||||
[TRAIN_DATA_DIR] Train Data directory
|
||||
[EVAL_DATA_DIR] Eval Data directory
|
||||
[LOAD_CKPT_PATH] The checkpoint directory of model
|
||||
or
|
||||
|
||||
python train.py [--h] [--device_target {GPU,Ascend}] [--epoch_num EPOCH_NUM] [--task_name {SST-2, QNLI, STS-B}]
|
||||
[--do_shuffle {True,False}] [--enable_data_sink {True,False}] [--do_eval {True,False}]
|
||||
[--device_id DEVICE_ID] [--save_ckpt_step SAVE_CKPT_STEP] [--eval_ckpt_step EVAL_CKPT_STEP]
|
||||
[--max_ckpt_num MAX_CKPT_NUM] [--load_ckpt_path LOAD_CKPT_PATH] [--train_data_dir TRAIN_DATA_DIR]
|
||||
[--eval_data_dir EVAL_DATA_DIR] [--device_id DEVICE_ID] [--logging_step LOGGIND_STEP]
|
||||
[--do_quant {True,False}]
|
||||
|
||||
options:
|
||||
--device_target Device where the code will be implemented: "GPU" | "Ascend", default is "GPU"
|
||||
--do_eval Do eval task during training or not: "true" | "false", default is "true"
|
||||
--do_eval Do eval task during training or not: "True" | "False", default is "True"
|
||||
--epoch_num Epoch num for train phase: N, default is 3
|
||||
--device_id Device id: N, default is 0
|
||||
--do_shuffle Enable shuffle for train dataset: "true" | "false", default is "true"
|
||||
--enable_data_sink Enable data sink: "true" | "false", default is "true"
|
||||
--save_ckpt_step If do_eval is false, the checkpoint will be saved every save_ckpt_step: N, default is 50
|
||||
--eval_ckpt_step If do_eval is true, the evaluation will be ran every eval_ckpt_step: N, default is 50
|
||||
--do_shuffle Enable shuffle for train dataset: "True" | "False", default is "True"
|
||||
--enable_data_sink Enable data sink: "True" | "False", default is "True"
|
||||
--save_ckpt_step If do_eval is False, the checkpoint will be saved every save_ckpt_step: N, default is 50
|
||||
--eval_ckpt_step If do_eval is True, the evaluation will be ran every eval_ckpt_step: N, default is 50
|
||||
--max_ckpt_num The number of checkpoints will not be larger than max_ckpt_num: N, default is 50
|
||||
--data_sink_steps Sink steps for each epoch: N, default is 1
|
||||
--load_ckpt_path The checkpoint directory of model: PATH, default is ""
|
||||
--train_data_dir Train Data directory: PATH, default is ""
|
||||
--eval_data_dir Eval Data directory: PATH, default is ""
|
||||
--task_name The name of the task to train: "SST-2"| "QNLI"| "MNLI"|"COLA"|"QQP"|"STS-B"|"RTE"
|
||||
--task_name The name of the task to train: "STS-B"| "QNLI"| SST-2"
|
||||
--dataset_type The name of the task to train: "tfrecord" | "mindrecord", default is "tfrecord"
|
||||
--train_batch_size Batch size for training: N, default is 16
|
||||
--eval_batch_size Eval Batch size in callback: N, default is 32
|
||||
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
`config.py`contains parameters of glue tasks, train, optimizer, eval, teacher BERT model and student BERT model.
|
||||
|
||||
```text
|
||||
|
||||
Parameters for glue task:
|
||||
num_labels the numbers of labels: N.
|
||||
seq_length length of input sequence: N
|
||||
task_type the type of task: "classification" | "regression"
|
||||
metrics the eval metric for task: Accuracy | F1 | Pearsonr | Matthews
|
||||
|
||||
Parameters for train:
|
||||
batch_size batch size of input dataset: N, default is 16
|
||||
loss_scale_value initial value of loss scale: N, default is 2^16
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 50
|
||||
|
||||
Parameters for optimizer:
|
||||
learning_rate value of learning rate: Q, default is 5e-5
|
||||
end_learning_rate value of end learning rate: Q, must be positive, default is 1e-14
|
||||
power power: Q, default is 1.0
|
||||
weight_decay weight decay: Q, default is 1e-4
|
||||
eps term added to the denominator to improve numerical stability: Q, default is 1e-6
|
||||
warmup_ratio the ratio of warmup steps to total steps: Q, default is 0.1
|
||||
|
||||
Parameters for eval:
|
||||
batch_size batch size of input dataset: N, default is 32
|
||||
|
||||
Parameters for teacher bert network:
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
save_ckpt_step number for saving checkponit: N, default is 100
|
||||
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||
type_vocab_size size of token type vocab: N, default is 2
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
|
||||
Parameters for student bert network:
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522
|
||||
hidden_size size of bert encoder layers: N
|
||||
num_hidden_layers number of hidden layers: N
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
save_ckpt_step number for saving checkponit: N, default is 100
|
||||
max_ckpt_num maximum number for saving checkpoint: N, default is 1
|
||||
type_vocab_size size of token type vocab: N, default is 2
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
do_quant do activation quantilization or not: True | False, default is True
|
||||
embedding_bits the quant bits of embedding: N, default is 2
|
||||
weight_bits the quant bits of weight: N, default is 2
|
||||
cls_dropout_prob dropout probability for BertModelCLS: Q
|
||||
activation_init initialization value of activation quantilization: Q, default is 2.5
|
||||
is_lgt_fit use label ground truth loss or not: True | False, default is False
|
||||
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
### Running-on-Ascend-and-GPU-platform
|
||||
|
||||
Before running the command below, please check `data_dir` and 'load_ckpt_path' has been set. Please set the path to be the absolute full path, e.g:"/home/xxx/model_dir/".
|
||||
Before running the command below, please check that all required parameters have been set. The parameter of path would better be the absolute path. The options of parameter DEVICE_TARGET contains Ascend and GPU, which means the model will run in Ascend and GPU platform respectively.
|
||||
|
||||
### Training based STS-B dataset
|
||||
|
||||
This model currently supports STS-B, QNLI and SST-2 datasets, following example is based on STS-B dataset.
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python ./run_train.py --device_target="GPU" --do_eval="true" --epoch_num=3 --task_name="STS-B" --do_shuffle="true" --enable_data_sink="true" --data_sink_steps=100 --save_ckpt_step=100 --max_ckpt_num=1 --load_ckpt_path="sts-b.ckpt" --train_data_dir="sts-b/train.tf_record" --eval_data_dir="sts-b/eval.tf_record" --device_id=0 --logging_step=100 --do_quant="true"
|
||||
shell
|
||||
sh run_train.sh
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
example:
|
||||
bash run_standalone_train.sh STS-B Ascend /path/sts-b/train.tf_record /path/sts-b/eval.tf_record /path/xxx.ckpt
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. After training, you will get some checkpoint files under the script folder by default. The eval metric value will be achieved as follows:
|
||||
The shell command above will run in the background, you can view the results the file train_log.txt. The python command will run in the console, you can view the results on the interface.
|
||||
|
||||
```text
|
||||
|
||||
epoch: 1, step: 100, loss are (Tensor(shape=[], dtype=Float32, value= 0.526506), Tensor(shape=[], dtype=Bool, value= False)) The current result is {'pearson': 0.8407084843799768, 'spearmanr': 0.8405771469597393, 'corr': 0.840642815669858} epoch time: 66421.602 ms, per step time: 664.216 ms
|
||||
epoch: 2, step: 200, loss are (Tensor(shape=[], dtype=Float32, value= 0.406012), Tensor(shape=[], dtype=Bool, value= False)) The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087} epoch time: 47488.633 ms, per step time: 474.886 ms
|
||||
epoch: 1, step: 100, loss: 0.526506
|
||||
The current result is {'pearson': 0.8407084843799768, 'spearmanr': 0.8405771469597393, 'corr': 0.840642815669858}, the best result is 0.8407084843799768
|
||||
epoch time: 147446.514 ms, per step time: 1474.465 ms
|
||||
epoch: 2, step: 200, loss: 0.406012
|
||||
The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087}, the best result is 0.8407084843799768
|
||||
epoch time: 93688.080 ms, per step time: 936.881 ms
|
||||
...
|
||||
best pearson:0.8753269455187238
|
||||
|
||||
After training, checkpoint files will be saved at relative folder under the root folder of the project.
|
||||
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
## [Evaling Process](#contents)
|
||||
|
||||
### Evaling based STS-B dataset
|
||||
|
||||
```text
|
||||
|
||||
shell
|
||||
bash run_eval.sh [TASK_NAME] [DEVICE_TARGET] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
example:
|
||||
bash run_eval.sh STS-B Ascend /path/sts-b/eval.tf_record /path/xxx.ckpt
|
||||
|
||||
```
|
||||
|
||||
The shell command above will run in the background, you can view the results the file eval_log.txt. The python command will run in the console, you can view the results on the interface.
|
||||
|
||||
```text
|
||||
|
||||
The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087}, the best result is 0.8407084843799768
|
||||
|
||||
```
|
||||
|
||||
## [Export model](#contents)
|
||||
|
||||
```text
|
||||
|
||||
python export.py --task_name [TASK_NAME] --ckpt_file [CKPT_FILE] --file_format [FILE_FORMAT]
|
||||
|
||||
```
|
||||
|
||||
The file_format parameter should be inside ["AIR", "MINDIR"]
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### training Performance
|
||||
### Performance Evaluation
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | :---------------------------------------------------- |
|
||||
| Model Version | Q8BERT |
|
||||
| Resource | NV GeForce GTX1080ti |
|
||||
| uploaded Date | 03/01/2020 |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | STS-B |
|
||||
| batch_size | 16 |
|
||||
| Metric value | 87.5833 |
|
||||
| Speed | 0.47s/step |
|
||||
| Total time | 9.1min(3epoch, 1p) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ---------------------------------------------------------- | ------------------------- |
|
||||
| Model Version | Q8BERT | Q8BERT |
|
||||
| Resource | Ascend 910, cpu 2.60GHz, cores 172, mem 755G, os Euler2.8 | NV GeForce GTX1080ti, cpu 2.00GHz, cores 56, mem 251G, os Ubuntu 16.04 |
|
||||
| Date | 2021-6-8 | 2021-6-8 |
|
||||
| MindSpore Version | 1.2.0 | 1.2.0 |
|
||||
| Dataset | STS-B | STS-B |
|
||||
| Total Time | 11mins (3epoch, 1p) | 18mins (3epoch, 1p) |
|
||||
| Metric value | 89.14 | 89.18 |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [TinyBERT概述](#tinybert概述)
|
||||
- [Q8BERT概述](#Q8BERT概述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
|
@ -12,27 +12,17 @@
|
|||
- [脚本说明](#脚本说明)
|
||||
- [脚本和样例代码](#脚本和样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [一般蒸馏](#一般蒸馏)
|
||||
- [任务蒸馏](#任务蒸馏)
|
||||
- [选项及参数](#选项及参数)
|
||||
- [选项](#选项)
|
||||
- [参数](#参数)
|
||||
- [训练流程](#训练流程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器上运行](#ascend处理器上运行)
|
||||
- [在GPU处理器上运行](#在gpu处理器上运行)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [Ascend处理器上运行](#ascend处理器上运行-1)
|
||||
- [GPU处理器上运行](#gpu处理器上运行)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#用法-1)
|
||||
- [基于SST-2数据集进行评估](#基于sst-2数据集进行评估)
|
||||
- [基于MNLI数据集进行评估](#基于mnli数据集进行评估)
|
||||
- [基于QNLI数据集进行评估](#基于qnli数据集进行评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [Ascend和GPU平台上运行](#Ascend和GPU平台上运行)
|
||||
- [基于STS-B数据集进行训练](#基于STS-B数据集进行训练)
|
||||
- [评估流程](#评估流程)
|
||||
- [基于STS-B数据集进行评估](#基于STS-B数据集进行评估)
|
||||
- [模型导出](#模型导出)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
|
@ -40,9 +30,9 @@
|
|||
|
||||
# Q8BERT概述
|
||||
|
||||
[Q8BERT](https://arxiv.org/abs/1910.06188)是一种在finetune阶段使用量化训练BERT后的模型,最后是训练出来的模型在保证精度损失的情况下,模型大小压缩4倍,而且使用这种算法训练出来的模型在含有8bit算子的硬件上,推理速度也可以相应提高
|
||||
[Q8BERT](https://arxiv.org/abs/1910.06188)是一种将训练中量化策略应用到BERT的模型,训练生成的模型在精度损失较小的情况下,可以减小模型存储尺寸,而且在支持8bit量化算子的硬件平台上,可以加速推理。
|
||||
|
||||
[论文](https://arxiv.org/abs/1910.06188): Ofir Zafrir, Guy Boudoukh, Peter Izsak and Moshe Wasserblat. [Q8BERT: Quantized 8Bit BERT](https://arxiv.org/abs/1910.06188). arXiv preprint arXiv:2009.12812.
|
||||
[论文](https://arxiv.org/abs/1910.06188): Ofir Zafrir, Guy Boudoukh, Peter Izsak and Moshe Wasserblat. [Q8BERT: Quantized 8Bit BERT](https://arxiv.org/abs/1910.06188).arXiv preprint arXiv:2009.12812.
|
||||
|
||||
# 模型架构
|
||||
|
||||
|
@ -50,17 +40,19 @@ Q8BERT模型的主干结构是transformer,一个转换器包含12个编码器
|
|||
|
||||
# 数据集
|
||||
|
||||
- 下载GLUE数据集进行任务蒸馏。将数据集由JSON格式转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的run_classifier.py文件。
|
||||
- 下载GLUE数据集进行微调。将数据集由JSON格式转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的run_classifier.py文件。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend或GPU)
|
||||
- 使用Ascend或GPU处理器准备硬件环境。
|
||||
- 硬件
|
||||
- 使用Ascend或GPU平台。
|
||||
- 框架
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- 更多关于Mindspore的信息,请查看以下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
- 软件包:
|
||||
- numpy, sklearn
|
||||
|
||||
# 快速入门
|
||||
|
||||
|
@ -68,47 +60,10 @@ Q8BERT模型的主干结构是transformer,一个转换器包含12个编码器
|
|||
|
||||
```bash
|
||||
# 运行训练脚本
|
||||
run_train.sh
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
# 运行推理脚本
|
||||
bash run_eval.sh [TASK_NAME] [DEVICE_TARGET] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
|
||||
Before running the shell script, please set the `task_name`, `teacher_model_dir`, `student_model_dir` and `data_dir` in the run_train.sh file first.
|
||||
|
||||
```
|
||||
|
||||
若在Ascend设备上运行分布式训练,请提前创建JSON格式的HCCL配置文件。
|
||||
详情参见如下链接:
|
||||
https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
|
||||
|
||||
如需设置数据集格式和参数,请创建JSON格式的视图配置文件,详见[TFRecord](https://www.mindspore.cn/doc/programming_guide/zh-CN/master/dataset_loading.html#tfrecord) 格式。
|
||||
|
||||
```text
|
||||
For general task, schema file contains ["input_ids", "input_mask", "segment_ids"].
|
||||
|
||||
For task distill and eval phase, schema file contains ["input_ids", "input_mask", "segment_ids", "label_ids"].
|
||||
|
||||
`numRows` is the only option which could be set by user, the others value must be set according to the dataset.
|
||||
|
||||
For example, the dataset is cn-wiki-128, the schema file for general distill phase as following:
|
||||
{
|
||||
"datasetType": "TF",
|
||||
"numRows": 7680,
|
||||
"columns": {
|
||||
"input_ids": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [256]
|
||||
},
|
||||
"input_mask": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [256]
|
||||
},
|
||||
"segment_ids": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [256]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
@ -116,144 +71,146 @@ For example, the dataset is cn-wiki-128, the schema file for general distill pha
|
|||
## 脚本和样例代码
|
||||
|
||||
```shell
|
||||
.
|
||||
|
||||
└─q8bert
|
||||
├─README.md
|
||||
├─README.md # 英文说明文档
|
||||
├─README_CN.md # 中文说明文档
|
||||
├─scripts
|
||||
├─run_train.sh # 运行shell脚本
|
||||
├─run_standalone_train.sh # 运行shell训练脚本
|
||||
├─run_eval.sh # 运行shell推理脚本
|
||||
├─src
|
||||
├─__init__.py
|
||||
├─dataset.py # 数据处理
|
||||
├─bert_model.py # bert模型主体结构
|
||||
├─q8bert_model.py # bert模型量化感知算法
|
||||
├─q8bert.py # q8bert主体结构
|
||||
├─q8bert.py # Q8BERT主体结构
|
||||
├─utils.py # utils函数
|
||||
├─__init__.py
|
||||
├─run_train.py # 运行main函数
|
||||
|
||||
├─train.py # 执行训练
|
||||
├─eval.py # 执行推理
|
||||
├─export.py # 模型导出
|
||||
```
|
||||
|
||||
## 脚本和脚本参数
|
||||
|
||||
```text
|
||||
|
||||
用法: run_train.py [--h] [--device_target {GPU,Ascend}][--epoch_num EPOCH_NUM] [--task_name {SST-2,QNLI,MNLI,COLA,QQP,"STS-B,RTE}][--do_shuffle {true,false}] [--enable_data_sink {true,false}][--do_eval {true,false}][--device_id DEVICE_ID] [--save_ckpt_step SAVE_CKPT_STEP] [--eval_ckpt_step EVAL_CKPT_STEP] [--max_ckpt_num MAX_CKPT_NUM] [--load_ckpt_path LOAD_CKPT_PATH] [--train_data_dir TRAIN_DATA_DIR] [--eval_data_dir EVAL_DATA_DIR] [--device_id DEVICE_ID] [--logging_step LOGGIND_STEP] [--do_quant {true,false}]
|
||||
用法
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
|
||||
选项:
|
||||
--device_target 代码实现设备,可选项为Ascend或CPU。默认为GPU
|
||||
[TASK_NAME] Glue数据集任务: "STS-B"| "QNLI"| SST-2"
|
||||
[DEVICE_TARGET] 代码运行平台,可选项为Ascend或GPU
|
||||
[TRAIN_DATA_DIR] 训练集路径
|
||||
[EVAL_DATA_DIR] 验证集路径
|
||||
[LOAD_CKPT_PATH] 加载检查点文件的路径
|
||||
|
||||
或者
|
||||
|
||||
python train.py [--h] [--device_target {GPU,Ascend}] [--epoch_num EPOCH_NUM] [--task_name {SST-2, QNLI, STS-B}]
|
||||
[--do_shuffle {True,False}] [--enable_data_sink {True,False}] [--do_eval {True,False}]
|
||||
[--device_id DEVICE_ID] [--save_ckpt_step SAVE_CKPT_STEP] [--eval_ckpt_step EVAL_CKPT_STEP]
|
||||
[--max_ckpt_num MAX_CKPT_NUM] [--load_ckpt_path LOAD_CKPT_PATH] [--train_data_dir TRAIN_DATA_DIR]
|
||||
[--eval_data_dir EVAL_DATA_DIR] [--device_id DEVICE_ID] [--logging_step LOGGIND_STEP]
|
||||
[--do_quant {True,False}]
|
||||
选项:
|
||||
--device_target 代码运行平台,可选项为Ascend或GPU,默认为Ascend
|
||||
--do_eval 是否在训练的过程中加上推理默认为是
|
||||
--epoch_num Epoch数,默认为3
|
||||
--epoch_num Epoch数,默认为3
|
||||
--device_id 设备ID,默认为0
|
||||
--do_shuffle 是否使能轮换,可选项为true或false,默认为true
|
||||
--enable_data_sink 是否使能数据下沉,可选项为true或false,默认为true
|
||||
--save_ckpt_step 保存检查点文件的步数,默认为1000
|
||||
--eval_ckpt_step 如过do_eval为是, 在训练过程中执行推理的步数
|
||||
--eval_ckpt_step 如过do_eval为是, 在训练过程中执行推理的步数
|
||||
--max_ckpt_num 保存检查点文件的最大数,默认为1
|
||||
--data_sink_steps 设置数据下沉步数,默认为1
|
||||
--load_ckpt_path 加载检查点文件的路径,默认为""
|
||||
--train_data_dir 训练集路径, 默认为 ""
|
||||
--eval_data_dir 验证集路径, 默认为 ""
|
||||
--task_name Glue数据集任务: "SST-2"| "QNLI"| "MNLI"|"COLA"|"QQP"|"STS-B"|"RTE"
|
||||
--train_data_dir 训练集路径, 默认为 ""
|
||||
--eval_data_dir 验证集路径, 默认为 ""
|
||||
--task_name Glue数据集任务: "STS-B"| "QNLI"| SST-2"
|
||||
--dataset_type 数据集类型,可选项为tfrecord或mindrecord,默认为tfrecord
|
||||
--train_batch_size 训练batchsize,默认16
|
||||
--eval_batch_size 推理batchsize,默认32
|
||||
|
||||
```
|
||||
|
||||
## 选项及参数
|
||||
|
||||
`config.py` 包含BERT模型参数与优化器和损失缩放选项。
|
||||
|
||||
### 选项
|
||||
|
||||
```text
|
||||
|
||||
batch_size 输入数据集的批次大小,默认为16
|
||||
Parameters for lossscale:
|
||||
loss_scale_value 损失放大初始值,默认为
|
||||
scale_factor 损失放大的更新因子,默认为2
|
||||
scale_window 损失放大的一次更新步数,默认为50
|
||||
|
||||
Parameters for optimizer:
|
||||
learning_rate 学习率
|
||||
end_learning_rate 结束学习率,取值需为正数
|
||||
power 幂
|
||||
weight_decay 权重衰减
|
||||
eps 增加分母,提高小数稳定性
|
||||
|
||||
```
|
||||
|
||||
### 参数
|
||||
|
||||
```text
|
||||
|
||||
Parameters for bert network:
|
||||
seq_length 输入序列的长度,默认为128
|
||||
vocab_size 各内嵌向量大小,需与所采用的数据集相同。默认为30522
|
||||
hidden_size BERT的encoder层数
|
||||
num_hidden_layers 隐藏层数
|
||||
num_attention_heads 注意头的数量,默认为12
|
||||
intermediate_size 中间层数
|
||||
hidden_act 所采用的激活函数,默认为gelu
|
||||
hidden_dropout_prob BERT输出的随机失活可能性
|
||||
attention_probs_dropout_prob BERT注意的随机失活可能性
|
||||
max_position_embeddings 序列最大长度,默认为512
|
||||
save_ckpt_step 保存检查点数量,默认为100
|
||||
max_ckpt_num 保存检查点最大数量,默认为1
|
||||
type_vocab_size 标记类型的词汇表大小,默认为2
|
||||
initializer_range TruncatedNormal的初始值,默认为0.02
|
||||
use_relative_positions 是否采用相对位置,可选项为true或false,默认为False
|
||||
dtype 输入的数据类型,可选项为mstype.float16或mstype.float32,默认为mstype.float32
|
||||
compute_type Bert Transformer的计算类型,可选项为mstype.float16或mstype.float32,默认为mstype.float16
|
||||
--train_batch_size 训练batchsize,默认16
|
||||
--eval_batch_size 推理batchsize,默认32
|
||||
|
||||
```
|
||||
|
||||
## 训练流程
|
||||
|
||||
### 用法
|
||||
### Ascend和GPU平台上运行
|
||||
|
||||
#### Ascend处理器上运行
|
||||
运行以下命令前,确保已设置所有必需参数。建议路径参数设置成绝对路径。DEVICE_TARGET参数可选项为Ascend和GPU,分别代表模型在Ascend和GPU平台运行。
|
||||
|
||||
运行以下命令前,确保已设置'data_dir'和'load_ckpt_path'。请将路径设置为绝对全路径,例如/username/checkpoint_100_300.ckpt。
|
||||
### 基于STS-B数据集进行训练
|
||||
|
||||
本模型目前支持”STS-B“,”QNLI“,“SST-2”数据集,以”STS-B“为例进行评估。
|
||||
|
||||
```text
|
||||
|
||||
python
|
||||
python ./run_train.py --device_target="GPU" --do_eval="true" --epoch_num=3 --task_name="STS-B" --do_shuffle="true" --enable_data_sink="true" --data_sink_steps=100 --save_ckpt_step=100 --max_ckpt_num=1 --load_ckpt_path="sts-b.ckpt" --train_data_dir="sts-b/train.tf_record" --eval_data_dir="sts-b/eval.tf_record" --device_id=0 --logging_step=100 --do_quant="true"
|
||||
shell
|
||||
sh run_train.sh
|
||||
|
||||
以上命令后台运行,您可以在log.txt文件中查看运行结果。训练结束后,您可以在默认脚本文件夹中找到检查点文件。得到如下损失值:
|
||||
epoch: 1, step: 100, loss are (Tensor(shape=[], dtype=Float32, value= 0.526506), Tensor(shape=[], dtype=Bool, value= False)) The current result is {'pearson': 0.8407084843799768, 'spearmanr': 0.8405771469597393, 'corr': 0.840642815669858} epoch time: 66421.602 ms, per step time: 664.216 ms
|
||||
epoch: 2, step: 200, loss are (Tensor(shape=[], dtype=Float32, value= 0.406012), Tensor(shape=[], dtype=Bool, value= False)) The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087} epoch time: 47488.633 ms, per step time: 474.886 ms
|
||||
...
|
||||
best pearson:0.8753269455187238
|
||||
bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
example:
|
||||
bash run_standalone_train.sh STS-B Ascend /path/sts-b/train.tf_record /path/sts-b/eval.tf_record /path/xxx.ckpt
|
||||
|
||||
```
|
||||
|
||||
## 模型描述
|
||||
以上命令后台运行,可以在train_log.txt文件中查看运行结果:
|
||||
|
||||
```text
|
||||
epoch: 1, step: 100, loss: 0.526506
|
||||
The current result is {'pearson': 0.8407084843799768, 'spearmanr': 0.8405771469597393, 'corr': 0.840642815669858}, the best result is 0.8407084843799768
|
||||
epoch time: 147446.514 ms, per step time: 1474.465 ms
|
||||
epoch: 2, step: 200, loss: 0.406012
|
||||
The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087}, the best result is 0.8407084843799768
|
||||
epoch time: 93688.080 ms, per step time: 936.881 ms
|
||||
...
|
||||
|
||||
训练结束后,可以在工程根目录对应的文件夹中找到检查点文件。
|
||||
|
||||
```
|
||||
|
||||
## 评估流程
|
||||
|
||||
### 基于STS-B数据集进行评估
|
||||
|
||||
```text
|
||||
shell
|
||||
bash run_eval.sh [TASK_NAME] [DEVICE_TARGET] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]
|
||||
example:
|
||||
bash run_eval.sh STS-B Ascend /path/sts-b/eval.tf_record /path/xxx.ckpt
|
||||
```
|
||||
|
||||
以上命令后台运行,可以在eval_log.txt文件中查看运行结果:
|
||||
|
||||
```text
|
||||
The current result is {'pearson': 0.826509808575773, 'spearmanr': 0.8274141859302444, 'corr': 0.8269619972530087}, the best result is 0.8407084843799768
|
||||
|
||||
```
|
||||
|
||||
## 模型导出
|
||||
|
||||
```text
|
||||
python export.py --task_name [TASK_NAME] --ckpt_file [CKPT_FILE] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
模型导出格式选项:["AIR", "MINDIR"]
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | :---------------------------------------------------- |
|
||||
| 模型 | Q8BERT |
|
||||
| 资源 | NV GeForce GTX1080ti |
|
||||
| 测试时间 | 03/01/2020 |
|
||||
| MindSpore版本 | 1.1.0 |
|
||||
| 数据集 | STS-B |
|
||||
| batch size | 16 |
|
||||
| 结果 | 87.5833 |
|
||||
| 速度 | 0.47s/step |
|
||||
| 总时间 | 9.1min(3epoch, 1p) |
|
||||
| 参数 | Ascend | GPU |
|
||||
| -------------------------- | ---------------------------------------------------------- | ------------------------- |
|
||||
| 模型版本 | Q8BERT | Q8BERT |
|
||||
| 资源 | Ascend 910,cpu 2.60GHz,192核,内存 755G,系统 Euler2.8 | NV GeForce GTX1080ti,cpu 2.00GHz,56核,内存 251G,系统 Ubuntu16.04 |
|
||||
| 上传日期 | 2021-6-8 | 2021-6-8 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | STS-B | STS-B |
|
||||
| 总时长 | 11分钟 (3轮, 1卡) | 18分钟 (3轮, 1卡) |
|
||||
| 精度 | 89.14 | 89.18 |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
run_train.py脚本中设置了do_shuffle来轮换数据集。
|
||||
run_train.py脚本中设置了do_shuffle参数用于轮换数据集。
|
||||
|
||||
config.py文件中设置了hidden_dropout_prob和attention_pros_dropout_prob,使网点随机失活。
|
||||
config.py文件中设置hidden_dropout_prob和attention_pros_dropout_prob参数,使网络节点随机失活。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""q8bert eval"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.q8bert import BertNetworkWithLoss_td
|
||||
from src.config import eval_cfg, model_cfg, glue_output_modes, task_params
|
||||
from src.utils import glue_compute_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Q8Bert task eval')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="False", choices=["True", "False"],
|
||||
help="Enable shuffle for dataset, default is True.")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="",
|
||||
help="Eval data path, it is better to use absolute path")
|
||||
parser.add_argument("--load_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--do_quant", type=str, default="False", help="Do quant for model")
|
||||
parser.add_argument("--task_name", type=str, default="STS-B", choices=["STS-B", "QNLI", "SST-2"],
|
||||
help="The name of the task to eval.")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
args_opt = parse_args()
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Encapsulation class of get the task parameter.
|
||||
"""
|
||||
def __init__(self, task_name):
|
||||
self.task_name = task_name
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["num_labels"]
|
||||
return DEFAULT_NUM_LABELS
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
|
||||
|
||||
task = Task(args_opt.task_name)
|
||||
|
||||
|
||||
def do_eval():
|
||||
"""
|
||||
do eval
|
||||
"""
|
||||
ckpt_file = args_opt.load_ckpt_path
|
||||
|
||||
if ckpt_file == '':
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
load_student_checkpoint_path = ckpt_file
|
||||
netwithloss = BertNetworkWithLoss_td(student_config=model_cfg, student_ckpt=load_student_checkpoint_path,
|
||||
do_quant=args_opt.do_quant, is_training=True,
|
||||
task_type=glue_output_modes[args_opt.task_name.lower()],
|
||||
num_labels=task.num_labels, is_predistill=False)
|
||||
eval_network = netwithloss.bert
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
eval_dataset = create_dataset(eval_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
drop_remainder=False)
|
||||
dataset_size = eval_dataset.get_dataset_size()
|
||||
print('eval dataset size: ', dataset_size)
|
||||
|
||||
label_nums = 2
|
||||
if args_opt.task_name.lower == 'mnli':
|
||||
label_nums = 3
|
||||
eval_network.set_train(False)
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for data in eval_dataset.create_dict_iterator(num_epochs=1):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
_, _, logits, _ = eval_network(input_ids, token_type_id, input_mask)
|
||||
if preds is None:
|
||||
preds = logits.asnumpy()
|
||||
preds = np.reshape(preds, [-1, label_nums])
|
||||
out_label_ids = label_ids.asnumpy()
|
||||
else:
|
||||
preds = np.concatenate((preds, np.reshape(logits.asnumpy(), [-1, label_nums])), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, label_ids.asnumpy())
|
||||
if glue_output_modes[args_opt.task_name.lower()] == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif glue_output_modes[args_opt.task_name.lower()] == "regression":
|
||||
preds = np.reshape(preds, [-1])
|
||||
result = glue_compute_metrics(args_opt.task_name.lower(), preds, out_label_ids)
|
||||
print("The current result is {}".format(result))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_cfg.seq_length = task.seq_length
|
||||
do_eval()
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""export checkpoint file into model"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import model_cfg, task_params
|
||||
from src.q8bert_model import BertModelCLS
|
||||
|
||||
parser = argparse.ArgumentParser(description="Q8Bert export model")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="device where the code will be implemented. (Default: Ascend)")
|
||||
parser.add_argument("--task_name", type=str, default="STS-B", choices=["STS-B", "QNLI", "SST-2"],
|
||||
help="The name of the task to eval.")
|
||||
parser.add_argument("--file_name", type=str, default="q8bert", help="The name of the output file.")
|
||||
parser.add_argument("--file_format", type=str, default="AIR", choices=["AIR", "MINDIR"],
|
||||
help="output model type")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="pretrained checkpoint file")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
DEFAULT_BS = 32
|
||||
|
||||
|
||||
class Task:
|
||||
"""
|
||||
Encapsulation class of get the task parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, task_name):
|
||||
self.task_name = task_name
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["num_labels"]
|
||||
return DEFAULT_NUM_LABELS
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
task = Task(args.task_name)
|
||||
model_cfg.seq_length = task.seq_length
|
||||
model_cfg.batch_size = DEFAULT_BS
|
||||
eval_model = BertModelCLS(model_cfg, False, task.num_labels, 0.0, phase_type="student")
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(eval_model, param_dict)
|
||||
eval_model.set_train(False)
|
||||
|
||||
input_ids = Tensor(np.zeros((model_cfg.batch_size, task.seq_length), np.int32))
|
||||
token_type_id = Tensor(np.zeros((model_cfg.batch_size, task.seq_length), np.int32))
|
||||
input_mask = Tensor(np.zeros((model_cfg.batch_size, task.seq_length), np.int32))
|
||||
|
||||
input_data = [input_ids, token_type_id, input_mask]
|
||||
export(eval_model, *input_data, file_name=args.file_name, file_format=args.file_format, quant_model="QUANT")
|
|
@ -14,27 +14,30 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh [TASK_NAME] [DEVICE_TARGET] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]"
|
||||
echo "for example: bash run_eval.sh STS-B Ascend /path/sts-b/eval.tf_record /path/xxx.ckpt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
task_name=$1
|
||||
device_target=$2
|
||||
eval_data_dir=$3
|
||||
load_ckpt_path=$4
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python ${PROJECT_DIR}/../run_train.py \
|
||||
--device_target="Ascend" \
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--device_target=$device_target \
|
||||
--device_id=0 \
|
||||
--do_eval="true" \
|
||||
--epoch_num=3 \
|
||||
--task_name="" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
--save_ckpt_step=100 \
|
||||
--max_ckpt_num=1 \
|
||||
--load_ckpt_path="" \
|
||||
--train_data_dir="" \
|
||||
--eval_data_dir="" \
|
||||
--device_id="" \
|
||||
--logging_step=100\
|
||||
--do_quant="true" > log.txt 2>&1 &
|
||||
--task_name=$task_name \
|
||||
--do_shuffle="False" \
|
||||
--load_ckpt_path=$load_ckpt_path \
|
||||
--eval_data_dir=$eval_data_dir \
|
||||
--do_quant="True" > eval_log.txt 2>&1 &
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train.sh [TASK_NAME] [DEVICE_TARGET] [TRAIN_DATA_DIR] [EVAL_DATA_DIR] [LOAD_CKPT_PATH]"
|
||||
echo "for example: bash run_standalone_train.sh STS-B Ascend /path/sts-b/train.tf_record /path/sts-b/eval.tf_record /path/xxx.ckpt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
task_name=$1
|
||||
device_target=$2
|
||||
train_data_dir=$3
|
||||
eval_data_dir=$4
|
||||
load_ckpt_path=$5
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--device_target=$device_target \
|
||||
--device_id=0 \
|
||||
--do_eval="True" \
|
||||
--epoch_num=3 \
|
||||
--task_name=$task_name \
|
||||
--do_shuffle="True" \
|
||||
--enable_data_sink="True" \
|
||||
--data_sink_steps=100 \
|
||||
--save_ckpt_step=100 \
|
||||
--max_ckpt_num=1 \
|
||||
--load_ckpt_path=$load_ckpt_path \
|
||||
--train_data_dir=$train_data_dir \
|
||||
--eval_data_dir=$eval_data_dir \
|
||||
--logging_step=100 \
|
||||
--do_quant="True" > train_log.txt 2>&1 &
|
||||
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.initializer import Normal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -49,7 +49,7 @@ class BertConfig:
|
|||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
|
@ -98,7 +98,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -111,11 +111,11 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -148,7 +148,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
|
@ -170,7 +170,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[token_type_vocab_size,
|
||||
embedding_size]))
|
||||
self.shape_flat = (-1,)
|
||||
|
@ -182,11 +182,11 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
embedding_size]))
|
||||
|
||||
|
@ -220,7 +220,7 @@ class BertOutput(nn.Cell):
|
|||
Args:
|
||||
in_channels (int): Input channels.
|
||||
out_channels (int): Output channels.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
@ -233,9 +233,9 @@ class BertOutput(nn.Cell):
|
|||
compute_type=mstype.float32,):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=Normal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.add = P.TensorAdd()
|
||||
self.add = P.Add()
|
||||
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||
if self.is_gpu:
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
|
||||
|
@ -302,7 +302,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
initializer_range (float): Initialization value of Normal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
|
@ -317,7 +317,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
initializer(Normal(initializer_range),
|
||||
[self.vocab_size, self.depth]))
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
|
@ -326,7 +326,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
@ -393,7 +393,7 @@ class BertAttention(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||
tensor. Default: False.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
|
@ -428,7 +428,7 @@ class BertAttention(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
weight = Normal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
|
@ -457,7 +457,7 @@ class BertAttention(nn.Cell):
|
|||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.TensorAdd()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if do_return_2d_tensor:
|
||||
|
@ -565,7 +565,7 @@ class BertSelfAttention(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
|
@ -628,7 +628,7 @@ class BertEncoderCell(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.02.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function. Default: "gelu".
|
||||
|
@ -661,7 +661,7 @@ class BertEncoderCell(nn.Cell):
|
|||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=Normal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
|
@ -692,7 +692,7 @@ class BertTransformer(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
|
@ -840,7 +840,7 @@ class BertModel(nn.Cell):
|
|||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
weight_init=Normal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
|
@ -936,7 +936,7 @@ class TinyBertModel(nn.Cell):
|
|||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
weight_init=Normal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
|
@ -981,7 +981,7 @@ class BertModelCLS(nn.Cell):
|
|||
super(BertModelCLS, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.weight_init = Normal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
|
@ -1008,5 +1008,4 @@ class BertModelCLS(nn.Cell):
|
|||
if self._phase == 'train' or self.phase_type == "teacher":
|
||||
return seq_output, att_output, logits, log_probs
|
||||
# return log_probs
|
||||
|
||||
return seq_output, att_output, logits, log_probs
|
||||
|
|
|
@ -20,12 +20,12 @@ from easydict import EasyDict as edict
|
|||
from .q8bert_model import BertConfig
|
||||
train_cfg = edict({
|
||||
'batch_size': 16,
|
||||
'loss_scale_value': 2 ** 16,
|
||||
'loss_scale_value': 1,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 50,
|
||||
'optimizer_cfg': edict({
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 5e-5,
|
||||
'learning_rate': 1e-5,
|
||||
'end_learning_rate': 1e-14,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-4,
|
||||
|
@ -44,12 +44,12 @@ model_cfg = BertConfig(
|
|||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=6,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
|
@ -57,3 +57,20 @@ model_cfg = BertConfig(
|
|||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
)
|
||||
|
||||
glue_output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
||||
"QNLI": {"num_labels": 2, "seq_length": 128},
|
||||
"STS-B": {"num_labels": 1, "seq_length": 128}}
|
||||
|
|
|
@ -13,22 +13,23 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""create tinybert dataset"""
|
||||
"""create q8bert dataset"""
|
||||
|
||||
from enum import Enum
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Enumerate supported dataset format"""
|
||||
TFRECORD = 1
|
||||
MINDRECORD = 2
|
||||
|
||||
def create_tinybert_dataset(batch_size=32, device_num=1, rank=0,
|
||||
do_shuffle="true", data_dir=None, schema_dir=None,
|
||||
data_type=DataType.TFRECORD, seq_length=128, task_type=mstype.int32, drop_remainder=True):
|
||||
"""create tinybert dataset"""
|
||||
|
||||
def create_dataset(batch_size=32, device_num=1, rank=0, do_shuffle="True", data_dir=None, data_type=DataType.TFRECORD,
|
||||
seq_length=128, drop_remainder=True):
|
||||
"""create q8bert dataset"""
|
||||
if isinstance(data_dir, list):
|
||||
data_files = data_dir
|
||||
else:
|
||||
|
@ -36,56 +37,25 @@ def create_tinybert_dataset(batch_size=32, device_num=1, rank=0,
|
|||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
shard_equal_rows = True
|
||||
shuffle = (do_shuffle == "true")
|
||||
shuffle = (do_shuffle == "True")
|
||||
if device_num == 1:
|
||||
shard_equal_rows = False
|
||||
shuffle = False
|
||||
if data_type == DataType.MINDRECORD:
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank)
|
||||
else:
|
||||
ds = de.TFRecordDataset(data_files, None, columns_list=columns_list,
|
||||
shuffle=shuffle, num_shards=device_num, shard_id=rank,
|
||||
ds = de.TFRecordDataset(data_files, None, columns_list=columns_list, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=shard_equal_rows)
|
||||
if device_num == 1 and shuffle is True:
|
||||
ds = ds.shuffle(10000)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
slice_op = C.Slice(slice(0, seq_length, 1))
|
||||
label_type = mstype.float32
|
||||
# label_type = mstype.int32 if task_type == 'classification' else mstype.float32
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["segment_ids"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_mask"])
|
||||
ds = ds.map(operations=[type_cast_op, slice_op], input_columns=["input_ids"])
|
||||
ds = ds.map(operations=[C.TypeCast(label_type), slice_op], input_columns=["label_ids"])
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
return ds
|
||||
|
||||
def generator_squad(data_features):
|
||||
for feature in data_features:
|
||||
yield (feature.input_ids, feature.input_mask, feature.segment_ids, feature.unique_id)
|
||||
|
||||
|
||||
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None,
|
||||
is_training=True, do_shuffle=True):
|
||||
"""create finetune or evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
if is_training:
|
||||
data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "start_positions",
|
||||
"end_positions", "unique_ids", "is_impossible"],
|
||||
shuffle=do_shuffle)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="start_positions")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="end_positions")
|
||||
else:
|
||||
data_set = ds.GeneratorDataset(generator_squad(data_file_path), shuffle=do_shuffle,
|
||||
column_names=["input_ids", "input_mask", "segment_ids", "unique_ids"])
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="unique_ids")
|
||||
data_set = data_set.repeat(repeat_count)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""Tinybert model"""
|
||||
"""q8bert model"""
|
||||
|
||||
import re
|
||||
import mindspore.nn as nn
|
||||
|
@ -27,8 +27,9 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from .q8bert_model import BertModel, TinyBertModel, BertModelCLS
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from .q8bert_model import BertModelCLS
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
@ -57,13 +58,23 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
@ -100,6 +111,7 @@ class ClipGradients(nn.Cell):
|
|||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
|
||||
class SoftCrossEntropy(nn.Cell):
|
||||
"""SoftCrossEntropy loss"""
|
||||
def __init__(self):
|
||||
|
@ -116,88 +128,6 @@ class SoftCrossEntropy(nn.Cell):
|
|||
|
||||
return self.cast(loss, mstype.float32)
|
||||
|
||||
class BertNetworkWithLoss_gd(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, teacher_config, teacher_ckpt, student_config, is_training, use_one_hot_embeddings=False,
|
||||
is_att_fit=True, is_rep_fit=True):
|
||||
super(BertNetworkWithLoss_gd, self).__init__()
|
||||
# load teacher model
|
||||
self.teacher = BertModel(teacher_config, False, use_one_hot_embeddings)
|
||||
param_dict = load_checkpoint(teacher_ckpt)
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('^bert.bert.', 'teacher.', key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.teacher, new_param_dict)
|
||||
# no_grad
|
||||
self.teacher.set_train(False)
|
||||
params = self.teacher.trainable_params()
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
# student model
|
||||
self.bert = TinyBertModel(student_config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.fit_dense = nn.Dense(student_config.hidden_size,
|
||||
teacher_config.hidden_size).to_float(teacher_config.compute_type)
|
||||
self.teacher_layers_num = teacher_config.num_hidden_layers
|
||||
self.student_layers_num = student_config.num_hidden_layers
|
||||
self.layers_per_block = int(self.teacher_layers_num / self.student_layers_num)
|
||||
self.is_att_fit = is_att_fit
|
||||
self.is_rep_fit = is_rep_fit
|
||||
self.loss_mse = nn.MSELoss()
|
||||
self.select = P.Select()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.dtype = teacher_config.dtype
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id):
|
||||
"""general distill network with loss"""
|
||||
# teacher model
|
||||
_, _, _, teacher_seq_output, teacher_att_output = self.teacher(input_ids, token_type_id, input_mask)
|
||||
# student model
|
||||
_, _, _, student_seq_output, student_att_output = self.bert(input_ids, token_type_id, input_mask)
|
||||
total_loss = 0
|
||||
if self.is_att_fit:
|
||||
selected_teacher_att_output = ()
|
||||
selected_student_att_output = ()
|
||||
for i in range(self.student_layers_num):
|
||||
selected_teacher_att_output += (teacher_att_output[(i + 1) * self.layers_per_block - 1],)
|
||||
selected_student_att_output += (student_att_output[i],)
|
||||
att_loss = 0
|
||||
for i in range(self.student_layers_num):
|
||||
student_att = selected_student_att_output[i]
|
||||
teacher_att = selected_teacher_att_output[i]
|
||||
student_att = self.select(student_att <= self.cast(-100.0, mstype.float32), self.zeroslike(student_att),
|
||||
student_att)
|
||||
teacher_att = self.select(teacher_att <= self.cast(-100.0, mstype.float32), self.zeroslike(teacher_att),
|
||||
teacher_att)
|
||||
att_loss += self.loss_mse(student_att, teacher_att)
|
||||
total_loss += att_loss
|
||||
if self.is_rep_fit:
|
||||
selected_teacher_seq_output = ()
|
||||
selected_student_seq_output = ()
|
||||
for i in range(self.student_layers_num + 1):
|
||||
selected_teacher_seq_output += (teacher_seq_output[i * self.layers_per_block],)
|
||||
fit_dense_out = self.fit_dense(student_seq_output[i])
|
||||
fit_dense_out = self.cast(fit_dense_out, self.dtype)
|
||||
selected_student_seq_output += (fit_dense_out,)
|
||||
rep_loss = 0
|
||||
for i in range(self.student_layers_num + 1):
|
||||
teacher_rep = selected_teacher_seq_output[i]
|
||||
student_rep = selected_student_seq_output[i]
|
||||
rep_loss += self.loss_mse(student_rep, teacher_rep)
|
||||
total_loss += rep_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
class BertTrainWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
|
@ -235,7 +165,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.depend_parameter_use = P.Depend()
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
@ -289,6 +219,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
@ -343,6 +274,7 @@ class BertTrainCell(nn.Cell):
|
|||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
|
||||
class BertNetworkWithLoss_td(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
|
@ -377,13 +309,13 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', 'bert.' + key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.bert, new_param_dict)
|
||||
load_nonquant_param_into_quant_net(self.bert, new_param_dict)
|
||||
else:
|
||||
new_param_dict = {}
|
||||
for key, value in param_dict.items():
|
||||
new_key = re.sub('tinybert_', 'bert_', key)
|
||||
new_param_dict[new_key] = value
|
||||
load_param_into_net(self.bert, new_param_dict)
|
||||
load_nonquant_param_into_quant_net(self.bert, new_param_dict)
|
||||
self.cast = P.Cast()
|
||||
self.student_layers_num = student_config.num_hidden_layers
|
||||
self.is_predistill = is_predistill
|
||||
|
@ -421,6 +353,7 @@ class BertNetworkWithLoss_td(nn.Cell):
|
|||
total_loss += cls_loss
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class BertEvaluationWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
specifically defined for finetuning where only four inputs tensor are needed.
|
||||
|
@ -445,11 +378,18 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.gpu_target = False
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.depend_parameter_use = P.Depend()
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
@ -467,6 +407,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
init = False
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -475,9 +416,12 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
if not self.gpu_target:
|
||||
init = self.alloc_status()
|
||||
clear_before_grad = self.clear_before_grad(init)
|
||||
F.depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -485,11 +429,19 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
flag = self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
F.depend(grads, flag)
|
||||
F.depend(flag, flag_sum)
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
|
@ -503,7 +455,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
|
|
|
@ -19,15 +19,13 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.compression.common import QuantDtype
|
||||
from mindspore.ops import operations as P, Primitive
|
||||
from mindspore.common.initializer import Normal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore import context
|
||||
from mindspore.nn.layer.quant import FakeQuantWithMinMaxObserver as FakeQuantWithMinMax, quant_config_default
|
||||
from mindspore.nn.layer.quant import FakeQuantWithMinMaxObserver as FakeQuantWithMinMax
|
||||
|
||||
|
||||
class BertConfig:
|
||||
|
@ -52,7 +50,7 @@ class BertConfig:
|
|||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
|
@ -101,7 +99,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -114,11 +112,11 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -151,7 +149,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
|
@ -173,7 +171,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[token_type_vocab_size,
|
||||
embedding_size]))
|
||||
self.shape_flat = (-1,)
|
||||
|
@ -185,11 +183,11 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
(Normal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
embedding_size]))
|
||||
|
||||
|
@ -217,94 +215,6 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
class QuantDense(nn.Cell):
|
||||
"""
|
||||
The fake quant fully connected layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (Function): activate function applied to the output of the fully connected layer, e.g. 'ReLU'.
|
||||
Default: None.
|
||||
quant_config (QuantConfig): default quant config.
|
||||
quant_dtype (QuantDtype): the bits of quantization, Default: 8bit.
|
||||
activation_init (float): init activate quant value. Default: 6.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8,
|
||||
activation_init=2.5):
|
||||
super(QuantDense, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.activation = nn.get_activation(activation) if isinstance(activation, str) else activation
|
||||
if activation is not None and not isinstance(self.activation, (nn.Cell, Primitive)):
|
||||
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
|
||||
self.activation_flag = self.activation is not None
|
||||
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
quant_dtype=quant_dtype)
|
||||
self.quant_input = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
|
||||
def construct(self, x):
|
||||
"""Use operators to construct the Dense layer."""
|
||||
output = self.fake_quant_weight(self.weight)
|
||||
x = self.quant_input(x)
|
||||
output = self.matmul(x, output)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
"""A pretty print for Dense layer."""
|
||||
s = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
|
||||
self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
s += ', bias={}'.format(self.bias)
|
||||
if self.activation_flag:
|
||||
s += ', activation={}'.format(self.activation)
|
||||
return s
|
||||
|
||||
|
||||
class BertOutput(nn.Cell):
|
||||
"""
|
||||
Apply a linear computation to hidden status and a residual computation to input.
|
||||
|
@ -312,7 +222,7 @@ class BertOutput(nn.Cell):
|
|||
Args:
|
||||
in_channels (int): Input channels.
|
||||
out_channels (int): Output channels.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
@ -325,11 +235,10 @@ class BertOutput(nn.Cell):
|
|||
compute_type=mstype.float32,
|
||||
activation_init=2.5):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = QuantDense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range),
|
||||
activation_init=activation_init).to_float(compute_type)
|
||||
self.dense = nn.DenseQuant(in_channels, out_channels,
|
||||
weight_init=Normal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.add = P.TensorAdd()
|
||||
self.add = P.Add()
|
||||
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||
if self.is_gpu:
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
|
||||
|
@ -337,10 +246,18 @@ class BertOutput(nn.Cell):
|
|||
else:
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
self.quant_bert_in = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_bert_out = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
"""bert output"""
|
||||
hidden_status = self.quant_bert_in(hidden_status)
|
||||
output = self.dense(hidden_status)
|
||||
output = self.quant_bert_out(output)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
|
@ -396,7 +313,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
initializer_range (float): Initialization value of Normal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
|
@ -411,7 +328,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
initializer(Normal(initializer_range),
|
||||
[self.vocab_size, self.depth]))
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
|
@ -420,7 +337,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
@ -484,12 +401,10 @@ class BertAttention(nn.Cell):
|
|||
key_act (str): Activation function for the key transform. Default: None.
|
||||
value_act (str): Activation function for the value transform. Default: None.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.0.
|
||||
attention_probs_dropout_prob (float): The dropout probability for BertAttention. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||
tensor. Default: False.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d tensor. Default: False.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||
"""
|
||||
|
@ -519,44 +434,18 @@ class BertAttention(nn.Cell):
|
|||
self.size_per_head = size_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
||||
self.compute_type = compute_type
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=self.compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
weight = Normal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.do_quant = True
|
||||
if self.do_quant:
|
||||
self.quant_from_tensor_2d = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_to_tensor_2d = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_query_layer = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_key_layer = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_attention_probs = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_value_layer = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
activation=query_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
# do quant:
|
||||
self.activation_init = activation_init
|
||||
self.activation = {"query_act": query_act, "key_act": key_act, "value_act": value_act}
|
||||
self._quant_init(from_tensor_width, to_tensor_width, units, weight)
|
||||
|
||||
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
|
@ -572,7 +461,7 @@ class BertAttention(nn.Cell):
|
|||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.TensorAdd()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if do_return_2d_tensor:
|
||||
|
@ -582,30 +471,51 @@ class BertAttention(nn.Cell):
|
|||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
if self.use_relative_positions:
|
||||
self._generate_relative_positions_embeddings = \
|
||||
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
||||
depth=size_per_head,
|
||||
max_relative_position=16,
|
||||
RelaPosEmbeddingsGenerator(length=to_seq_length, depth=size_per_head, max_relative_position=16,
|
||||
initializer_range=initializer_range,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
def _quant_init(self, from_tensor_width, to_tensor_width, units, weight):
|
||||
"""Init quantization operations"""
|
||||
self.quant_from_tensor_2d = FakeQuantWithMinMax(min_init=-self.activation_init,
|
||||
max_init=self.activation_init,
|
||||
ema=True)
|
||||
self.quant_to_tensor_2d = FakeQuantWithMinMax(min_init=-self.activation_init,
|
||||
max_init=self.activation_init,
|
||||
ema=True)
|
||||
self.quant_query_out = FakeQuantWithMinMax(min_init=-self.activation_init,
|
||||
max_init=self.activation_init,
|
||||
ema=True)
|
||||
self.quant_key_out = FakeQuantWithMinMax(min_init=-self.activation_init,
|
||||
max_init=self.activation_init,
|
||||
ema=True)
|
||||
self.quant_value_out = FakeQuantWithMinMax(min_init=-self.activation_init,
|
||||
max_init=self.activation_init,
|
||||
ema=True)
|
||||
self.query_layer = nn.DenseQuant(from_tensor_width, units, activation=self.activation["query_act"],
|
||||
weight_init=weight).to_float(self.compute_type)
|
||||
self.key_layer = nn.DenseQuant(to_tensor_width, units, activation=self.activation["key_act"],
|
||||
weight_init=weight).to_float(self.compute_type)
|
||||
self.value_layer = nn.DenseQuant(to_tensor_width, units, activation=self.activation["value_act"],
|
||||
weight_init=weight).to_float(self.compute_type)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""bert attention"""
|
||||
# reshape 2d/3d input tensors to 2d
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
if self.do_quant:
|
||||
from_tensor_2d = self.quant_from_tensor_2d(from_tensor_2d)
|
||||
to_tensor_2d = self.quant_to_tensor_2d(to_tensor_2d)
|
||||
query_out = self.query_layer(from_tensor_2d)
|
||||
# do quant:
|
||||
to_tensor_2d = self.quant_to_tensor_2d(to_tensor_2d)
|
||||
query_out = self.query_layer(self.quant_from_tensor_2d(from_tensor_2d))
|
||||
query_out = self.quant_query_out(query_out)
|
||||
key_out = self.key_layer(to_tensor_2d)
|
||||
value_out = self.value_layer(to_tensor_2d)
|
||||
key_out = self.quant_key_out(key_out)
|
||||
value_out = self.quant_value_out(self.value_layer(to_tensor_2d))
|
||||
|
||||
query_layer = self.reshape(query_out, self.shape_from)
|
||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||
key_layer = self.reshape(key_out, self.shape_to)
|
||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||
if self.do_quant:
|
||||
query_layer = self.quant_query_layer(query_layer)
|
||||
key_layer = self.quant_key_layer(key_layer)
|
||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
|
@ -615,22 +525,14 @@ class BertAttention(nn.Cell):
|
|||
# query_layer_t is [F, B, N, H]
|
||||
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||
# query_layer_r is [F, B * N, H]
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
query_layer_r = self.reshape(query_layer_t, (self.from_seq_length, -1, self.size_per_head))
|
||||
# key_position_scores is [F, B * N, F|T]
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
relations_keys)
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r, relations_keys)
|
||||
# key_position_scores_r is [F, B, N, F|T]
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.from_seq_length))
|
||||
key_position_scores_r = self.reshape(key_position_scores, (self.from_seq_length, -1,
|
||||
self.num_attention_heads, self.from_seq_length))
|
||||
# key_position_scores_r_t is [B, N, F, F|T]
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r, self.trans_shape_position)
|
||||
attention_scores = attention_scores + key_position_scores_r_t
|
||||
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||
if self.has_attention_mask:
|
||||
|
@ -643,9 +545,6 @@ class BertAttention(nn.Cell):
|
|||
attention_probs = self.dropout(attention_probs)
|
||||
value_layer = self.reshape(value_out, self.shape_to)
|
||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||
if self.do_quant:
|
||||
attention_probs = self.quant_attention_probs(attention_probs)
|
||||
value_layer = self.quant_value_layer(value_layer)
|
||||
context_layer = self.matmul(attention_probs, value_layer)
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
|
@ -655,23 +554,12 @@ class BertAttention(nn.Cell):
|
|||
# attention_probs_t is [F, B, N, T]
|
||||
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||
# attention_probs_r is [F, B * N, T]
|
||||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.to_seq_length))
|
||||
# value_position_scores is [F, B * N, H]
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
relations_values)
|
||||
# value_position_scores_r is [F, B, N, H]
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
# value_position_scores_r_t is [B, N, F, H]
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
attention_probs_r = self.reshape(attention_probs_t, (self.from_seq_length, -1, self.to_seq_length))
|
||||
value_position_scores = self.matmul(attention_probs_r, relations_values)
|
||||
value_position_scores_r = self.reshape(value_position_scores, (self.from_seq_length, -1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r, self.trans_shape_position)
|
||||
context_layer = context_layer + value_position_scores_r_t
|
||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||
context_layer = self.reshape(context_layer, self.shape_return)
|
||||
|
@ -689,7 +577,7 @@ class BertSelfAttention(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
|
@ -756,7 +644,7 @@ class BertEncoderCell(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.02.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function. Default: "gelu".
|
||||
|
@ -788,22 +676,31 @@ class BertEncoderCell(nn.Cell):
|
|||
use_relative_positions=use_relative_positions,
|
||||
compute_type=compute_type,
|
||||
activation_init=activation_init)
|
||||
self.intermediate = QuantDense(in_channels=hidden_size, out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.intermediate = nn.DenseQuant(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=Normal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type,
|
||||
activation_init=activation_init)
|
||||
self.quant_encoder_in = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_encoder_out = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
"""bert encoder cell"""
|
||||
# self-attention
|
||||
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
attention_output = self.quant_encoder_in(attention_output)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
intermediate_output = self.quant_encoder_out(intermediate_output)
|
||||
# add and normalize
|
||||
output = self.output(intermediate_output, attention_output)
|
||||
return output, attention_scores
|
||||
|
@ -822,7 +719,7 @@ class BertTransformer(nn.Cell):
|
|||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
initializer_range (float): Initialization value of Normal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
|
@ -843,8 +740,7 @@ class BertTransformer(nn.Cell):
|
|||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
return_all_encoders=False,
|
||||
activation_init=2.5):
|
||||
return_all_encoders=False):
|
||||
super(BertTransformer, self).__init__()
|
||||
self.return_all_encoders = return_all_encoders
|
||||
layers = []
|
||||
|
@ -920,7 +816,7 @@ class BertModel(nn.Cell):
|
|||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
use_one_hot_embeddings=False, activation_init=2.5):
|
||||
super(BertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
|
@ -969,9 +865,15 @@ class BertModel(nn.Cell):
|
|||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self.dense = nn.DenseQuant(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=Normal(config.initializer_range)).to_float(config.compute_type)
|
||||
self.quant_model_in = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self.quant_model_out = FakeQuantWithMinMax(min_init=-activation_init,
|
||||
max_init=activation_init,
|
||||
ema=True)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
|
@ -992,104 +894,9 @@ class BertModel(nn.Cell):
|
|||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
first_token = self.quant_model_in(first_token)
|
||||
pooled_output = self.dense(first_token)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
encoder_outputs = ()
|
||||
for output in encoder_layers:
|
||||
encoder_outputs += (self.cast(output, self.dtype),)
|
||||
attention_outputs = ()
|
||||
for output in layer_atts:
|
||||
attention_outputs += (self.cast(output, self.dtype),)
|
||||
return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs
|
||||
|
||||
|
||||
class TinyBertModel(nn.Cell):
|
||||
"""
|
||||
Bidirectional Encoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
super(TinyBertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
self.seq_length = config.seq_length
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embedding_size = config.hidden_size
|
||||
self.token_type_ids = None
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [-1, self.seq_length,
|
||||
self.embedding_size]
|
||||
self.tinybert_embedding_lookup = EmbeddingLookup(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range)
|
||||
self.tinybert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_token_type=True,
|
||||
token_type_vocab_size=config.type_vocab_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
dropout_prob=config.hidden_dropout_prob)
|
||||
self.tinybert_encoder = BertTransformer(
|
||||
hidden_size=self.hidden_size,
|
||||
seq_length=self.seq_length,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=config.intermediate_size,
|
||||
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range,
|
||||
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
return_all_encoders=True)
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""tiny bert model"""
|
||||
# embedding
|
||||
word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids)
|
||||
embedding_output = self.tinybert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
# bert encoder
|
||||
encoder_output, encoder_layers, layer_atts = self.tinybert_encoder(self.cast_compute_type(embedding_output),
|
||||
attention_mask)
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
pooled_output = self.quant_model_out(pooled_output)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
encoder_outputs = ()
|
||||
for output in encoder_layers:
|
||||
|
@ -1110,9 +917,10 @@ class BertModelCLS(nn.Cell):
|
|||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False, phase_type="student"):
|
||||
super(BertModelCLS, self).__init__()
|
||||
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.weight_init = Normal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
|
@ -1121,8 +929,8 @@ class BertModelCLS(nn.Cell):
|
|||
self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
else:
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dense_1 = nn.DenseQuant(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.ReLU()
|
||||
|
||||
def construct(self, input_ids, token_type_id, input_mask):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""tinybert utils"""
|
||||
"""q8bert utils"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
@ -24,7 +24,7 @@ from mindspore.train.callback import Callback
|
|||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
import mindspore.nn as nn
|
||||
from .config import glue_output_modes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -36,15 +36,16 @@ except (AttributeError, ImportError) as e:
|
|||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||
_has_sklearn = False
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
|
||||
|
||||
if _has_sklearn:
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
|
@ -54,7 +55,6 @@ if _has_sklearn:
|
|||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
|
@ -64,7 +64,6 @@ if _has_sklearn:
|
|||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
"""different dataset evaluation."""
|
||||
assert len(preds) == len(labels)
|
||||
|
@ -92,19 +91,21 @@ if _has_sklearn:
|
|||
raise KeyError(task_name)
|
||||
return result
|
||||
|
||||
glue_output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
|
||||
prior_index = {
|
||||
"cola": "mcc",
|
||||
"mnli": "acc",
|
||||
"mnli-mm": "acc",
|
||||
"mrpc": "acc",
|
||||
"sst-2": "acc",
|
||||
"sts-b": "pearson",
|
||||
"qqp": "acc",
|
||||
"qnli": "acc",
|
||||
"rte": "acc",
|
||||
"wnli": "acc",
|
||||
}
|
||||
|
||||
|
||||
class ModelSaveCkpt(Callback):
|
||||
"""
|
||||
Saves checkpoint.
|
||||
|
@ -129,13 +130,14 @@ class ModelSaveCkpt(Callback):
|
|||
saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
|
||||
if saved_ckpt_num > self.max_ckpt_num:
|
||||
oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
|
||||
path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||
self.save_ckpt_step))
|
||||
path = os.path.join(self.output_dir, "q8bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
|
||||
cb_params.cur_step_num))
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
"q8bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
cb_params.cur_step_num)))
|
||||
|
||||
|
||||
def make_directory(path: str):
|
||||
"""Make directory."""
|
||||
|
@ -161,6 +163,7 @@ def make_directory(path: str):
|
|||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
@ -179,19 +182,23 @@ class LossCallBack(Callback):
|
|||
def step_end(self, run_context):
|
||||
"""step end and print loss"""
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, loss are {}".format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
loss, _ = cb_params.net_outputs
|
||||
print("epoch: {}, step: {}, loss: {}".format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
loss))
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""Evaluation callback"""
|
||||
def __init__(self, network, dataset, task_name, logging_step):
|
||||
def __init__(self, network, dataset, task_name, logging_step, save_ckpt_dir):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.network = network
|
||||
self.global_acc = 0.0
|
||||
self.dataset = dataset
|
||||
self.task_name = task_name
|
||||
self.logging_step = logging_step
|
||||
self.best_result = 0.0
|
||||
self.save_ckpt_dir = save_ckpt_dir
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end and do evaluation"""
|
||||
|
@ -200,6 +207,7 @@ class EvalCallBack(Callback):
|
|||
if self.task_name.lower == 'mnli':
|
||||
label_nums = 3
|
||||
if cb_params.cur_step_num % self.logging_step == 0:
|
||||
self.network.set_train(False)
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
|
@ -208,7 +216,6 @@ class EvalCallBack(Callback):
|
|||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
self.network.set_train(False)
|
||||
_, _, logits, _ = self.network(input_ids, token_type_id, input_mask)
|
||||
if preds is None:
|
||||
preds = logits.asnumpy()
|
||||
|
@ -222,35 +229,17 @@ class EvalCallBack(Callback):
|
|||
elif glue_output_modes[self.task_name.lower()] == "regression":
|
||||
preds = np.reshape(preds, [-1])
|
||||
result = glue_compute_metrics(self.task_name.lower(), preds, out_label_ids)
|
||||
print("The current result is {}".format(result))
|
||||
prior_result = result[prior_index[self.task_name.lower()]]
|
||||
if prior_result > self.best_result:
|
||||
self.best_result = prior_result
|
||||
eval_model_ckpt_file = os.path.join(self.save_ckpt_dir, self.task_name.lower() + "_eval_model.ckpt")
|
||||
if os.path.exists(eval_model_ckpt_file):
|
||||
os.remove(eval_model_ckpt_file)
|
||||
save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
print("The current result is {}, the best result is {}".format(result, self.best_result))
|
||||
self.network.set_train(True)
|
||||
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
Find the ckpt finetune generated and load it into eval network.
|
||||
"""
|
||||
files = os.listdir(load_finetune_checkpoint_dir)
|
||||
pre_len = len(prefix)
|
||||
max_num = 0
|
||||
for filename in files:
|
||||
name_ext = os.path.splitext(filename)
|
||||
if name_ext[-1] != ".ckpt":
|
||||
continue
|
||||
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
|
||||
index = filename[pre_len:].find("-")
|
||||
if index == 0 and max_num == 0:
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
elif index not in (0, -1):
|
||||
name_split = name_ext[-2].split('_')
|
||||
if (steps_per_epoch != int(name_split[len(name_split)-1])) \
|
||||
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
|
||||
continue
|
||||
num = filename[pre_len + 1:pre_len + index]
|
||||
if int(num) > max_num:
|
||||
max_num = int(num)
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
return load_finetune_checkpoint_path
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
|
@ -277,31 +266,3 @@ class BertLearningRate(LearningRateSchedule):
|
|||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, logits, label_ids, num_labels):
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0
|
||||
return return_value
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""task distill script"""
|
||||
"""q8bert train"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
@ -24,44 +24,40 @@ from mindspore.nn.optim import AdamWeightDecay
|
|||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.dataset import create_tinybert_dataset
|
||||
from src.dataset import create_dataset
|
||||
from src.q8bert import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
||||
from src.config import train_cfg, eval_cfg, model_cfg
|
||||
from src.config import train_cfg, eval_cfg, model_cfg, glue_output_modes, task_params
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
save_ckpt_dir = os.path.join(_cur_dir, 'Q8Bert_save_ckpt')
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Q8Bert task distill')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--do_eval", type=str, default="true", choices=["true", "false"],
|
||||
help="Do eval task, default is true.")
|
||||
parser = argparse.ArgumentParser(description="Q8Bert task train")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="device where the code will be implemented. (Default: Ascend)")
|
||||
parser.add_argument("--do_eval", type=str, default="True", choices=["True", "False"],
|
||||
help="Do eval task, default is True.")
|
||||
parser.add_argument("--epoch_num", type=int, default=3, help="default is 3.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable data sink, default is true.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="True", choices=["True", "False"],
|
||||
help="Enable shuffle for dataset, default is True.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="True", choices=["True", "False"],
|
||||
help="Enable data sink, default is True.")
|
||||
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable save ckpt.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is True.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--load_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--train_data_dir", type=str, default="",
|
||||
help="Train data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_dir", type=str, default="",
|
||||
help="Eval data path, it is better to use absolute path")
|
||||
parser.add_argument("--do_quant", type=str, default="false", help="Do quant for model")
|
||||
parser.add_argument("--do_quant", type=str, default="True", help="Do quant for model")
|
||||
parser.add_argument("--logging_step", type=int, default=100, help="Do evalate each logging step")
|
||||
parser.add_argument("--task_name", type=str, default="COLA",
|
||||
choices=["SST-2", "QNLI", "MNLI", "COLA", "QQP", "STS-B", "RTE"],
|
||||
parser.add_argument("--task_name", type=str, default="STS-B", choices=["STS-B", "QNLI", "SST-2"],
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
|
@ -70,26 +66,13 @@ def parse_args():
|
|||
|
||||
|
||||
args_opt = parse_args()
|
||||
_cur_dir = os.getcwd()
|
||||
save_ckpt_dir = os.path.join(_cur_dir, "Q8Bert_" + args_opt.task_name + "_model")
|
||||
if not os.path.exists(save_ckpt_dir):
|
||||
os.makedirs(save_ckpt_dir)
|
||||
|
||||
DEFAULT_NUM_LABELS = 2
|
||||
DEFAULT_SEQ_LENGTH = 128
|
||||
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
||||
"QNLI": {"num_labels": 2, "seq_length": 128},
|
||||
"MNLI": {"num_labels": 3, "seq_length": 128},
|
||||
"STS-B": {"num_labels": 1, "seq_length": 128}}
|
||||
|
||||
glue_output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mnli-mm": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
|
||||
class Task:
|
||||
|
@ -110,9 +93,10 @@ class Task:
|
|||
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
||||
return task_params[self.task_name]["seq_length"]
|
||||
return DEFAULT_SEQ_LENGTH
|
||||
task = Task(args_opt.task_name)
|
||||
|
||||
|
||||
task = Task(args_opt.task_name)
|
||||
|
||||
|
||||
def do_train():
|
||||
"""
|
||||
|
@ -120,12 +104,14 @@ def do_train():
|
|||
"""
|
||||
ckpt_file = args_opt.load_ckpt_path
|
||||
|
||||
if ckpt_file == '':
|
||||
if not ckpt_file:
|
||||
raise ValueError("Student ckpt file should not be None")
|
||||
cfg = train_cfg
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
model_cfg.compute_type = mstype.float16
|
||||
train_cfg.loss_scale_value = 2 ** 10
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
else:
|
||||
|
@ -138,16 +124,16 @@ def do_train():
|
|||
num_labels=task.num_labels, is_predistill=False)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
train_dataset = create_tinybert_dataset(cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir, None, seq_length=task.seq_length,
|
||||
task_type='classification',
|
||||
drop_remainder=True)
|
||||
train_dataset = create_dataset(cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.train_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
drop_remainder=True)
|
||||
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
print('td2 train dataset size: ', dataset_size)
|
||||
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
print("train dataset size: ", dataset_size)
|
||||
if args_opt.enable_data_sink == "True":
|
||||
repeat_count = args_opt.epoch_num * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
|
@ -164,24 +150,24 @@ def do_train():
|
|||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
group_params = [{"params": decay_params, "weight_decay": optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{"params": other_params, "weight_decay": 0.0},
|
||||
{"order_params": params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
|
||||
eval_dataset = create_tinybert_dataset(eval_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir, None,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
task_type='classification',
|
||||
drop_remainder=False)
|
||||
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
|
||||
eval_dataset = create_dataset(eval_cfg.batch_size,
|
||||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir,
|
||||
data_type=args_opt.dataset_type,
|
||||
seq_length=task.seq_length,
|
||||
drop_remainder=False)
|
||||
print("eval dataset size: ", eval_dataset.get_dataset_size())
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if args_opt.do_eval == "True":
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||
EvalCallBack(netwithloss.bert, eval_dataset, args_opt.task_name, args_opt.logging_step)]
|
||||
EvalCallBack(netwithloss.bert, eval_dataset, args_opt.task_name, args_opt.logging_step,
|
||||
save_ckpt_dir)]
|
||||
else:
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||
ModelSaveCkpt(netwithloss.bert,
|
||||
|
@ -192,16 +178,16 @@ def do_train():
|
|||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
|
||||
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, train_dataset, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "True"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(1)
|
||||
enable_loss_scale = True
|
||||
model_cfg.seq_length = task.seq_length
|
Loading…
Reference in New Issue