forked from mindspore-Ecosystem/mindspore
commit
82715410e9
|
@ -0,0 +1,417 @@
|
|||
|
||||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [概述](#概述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据准备](#数据准备)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和样例代码](#脚本和样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [预训练](#预训练)
|
||||
- [微调与评估](#微调与评估)
|
||||
- [选项及参数](#选项及参数)
|
||||
- [选项](#选项)
|
||||
- [参数](#参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器上运行](#ascend处理器上运行)
|
||||
- [GPU上运行](#GPU上运行)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#用法-1)
|
||||
- [Ascend处理器上运行后评估各个任务的模型](#Ascend处理器上运行后评估各个任务的模型)
|
||||
- [GPU上运行后评估各个任务的模型](#GPU上运行后评估各个任务的模型)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [预训练性能](#预训练性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# 概述
|
||||
|
||||
对话系统 (Dialogue System) 常常需要根据应用场景的变化去解决多种多样的任务。任务的多样性(意图识别、槽填充、行为识别、状态追踪等等),以及领域训练数据的稀少,给Dialogue System的研究和应用带来了巨大的困难和挑战,要使得Dialogue System得到更好的发展,基于BERT的对话通用理解模型 (DGU: Dialogue General Understanding),通过实验表明,使用base-model (BERT)并结合常见的学习范式,可以实现一个通用的对话理解模型。
|
||||
|
||||
DGU模型内共包含4个任务,全部基于公开数据集在mindspore1.1.1上完成训练及评估,详细说明如下:
|
||||
|
||||
udc: 使用UDC (Ubuntu Corpus V1) 数据集完成对话匹配 (Dialogue Response Selection) 任务;
|
||||
atis_intent: 使用ATIS (Airline Travel Information System) 数据集完成对话意图识别 (Dialogue Intent Detection) 任务;
|
||||
mrda: 使用MRDAC (Meeting Recorder Dialogue Act Corpus) 数据集完成对话行为识别 (Dialogue Act Detection) 任务;
|
||||
swda: 使用SwDAC (Switchboard Dialogue Act Corpus) 数据集完成对话行为识别 (Dialogue Act Detection) 任务;
|
||||
|
||||
# 模型架构
|
||||
|
||||
BERT的主干结构为Transformer。对于BERT_base,Transformer包含12个编码器模块,每个模块包含一个自注意模块,每个自注意模块包含一个注意模块。
|
||||
|
||||
# 数据准备
|
||||
|
||||
- 下载数据集压缩包并解压后,DGU_datasets目录下共存在6个目录,分别对应每个任务的训练集train.txt、评估集dev.txt和测试集test.txt。
|
||||
wget https://paddlenlp.bj.bcebos.com/datasets/DGU_datasets.tar.gz
|
||||
tar -zxf DGU_datasets.tar.gz
|
||||
- 下载数据集进行微调和评估,如udc、atis_intent、mrda、swda等。将数据集文件从JSON格式转换为MindRecord格式。详见src/dataconvert.py文件。
|
||||
- BERT模型训练的词汇表bert-base-uncased-vocab.txt 下载地址:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
|
||||
- bert-base-uncased预训练模型原始权重 下载地址:https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(GPU处理器)
|
||||
- 准备GPU处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- 更多关于Mindspore的信息,请查看以下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
从官网下载安装MindSpore之后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- 在GPU上运行
|
||||
|
||||
```bash
|
||||
# 运行微调和评估示例
|
||||
- 如需运行微调任务,请先准备预训练生成的权重文件(ckpt)。
|
||||
- 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。
|
||||
- 运行下载数据脚本:
|
||||
|
||||
bash scripts/download_data.sh
|
||||
- 运行数据预处理脚本:
|
||||
|
||||
bash scripts/run_data_preprocess.sh
|
||||
- 运行下载及转换预训练模型脚本(转换需要paddle环境):
|
||||
|
||||
bash scripts/download_pretrain_model.sh
|
||||
|
||||
- dgu:在scripts/run_dgu.sh中设置任务相关的超参,可完成进行针对不同任务的微调。
|
||||
- 运行`bash scripts/run_dgu_gpu.sh`,对BERT-base模型进行微调。
|
||||
|
||||
bash scripts/run_dgu_gpu.sh
|
||||
|
||||
```
|
||||
|
||||
在Ascend设备上做分布式训练时,请提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
在Ascend设备上做单机分布式训练时,请参考[here](https://gitee.com/mindspore/mindspore/tree/master/config/hccl_single_machine_multi_rank.json)创建HCCL配置文件。
|
||||
|
||||
在Ascend设备上做多机分布式训练时,训练命令需要在很短的时间间隔内在各台设备上执行。因此,每台设备上都需要准备HCCL配置文件。请参考[here](https://gitee.com/mindspore/mindspore/tree/master/config/hccl_multi_machine_multi_rank.json)创建多机的HCCL配置文件。
|
||||
|
||||
如需设置数据集格式和参数,请创建JSON格式的模式配置文件,详见[TFRecord](https://www.mindspore.cn/doc/programming_guide/zh-CN/master/dataset_loading.html#tfrecord)格式。
|
||||
|
||||
```text
|
||||
For pretraining, schema file contains ["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"].
|
||||
|
||||
For ner or classification task, schema file contains ["input_ids", "input_mask", "segment_ids", "label_ids"].
|
||||
|
||||
For squad task, training: schema file contains ["start_positions", "end_positions", "input_ids", "input_mask", "segment_ids"], evaluation: schema file contains ["input_ids", "input_mask", "segment_ids"].
|
||||
|
||||
`numRows` is the only option which could be set by user, other values must be set according to the dataset.
|
||||
|
||||
For example, the schema file of cn-wiki-128 dataset for pretraining shows as follows:
|
||||
{
|
||||
"datasetType": "TF",
|
||||
"numRows": 7680,
|
||||
"columns": {
|
||||
"input_ids": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [128]
|
||||
},
|
||||
"input_mask": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [128]
|
||||
},
|
||||
"segment_ids": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [128]
|
||||
},
|
||||
"next_sentence_labels": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
},
|
||||
"masked_lm_positions": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [20]
|
||||
},
|
||||
"masked_lm_ids": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [20]
|
||||
},
|
||||
"masked_lm_weights": {
|
||||
"type": "float32",
|
||||
"rank": 1,
|
||||
"shape": [20]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```shell
|
||||
.
|
||||
└─dgu
|
||||
├─README_CN.md
|
||||
├─scripts
|
||||
├─run_dgu.sh # Ascend上单机DGU任务shell脚本
|
||||
├─run_dgu_gpu.sh # GPU上单机DGU任务shell脚本
|
||||
├─download_data.sh # 下载数据集shell脚本
|
||||
├─download_pretrain_model.sh # 下载预训练模型权重shell脚本
|
||||
├─export.sh # export脚本
|
||||
├─eval.sh # Ascend上单机DGU任务评估shell脚本
|
||||
└─run_data_preprocess.sh # 数据集预处理shell脚本
|
||||
├─src
|
||||
├─__init__.py
|
||||
├─adam.py # 优化器
|
||||
├─args.py # 代码运行参数设置
|
||||
├─bert_for_finetune.py # 网络骨干编码
|
||||
├─bert_for_pre_training.py # 网络骨干编码
|
||||
├─bert_model.py # 网络骨干编码
|
||||
├─config.py # 预训练参数配置
|
||||
├─data_util.py # 数据预处理util函数
|
||||
├─dataset.py # 数据预处理
|
||||
├─dataconvert.py # 数据转换
|
||||
├─finetune_eval_config.py # 微调参数配置
|
||||
├─finetune_eval_model.py # 网络骨干编码
|
||||
├─metric.py # 评估过程的测评方法
|
||||
├─pretrainmodel_convert.py # 预训练模型权重转换
|
||||
├─tokenizer.py # tokenizer函数
|
||||
└─utils.py # util函数
|
||||
└─run_dgu.py # DGU模型的微调和评估网络
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
### 微调与评估
|
||||
|
||||
```shell
|
||||
用法:dataconvert.py [--task_name TASK_NAME]
|
||||
[--data_dir DATA_DIR]
|
||||
[--vocab_file_path VOCAB_FILE_PATH]
|
||||
[--output_dir OUTPUT_DIR]
|
||||
[--max_seq_len N]
|
||||
[--eval_max_seq_len N]
|
||||
选项:
|
||||
--task_name 训练任务的名称
|
||||
--data_dir 原始数据集路径
|
||||
--vocab_file_path BERT模型训练的词汇表
|
||||
--output_dir 保存生成mindRecord格式数据的路径
|
||||
--max_seq_len train数据集的max_seq_len
|
||||
--eval_max_seq_len dev或test数据集的max_seq_len
|
||||
|
||||
用法:run_dgu.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL]
|
||||
[--device_id N] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--checkpoint_path CHECKPOINT_PATH]
|
||||
[--model_name_or_path MODEL_NAME_OR_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
[--eval_ckpt_path EVAL_CKPT_PATH]
|
||||
[--is_modelarts_work IS_MODELARTS_WORK]
|
||||
选项:
|
||||
--task_name 训练任务的名称
|
||||
--device_target 代码实现设备,可选项为Ascend或CPU。默认为Ascend
|
||||
--do_train 是否基于训练集开始训练,可选项为true或false
|
||||
--do_eval 是否基于开发集开始评估,可选项为true或false
|
||||
--epoch_num 训练轮次总数
|
||||
--train_data_shuffle 是否使能训练数据集轮换,默认为true
|
||||
--eval_data_shuffle 是否使能评估数据集轮换,默认为false
|
||||
--checkpoint_path 保存生成微调检查点的路径
|
||||
--model_name_or_path 初始检查点的文件路径(通常来自预训练BERT模型
|
||||
--train_data_file_path 用于保存训练数据的mindRecord文件,如train1.1.mindrecord
|
||||
--eval_data_file_path 用于保存预测数据的mindRecord文件,如dev1.1.mindrecord
|
||||
--eval_ckpt_path 如仅执行评估,提供用于评估的微调检查点的路径
|
||||
--is_modelarts_work 是否使用ModelArts线上训练环境,默认为false
|
||||
```
|
||||
|
||||
## 选项及参数
|
||||
|
||||
可以在`config.py`和`finetune_eval_config.py`文件中分别配置训练和评估参数。
|
||||
|
||||
### 选项
|
||||
|
||||
```text
|
||||
config for lossscale and etc.
|
||||
bert_network BERT模型版本,可选项为base或nezha,默认为base
|
||||
batch_size 输入数据集的批次大小,默认为16
|
||||
loss_scale_value 损失放大初始值,默认为2^32
|
||||
scale_factor 损失放大的更新因子,默认为2
|
||||
scale_window 损失放大的一次更新步数,默认为1000
|
||||
optimizer 网络中采用的优化器,可选项为AdamWerigtDecayDynamicLR、Lamb、或Momentum,默认为Lamb
|
||||
```
|
||||
|
||||
### 参数
|
||||
|
||||
```text
|
||||
数据集和网络参数(预训练/微调/评估):
|
||||
seq_length 输入序列的长度,默认为128
|
||||
vocab_size 各内嵌向量大小,需与所采用的数据集相同。默认为21136
|
||||
hidden_size BERT的encoder层数,默认为768
|
||||
num_hidden_layers 隐藏层数,默认为12
|
||||
num_attention_heads 注意头的数量,默认为12
|
||||
intermediate_size 中间层数,默认为3072
|
||||
hidden_act 所采用的激活函数,默认为gelu
|
||||
hidden_dropout_prob BERT输出的随机失活可能性,默认为0.1
|
||||
attention_probs_dropout_prob BERT注意的随机失活可能性,默认为0.1
|
||||
max_position_embeddings 序列最大长度,默认为512
|
||||
type_vocab_size 标记类型的词汇表大小,默认为16
|
||||
initializer_range TruncatedNormal的初始值,默认为0.02
|
||||
use_relative_positions 是否采用相对位置,可选项为true或false,默认为False
|
||||
dtype 输入的数据类型,可选项为mstype.float16或mstype.float32,默认为mstype.float32
|
||||
compute_type Bert Transformer的计算类型,可选项为mstype.float16或mstype.float32,默认为mstype.float16
|
||||
|
||||
Parameters for optimizer:
|
||||
AdamWeightDecay:
|
||||
decay_steps 学习率开始衰减的步数
|
||||
learning_rate 学习率
|
||||
end_learning_rate 结束学习率,取值需为正数
|
||||
power 幂
|
||||
warmup_steps 热身学习率步数
|
||||
weight_decay 权重衰减
|
||||
eps 增加分母,提高小数稳定性
|
||||
|
||||
Lamb:
|
||||
decay_steps 学习率开始衰减的步数
|
||||
learning_rate 学习率
|
||||
end_learning_rate 结束学习率
|
||||
power 幂
|
||||
warmup_steps 热身学习率步数
|
||||
weight_decay 权重衰减
|
||||
|
||||
Momentum:
|
||||
learning_rate 学习率
|
||||
momentum 平均移动动量
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器上运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_dgu.sh
|
||||
```
|
||||
|
||||
以上命令后台运行,您可以在task_name.log中查看训练日志。训练结束后,您可以在默认脚本路径下脚本文件夹中找到检查点文件,得到如下损失值:
|
||||
|
||||
```text
|
||||
# grep "epoch" task_name.log
|
||||
epoch: 0.0, current epoch percent: 0.000, step: 1, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.0856101e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 65536))
|
||||
epoch: 0.0, current epoch percent: 0.000, step: 2, outputs are (Tensor(shape=[1], dtype=Float32, [ 1.0821701e+01]), Tensor(shape=[], dtype=Bool, False), Tensor(shape=[], dtype=Float32, 65536))
|
||||
...
|
||||
```
|
||||
|
||||
> **注意**如果所运行的数据集较大,建议添加一个外部环境变量,确保HCCL不会超时。
|
||||
>
|
||||
> ```bash
|
||||
> export HCCL_CONNECT_TIMEOUT=600
|
||||
> ```
|
||||
>
|
||||
> 将HCCL的超时时间从默认的120秒延长到600秒。
|
||||
> **注意**若使用的BERT模型较大,保存检查点时可能会出现protobuf错误,可尝试使用下面的环境集。
|
||||
>
|
||||
> ```bash
|
||||
> export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
> ```
|
||||
|
||||
#### GPU上运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_dgu_gpu.sh
|
||||
```
|
||||
|
||||
以上命令后台运行,您可以在task_name.log中查看训练日志。训练结束后,您可以在默认脚本路径下脚本文件夹中找到检查点文件,得到如下损失值:
|
||||
|
||||
```text
|
||||
# grep "epoch" task_name.log
|
||||
epoch: 0, current epoch percent: 1.000, step: 6094, outputs are (Tensor(shape=[], dtype=Float32, value= 0.714172), Tensor(shape=[], dtype=Bool, value= False))
|
||||
epoch time: 1702423.561 ms, per step time: 279.361 ms
|
||||
epoch: 1, current epoch percent: 1.000, step: 12188, outputs are (Tensor(shape=[], dtype=Float32, value= 0.788653), Tensor(shape=[], dtype=Bool, value= False))
|
||||
epoch time: 1684662.219 ms, per step time: 276.446 ms
|
||||
epoch: 2, current epoch percent: 1.000, step: 18282, outputs are (Tensor(shape=[], dtype=Float32, value= 0.618005), Tensor(shape=[], dtype=Bool, value= False))
|
||||
epoch time: 1711860.908 ms, per step time: 280.909 ms
|
||||
...
|
||||
```
|
||||
|
||||
> **注意**如果所运行的数据集较大,建议添加一个外部环境变量,确保HCCL不会超时。
|
||||
>
|
||||
> ```bash
|
||||
> export HCCL_CONNECT_TIMEOUT=600
|
||||
> ```
|
||||
>
|
||||
> 将HCCL的超时时间从默认的120秒延长到600秒。
|
||||
> **注意**若使用的BERT模型较大,保存检查点时可能会出现protobuf错误,可尝试使用下面的环境集。
|
||||
>
|
||||
> ```bash
|
||||
> export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
> ```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器上运行后评估各个任务的模型
|
||||
|
||||
运行以下命令前,确保已设置加载与训练检查点路径。若将检查点路径设置为绝对全路径,例如,/username/pretrain/checkpoint_100_300.ckpt,则评估指定的检查点;若将检查点路径设置为文件夹路径,则评估文件夹中所有检查点。
|
||||
修改eval.sh中task_name为将要评估的任务名以及修改相应的测试数据路径,修改device_target为"Ascend"。
|
||||
|
||||
```bash
|
||||
bash scripts/eval.sh
|
||||
```
|
||||
|
||||
可得到如下结果:
|
||||
|
||||
```text
|
||||
eval model: /home/dgu/checkpoints/swda/swda_3-2_6094.ckpt
|
||||
loading...
|
||||
evaling...
|
||||
==============================================================
|
||||
(w/o first and last) elapsed time: 2.3705036640167236, per step time : 0.017053983194364918
|
||||
==============================================================
|
||||
Accuracy : 0.8092150215136715
|
||||
```
|
||||
|
||||
#### GPU上运行后评估各个任务的模型
|
||||
|
||||
运行以下命令前,确保已设置加载与训练检查点路径。请将检查点路径设置为绝对全路径,例如,/username/pretrain/checkpoint_100_300.ckpt,则评估指定的检查点;若将检查点路径设置为文件夹路径,则评估文件夹中所有检查点。
|
||||
修改eval.sh中task_name为将要评估的任务名以及修改相应的测试数据路径,修改device_target为"GPU"。
|
||||
|
||||
```bash
|
||||
bash scripts/eval.sh
|
||||
```
|
||||
|
||||
可得到如下结果:
|
||||
|
||||
```text
|
||||
eval model: /home/dgu/checkpoints/swda/swda-2_6094.ckpt
|
||||
loading...
|
||||
evaling...
|
||||
==============================================================
|
||||
(w/o first and last) elapsed time: 10.98917531967163, per step time : 0.0790588152494362
|
||||
==============================================================
|
||||
Accuracy : 0.8082890070921985
|
||||
```
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
run_dgu.sh中设置train_data_shuffle为true,eval_data_shuffle为false,默认对数据集进行轮换操作。
|
||||
|
||||
config.py中,默认将hidden_dropout_prob和note_pros_dropout_prob设置为0.1,丢弃部分网络节点。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into models"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, load_checkpoint, export
|
||||
|
||||
from src.finetune_eval_config import bert_net_cfg
|
||||
from src.finetune_eval_model import BertCLSModel
|
||||
parser = argparse.ArgumentParser(description="Bert export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
|
||||
parser.add_argument("--number_labels", type=int, default=16, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Bert ckpt file.")
|
||||
parser.add_argument("--file_name", type=str, default="Bert", help="bert output air name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = BertCLSModel(bert_net_cfg, False, num_labels=args.number_labels)
|
||||
|
||||
load_checkpoint(args.ckpt_file, net=net)
|
||||
net.set_train(False)
|
||||
|
||||
input_ids = Tensor(np.zeros([args.batch_size, bert_net_cfg.seq_length]), mstype.int32)
|
||||
input_mask = Tensor(np.zeros([args.batch_size, bert_net_cfg.seq_length]), mstype.int32)
|
||||
token_type_id = Tensor(np.zeros([args.batch_size, bert_net_cfg.seq_length]), mstype.int32)
|
||||
label_ids = Tensor(np.zeros([args.batch_size, bert_net_cfg.seq_length]), mstype.int32)
|
||||
|
||||
input_data = [input_ids, input_mask, token_type_id]
|
||||
export(net, *input_data, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Bert finetune and evaluation script.
|
||||
'''
|
||||
|
||||
import os
|
||||
import time
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as P
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn import Accuracy
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import (CheckpointConfig, ModelCheckpoint,
|
||||
TimeMonitor)
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
import src.dataset as data
|
||||
import src.metric as metric
|
||||
from src.args import parse_args, set_default_args
|
||||
from src.bert_for_finetune import BertCLS, BertFinetuneCell
|
||||
from src.finetune_eval_config import (bert_net_cfg, bert_net_udc_cfg,
|
||||
optimizer_cfg)
|
||||
from src.utils import (CustomWarmUpLR, GetAllCkptPath, LossCallBack,
|
||||
create_classification_dataset, make_directory)
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="base-BertCLS-111.ckpt",
|
||||
save_checkpoint_path="", epoch_num=1):
|
||||
""" do train """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
print("load pretrain model: ", load_checkpoint_path)
|
||||
steps_per_epoch = args_opt.save_steps
|
||||
num_examples = dataset.get_dataset_size() * args_opt.train_batch_size
|
||||
max_train_steps = epoch_num * dataset.get_dataset_size()
|
||||
warmup_steps = int(max_train_steps * args_opt.warmup_proportion)
|
||||
print("Num train examples: %d" % num_examples)
|
||||
print("Max train steps: %d" % max_train_steps)
|
||||
print("Num warmup steps: %d" % warmup_steps)
|
||||
#warmup and optimizer
|
||||
lr_schedule = CustomWarmUpLR(learning_rate=args_opt.learning_rate, \
|
||||
warmup_steps=warmup_steps, max_train_steps=max_train_steps)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
|
||||
#ckpt config
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=args_opt.task_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
# load checkpoint into network
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
|
||||
model.train(epoch_num, dataset, callbacks=callbacks)
|
||||
|
||||
def eval_result_print(eval_metric, result):
|
||||
if args_opt.task_name.lower() in ['atis_intent', 'mrda', 'swda']:
|
||||
metric_name = "Accuracy"
|
||||
else:
|
||||
metric_name = eval_metric.name()
|
||||
print(metric_name, " :", result)
|
||||
if args_opt.task_name.lower() == "udc":
|
||||
print("R1@10: ", result[0])
|
||||
print("R2@10: ", result[1])
|
||||
print("R5@10: ", result[2])
|
||||
|
||||
def do_eval(dataset=None, network=None, num_class=5, eval_metric=None, load_checkpoint_path=""):
|
||||
""" do eval """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
print("eval model: ", load_checkpoint_path)
|
||||
print("loading... ")
|
||||
net_for_pretraining = network(eval_net_cfg, False, num_class)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(net_for_pretraining, param_dict)
|
||||
model = Model(net_for_pretraining)
|
||||
|
||||
print("evaling... ")
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
eval_metric.clear()
|
||||
evaluate_times = []
|
||||
for data_item in dataset.create_dict_iterator(num_epochs=1):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data_item[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
squeeze = P.Squeeze(-1)
|
||||
label_ids = squeeze(label_ids)
|
||||
time_begin = time.time()
|
||||
logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
|
||||
time_end = time.time()
|
||||
evaluate_times.append(time_end - time_begin)
|
||||
eval_metric.update(logits, label_ids)
|
||||
print("==============================================================")
|
||||
print("(w/o first and last) elapsed time: {}, per step time : {}".format(
|
||||
sum(evaluate_times[1:-1]), sum(evaluate_times[1:-1])/(len(evaluate_times) - 2)))
|
||||
print("==============================================================")
|
||||
result = eval_metric.eval()
|
||||
eval_result_print(eval_metric, result)
|
||||
return result
|
||||
|
||||
|
||||
def run_dgu(args_input):
|
||||
"""run_dgu main function """
|
||||
dataset_class, metric_class = TASK_CLASSES[args_input.task_name]
|
||||
epoch_num = args_input.epochs
|
||||
num_class = dataset_class.num_classes()
|
||||
|
||||
target = args_input.device_target
|
||||
if target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_input.device_id)
|
||||
elif target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_input.device_id)
|
||||
if net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
net_cfg.compute_type = mstype.float32
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
if args_input.do_train.lower() == "true":
|
||||
netwithloss = BertCLS(net_cfg, True, num_labels=num_class, dropout_prob=0.1)
|
||||
train_ds = create_classification_dataset(batch_size=args_input.train_batch_size, repeat_count=1, \
|
||||
data_file_path=args_input.train_data_file_path, \
|
||||
do_shuffle=(args_input.train_data_shuffle.lower() == "true"))
|
||||
do_train(train_ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
||||
|
||||
if args_input.do_eval.lower() == "true":
|
||||
eval_ds = create_classification_dataset(batch_size=args_input.eval_batch_size, repeat_count=1, \
|
||||
data_file_path=args_input.eval_data_file_path, \
|
||||
do_shuffle=(args_input.eval_data_shuffle.lower() == "true"))
|
||||
if args_input.task_name in ['atis_intent', 'mrda', 'swda']:
|
||||
eval_metric = metric_class("classification")
|
||||
else:
|
||||
eval_metric = metric_class()
|
||||
#load model from path and eval
|
||||
if args_input.eval_ckpt_path:
|
||||
do_eval(eval_ds, BertCLS, num_class, eval_metric, args_input.eval_ckpt_path)
|
||||
#eval all saved models
|
||||
else:
|
||||
ckpt_list = GetAllCkptPath(save_finetune_checkpoint_path)
|
||||
print("saved models:", ckpt_list)
|
||||
for filepath in ckpt_list:
|
||||
eval_result = do_eval(eval_ds, BertCLS, num_class, eval_metric, filepath)
|
||||
eval_file_dict[filepath] = str(eval_result)
|
||||
print(eval_file_dict)
|
||||
if args_input.is_modelarts_work == 'true':
|
||||
for filename in eval_file_dict:
|
||||
ckpt_result = eval_file_dict[filename].replace('[', '').replace(']', '').replace(', ', '_', 2)
|
||||
save_file_name = args_input.train_url + ckpt_result + "_" + filename.split('/')[-1]
|
||||
mox.file.copy_parallel(filename, save_file_name)
|
||||
print("upload model " + filename + " to " + save_file_name)
|
||||
|
||||
def print_args_input(args_input):
|
||||
print('----------- Configuration Arguments -----------')
|
||||
for arg, value in sorted(vars(args_input).items()):
|
||||
print('%s: %s' % (arg, value))
|
||||
print('------------------------------------------------')
|
||||
|
||||
def set_bert_cfg():
|
||||
"""set bert cfg"""
|
||||
global net_cfg
|
||||
global eval_net_cfg
|
||||
if args_opt.task_name == 'udc':
|
||||
net_cfg = bert_net_udc_cfg
|
||||
eval_net_cfg = bert_net_udc_cfg
|
||||
print("use udc_bert_cfg")
|
||||
else:
|
||||
net_cfg = bert_net_cfg
|
||||
eval_net_cfg = bert_net_cfg
|
||||
return net_cfg, eval_net_cfg
|
||||
|
||||
if __name__ == '__main__':
|
||||
TASK_CLASSES = {
|
||||
'udc': (data.UDCv1, metric.RecallAtK),
|
||||
'atis_intent': (data.ATIS_DID, Accuracy),
|
||||
'mrda': (data.MRDA, Accuracy),
|
||||
'swda': (data.SwDA, Accuracy)
|
||||
}
|
||||
os.environ['GLOG_v'] = '3'
|
||||
eval_file_dict = {}
|
||||
args_opt = parse_args()
|
||||
set_default_args(args_opt)
|
||||
net_cfg, eval_net_cfg = set_bert_cfg()
|
||||
load_pretrain_checkpoint_path = args_opt.model_name_or_path
|
||||
save_finetune_checkpoint_path = args_opt.checkpoints_path + args_opt.task_name
|
||||
save_finetune_checkpoint_path = make_directory(save_finetune_checkpoint_path)
|
||||
if args_opt.is_modelarts_work == 'true':
|
||||
import moxing as mox
|
||||
local_load_pretrain_checkpoint_path = args_opt.local_model_name_or_path
|
||||
local_data_path = '/cache/data/' + args_opt.task_name
|
||||
mox.file.copy_parallel(args_opt.data_url + args_opt.task_name, local_data_path)
|
||||
mox.file.copy_parallel('obs:/' + load_pretrain_checkpoint_path, local_load_pretrain_checkpoint_path)
|
||||
load_pretrain_checkpoint_path = local_load_pretrain_checkpoint_path
|
||||
if not args_opt.train_data_file_path:
|
||||
args_opt.train_data_file_path = local_data_path + '/' + args_opt.task_name + '_train.mindrecord'
|
||||
if not args_opt.eval_data_file_path:
|
||||
args_opt.eval_data_file_path = local_data_path + '/' + args_opt.task_name + '_test.mindrecord'
|
||||
print_args_input(args_opt)
|
||||
run_dgu(args_opt)
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# download dataset file to ./
|
||||
DATA_URL=https://paddlenlp.bj.bcebos.com/datasets/DGU_datasets.tar.gz
|
||||
wget --no-check-certificate ${DATA_URL}
|
||||
# unzip dataset file to ./DGU_datasets
|
||||
tar -zxvf DGU_datasets.tar.gz
|
||||
|
||||
cd src
|
||||
# download vocab file to ./src/
|
||||
VOCAB_URL=https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
|
||||
wget --no-check-certificate ${VOCAB_URL}
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
mkdir -p pretrainModel
|
||||
cd pretrainModel
|
||||
|
||||
# download pretrain model file to ./pretrainModel/
|
||||
MODEL_BERT_BASE="https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams"
|
||||
wget --no-check-certificate ${MODEL_BERT_BASE}
|
||||
# convert pdparams to mindspore ckpt
|
||||
python ../src/pretrainmodel_convert.py
|
|
@ -0,0 +1,78 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export GLOG_v=3
|
||||
|
||||
python3 ./run_dgu.py \
|
||||
--task_name=udc \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/udc/udc_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=100 \
|
||||
--eval_data_file_path=./data/udc/udc_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=2 \
|
||||
--is_modelarts_work="false" \
|
||||
--eval_ckpt_path=./checkpoints/udc/udc-2_31250.ckpt
|
||||
|
||||
python3 ./run_dgu.py \
|
||||
--task_name=atis_intent \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/atis_intent/atis_intent_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/atis_intent/atis_intent_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=20 \
|
||||
--is_modelarts_work="false" \
|
||||
--eval_ckpt_path=./checkpoints/atis_intent/atis_intent-17_155.ckpt
|
||||
|
||||
python3 ./run_dgu.py \
|
||||
--task_name=mrda \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/mrda/mrda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/mrda/mrda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=7 \
|
||||
--is_modelarts_work="false" \
|
||||
--eval_ckpt_path=./checkpoints/mrda/mrda-3_2364.ckpt
|
||||
|
||||
python3 ./run_dgu.py \
|
||||
--task_name=swda \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/swda/swda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/swda/swda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=3 \
|
||||
--is_modelarts_work="false" \
|
||||
--eval_ckpt_path=./checkpoints/swda/swda-3_6094.ckpt
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python export.py --device_id=0 \
|
||||
--batch_size=32 \
|
||||
--number_labels=26 \
|
||||
--ckpt_file=/home/ma-user/work/ckpt/atis_intent/0.9791666666666666_atis_intent-11_155.ckpt \
|
||||
--file_name=atis_intent.mindir \
|
||||
--file_format=MINDIR \
|
||||
--device_target=Ascend
|
|
@ -0,0 +1,50 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
|
||||
#udc
|
||||
python3 ${CUR_DIR}/src/dataconvert.py \
|
||||
--data_dir=${CUR_DIR}/DGU_datasets/ \
|
||||
--output_dir=${CUR_DIR}/data/ \
|
||||
--vocab_file_dir=${CUR_DIR}/src/bert-base-uncased-vocab.txt \
|
||||
--task_name=udc \
|
||||
--max_seq_len=224 \
|
||||
--eval_max_seq_len=224
|
||||
|
||||
#atis_intent
|
||||
python3 ${CUR_DIR}/src/dataconvert.py \
|
||||
--data_dir=${CUR_DIR}/DGU_datasets/ \
|
||||
--output_dir=${CUR_DIR}/data/ \
|
||||
--vocab_file_dir=${CUR_DIR}/src/bert-base-uncased-vocab.txt \
|
||||
--task_name=atis_intent \
|
||||
--max_seq_len=128
|
||||
|
||||
#mrda
|
||||
python3 ${CUR_DIR}/src/dataconvert.py \
|
||||
--data_dir=${CUR_DIR}/DGU_datasets/ \
|
||||
--output_dir=${CUR_DIR}/data/ \
|
||||
--vocab_file_dir=${CUR_DIR}/src/bert-base-uncased-vocab.txt \
|
||||
--task_name=mrda \
|
||||
--max_seq_len=128
|
||||
|
||||
#swda
|
||||
python3 ${CUR_DIR}/src/dataconvert.py \
|
||||
--data_dir=${CUR_DIR}/DGU_datasets/ \
|
||||
--output_dir=${CUR_DIR}/data/ \
|
||||
--vocab_file_dir=${CUR_DIR}/src/bert-base-uncased-vocab.txt \
|
||||
--task_name=swda \
|
||||
--max_seq_len=128
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export GLOG_v=3
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=udc \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="Ascend" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/udc/udc_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/udc/udc_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=2 \
|
||||
--is_modelarts_work="false" >udc_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=atis_intent \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="Ascend" \
|
||||
--device_id=1 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/atis_intent/atis_intent_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/atis_intent/atis_intent_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=20 \
|
||||
--is_modelarts_work="false" >atisintent_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=mrda \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="Ascend" \
|
||||
--device_id=2 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/mrda/mrda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/mrda/mrda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=7 \
|
||||
--is_modelarts_work="false" >mrda_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=swda \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="Ascend" \
|
||||
--device_id=3 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/swda/swda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/swda/swda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=3 \
|
||||
--is_modelarts_work="false" >swda_output.log 2>&1 &
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export GLOG_v=3
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=udc \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=0 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/udc/udc_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/udc/udc_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=2 \
|
||||
--is_modelarts_work="false" >udc_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=atis_intent \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=1 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/atis_intent/atis_intent_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/atis_intent/atis_intent_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=20 \
|
||||
--is_modelarts_work="false" >atisintent_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=mrda \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=2 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/mrda/mrda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/mrda/mrda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=7 \
|
||||
--is_modelarts_work="false" >mrda_output.log 2>&1 &
|
||||
|
||||
nohup python3 ./run_dgu.py \
|
||||
--task_name=swda \
|
||||
--do_train="true" \
|
||||
--do_eval="true" \
|
||||
--device_target="GPU" \
|
||||
--device_id=3 \
|
||||
--model_name_or_path=./pretrainModel/base-BertCLS-111.ckpt \
|
||||
--train_data_file_path=./data/swda/swda_train.mindrecord \
|
||||
--train_batch_size=32 \
|
||||
--eval_data_file_path=./data/swda/swda_test.mindrecord \
|
||||
--checkpoints_path=./checkpoints/ \
|
||||
--epochs=3 \
|
||||
--is_modelarts_work="false" >swda_output.log 2>&1 &
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert Init."""
|
||||
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainOneStepWithLossScaleCellForAdam
|
||||
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||
SaturateCast, CreateAttentionMaskFromInputMask
|
||||
from .adam import AdamWeightDecayForBert
|
||||
__all__ = [
|
||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
||||
"BertTrainOneStepWithLossScaleCell",
|
||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
|
||||
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask",
|
||||
"BertTrainOneStepWithLossScaleCellForAdam"
|
||||
]
|
|
@ -0,0 +1,307 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
overflow (Tensor): Whether overflow occurs.
|
||||
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Applies weight decay or not.
|
||||
optim_filter (bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
op_select = P.Select()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
|
||||
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
|
||||
|
||||
update = next_m / (eps + op_sqrt(next_v))
|
||||
if decay_flag:
|
||||
update = op_mul(weight_decay, param_fp32) + update
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
|
||||
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
|
||||
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices), shapes), param))
|
||||
return success
|
||||
|
||||
if not target:
|
||||
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices))
|
||||
else:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
scatter_add = P.ScatterAdd(use_locking)
|
||||
|
||||
assign_m = F.assign(m, op_mul(beta1, m))
|
||||
assign_v = F.assign(v, op_mul(beta2, v))
|
||||
|
||||
grad_indices = gradient.indices
|
||||
grad_value = gradient.values
|
||||
|
||||
next_m = scatter_add(m,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
|
||||
next_v = scatter_add(v,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
|
||||
|
||||
if use_nesterov:
|
||||
m_temp = next_m * _scaler_ten
|
||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||
div_value = scatter_add(m,
|
||||
op_mul(grad_indices, _scaler_one),
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||
|
||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||
|
||||
F.control_depend(m_temp, assign_m_nesterov)
|
||||
F.control_depend(assign_m_nesterov, div_value)
|
||||
F.control_depend(param_update, m_recover)
|
||||
else:
|
||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||
|
||||
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
|
||||
next_param = param - lr_t * param_update
|
||||
|
||||
F.control_depend(assign_m, next_m)
|
||||
F.control_depend(assign_v, next_v)
|
||||
|
||||
success = F.depend(success, F.assign(param, next_param))
|
||||
success = F.depend(success, F.assign(m, next_m))
|
||||
success = F.depend(success, F.assign(v, next_v))
|
||||
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||
moment1, moment2, ps_parameter, cache_enable):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||
else:
|
||||
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor")
|
||||
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
||||
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
||||
success = F.depend(success, F.assign_add(param, delat_param))
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
class AdamWeightDecayForBert(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.op_select = P.Select()
|
||||
self.op_cast = P.Cast()
|
||||
self.op_reshape = P.Reshape()
|
||||
self.op_shape = P.Shape()
|
||||
|
||||
def construct(self, gradients, overflow):
|
||||
"""AdamWeightDecayForBert"""
|
||||
lr = self.get_lr()
|
||||
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
|
||||
self.op_reshape(overflow, (())), mstype.bool_)
|
||||
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
|
||||
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Args used in Bert finetune and evaluation.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
"""Parse args."""
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default="udc",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument(
|
||||
"--device_target",
|
||||
default="GPU",
|
||||
type=str,
|
||||
help="The device to train.")
|
||||
parser.add_argument(
|
||||
"--device_id",
|
||||
default=0,
|
||||
type=int,
|
||||
help="The device id to use.")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default='bert-base-uncased.ckpt',
|
||||
type=str,
|
||||
help="Path to pre-trained bert model or shortcut name.")
|
||||
parser.add_argument(
|
||||
"--local_model_name_or_path",
|
||||
default='/cache/pretrainModel/bert-BertCLS-111.ckpt',
|
||||
type=str,
|
||||
help="local Path to pre-trained bert model or shortcut name, for online work.")
|
||||
parser.add_argument(
|
||||
"--checkpoints_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The output directory where the checkpoints will be saved.")
|
||||
parser.add_argument(
|
||||
"--eval_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The path of checkpoint to be loaded.")
|
||||
parser.add_argument(
|
||||
"--max_seq_len",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization for trainng.\
|
||||
Sequences longer than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument(
|
||||
"--eval_max_seq_len",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization for evaling.\
|
||||
Sequences longer than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
default=None,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="The proportion of warmup.")
|
||||
parser.add_argument(
|
||||
"--do_train", default="true", type=str, help="Whether training.")
|
||||
parser.add_argument(
|
||||
"--do_eval", default="true", type=str, help="Whether evaluation.")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_data_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable train data shuffle, default is true")
|
||||
parser.add_argument(
|
||||
"--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=32, help="Train batch size, default is 32")
|
||||
parser.add_argument(
|
||||
"--eval_batch_size", type=int, default=None,
|
||||
help="Eval batch size, default is None. if the eval_batch_size parameter is not passed in,\
|
||||
It will be assigned the same value as train_batch_size")
|
||||
parser.add_argument(
|
||||
"--eval_data_file_path", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument(
|
||||
"--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval data shuffle, default is false")
|
||||
|
||||
parser.add_argument(
|
||||
"--is_modelarts_work", type=str, default="false", help="Whether modelarts online work.")
|
||||
parser.add_argument(
|
||||
"--train_url", type=str, default="",
|
||||
help="save_model path, it is better to use absolute path, for modelarts online work.")
|
||||
parser.add_argument(
|
||||
"--data_url", type=str, default="", help="data path, for modelarts online work")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def set_default_args(args):
|
||||
"""set default args."""
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name == 'udc':
|
||||
if not args.save_steps:
|
||||
args.save_steps = 1000
|
||||
if not args.epochs:
|
||||
args.epochs = 2
|
||||
if not args.max_seq_len:
|
||||
args.max_seq_len = 224
|
||||
if not args.eval_batch_size:
|
||||
args.eval_batch_size = 100
|
||||
elif args.task_name == 'atis_intent':
|
||||
if not args.save_steps:
|
||||
args.save_steps = 100
|
||||
if not args.epochs:
|
||||
args.epochs = 20
|
||||
elif args.task_name == 'mrda':
|
||||
if not args.save_steps:
|
||||
args.save_steps = 500
|
||||
if not args.epochs:
|
||||
args.epochs = 7
|
||||
elif args.task_name == 'swda':
|
||||
if not args.save_steps:
|
||||
args.save_steps = 500
|
||||
if not args.epochs:
|
||||
args.epochs = 3
|
||||
else:
|
||||
raise ValueError('Not support task: %s.' % args.task_name)
|
||||
|
||||
if not args.checkpoints_path:
|
||||
args.checkpoints_path = './checkpoints/' + args.task_name
|
||||
if not args.learning_rate:
|
||||
args.learning_rate = 2e-5
|
||||
if not args.max_seq_len:
|
||||
args.max_seq_len = 128
|
||||
if not args.eval_max_seq_len:
|
||||
args.eval_max_seq_len = args.max_seq_len
|
||||
if not args.eval_batch_size:
|
||||
args.eval_batch_size = args.train_batch_size
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Bert for finetune script.
|
||||
'''
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .bert_for_pre_training import clip_grad
|
||||
from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel
|
||||
from .utils import CrossEntropyCalculation
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class BertFinetuneCell(nn.Cell):
|
||||
"""
|
||||
Especially defined for finetuning where only four inputs tensor are needed.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Different from the builtin loss_scale wrapper cell, we apply grad_clip before the optimization.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(BertFinetuneCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.gpu_target = False
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
sens=None):
|
||||
"""Bert Finetune"""
|
||||
|
||||
weights = self.weights
|
||||
init = False
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
if not self.gpu_target:
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertSquadCell(nn.Cell):
|
||||
"""
|
||||
specifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertSquadCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible,
|
||||
sens=None):
|
||||
"""BertSquad"""
|
||||
weights = self.weights
|
||||
init = self.alloc_status()
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertCLS(nn.Cell):
|
||||
"""
|
||||
Train interface for classification finetuning task.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
assessment_method=""):
|
||||
super(BertCLS, self).__init__()
|
||||
self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings,
|
||||
assessment_method)
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.assessment_method = assessment_method
|
||||
self.is_training = is_training
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.assessment_method == "spearman_correlation":
|
||||
if self.is_training:
|
||||
loss = self.loss(logits, label_ids)
|
||||
else:
|
||||
loss = logits
|
||||
else:
|
||||
loss = self.loss(logits, label_ids, self.num_labels)
|
||||
return loss
|
||||
|
||||
|
||||
class BertNER(nn.Cell):
|
||||
"""
|
||||
Train interface for sequence labeling finetuning task.
|
||||
"""
|
||||
def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False,
|
||||
tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertNER, self).__init__()
|
||||
self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
|
||||
if use_crf:
|
||||
if not tag_to_index:
|
||||
raise Exception("The dict for tag-index mapping should be provided for CRF.")
|
||||
from src.CRF import CRF
|
||||
self.loss = CRF(tag_to_index, batch_size, config.seq_length, is_training)
|
||||
else:
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.use_crf = use_crf
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.use_crf:
|
||||
loss = self.loss(logits, label_ids)
|
||||
else:
|
||||
loss = self.loss(logits, label_ids, self.num_labels)
|
||||
return loss
|
||||
|
||||
class BertSquad(nn.Cell):
|
||||
'''
|
||||
Train interface for SQuAD finetuning task.
|
||||
'''
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquad, self).__init__()
|
||||
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.seq_length = config.seq_length
|
||||
self.is_training = is_training
|
||||
self.total_num = Parameter(Tensor([0], mstype.float32))
|
||||
self.start_num = Parameter(Tensor([0], mstype.float32))
|
||||
self.end_num = Parameter(Tensor([0], mstype.float32))
|
||||
self.sum = P.ReduceSum()
|
||||
self.equal = P.Equal()
|
||||
self.argmax = P.ArgMaxWithValue(axis=1)
|
||||
self.squeeze = P.Squeeze(axis=-1)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible):
|
||||
"""interface for SQuAD finetuning task"""
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.is_training:
|
||||
unstacked_logits_0 = self.squeeze(logits[:, :, 0:1])
|
||||
unstacked_logits_1 = self.squeeze(logits[:, :, 1:2])
|
||||
start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length)
|
||||
end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length)
|
||||
total_loss = (start_loss + end_loss) / 2.0
|
||||
else:
|
||||
start_logits = self.squeeze(logits[:, :, 0:1])
|
||||
start_logits = start_logits + 100 * input_mask
|
||||
end_logits = self.squeeze(logits[:, :, 1:2])
|
||||
end_logits = end_logits + 100 * input_mask
|
||||
total_loss = (unique_id, start_logits, end_logits)
|
||||
return total_loss
|
|
@ -0,0 +1,807 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert for pretraining."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .bert_model import BertModel
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
class GetMaskedLMOutput(nn.Cell):
|
||||
"""
|
||||
Get masked lm output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, masked lm output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(self.width,
|
||||
config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
activation=config.hidden_act).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
self.output_bias = Parameter(
|
||||
initializer(
|
||||
'zero',
|
||||
config.vocab_size))
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.shape_flat_offsets = (-1, 1)
|
||||
self.last_idx = (-1,)
|
||||
self.shape_flat_sequence_tensor = (-1, self.width)
|
||||
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
|
||||
self.cast = P.Cast()
|
||||
self.compute_type = config.compute_type
|
||||
self.dtype = config.dtype
|
||||
|
||||
def construct(self,
|
||||
input_tensor,
|
||||
output_weights,
|
||||
positions):
|
||||
"""Get output log_probs"""
|
||||
rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0]))
|
||||
flat_offsets = self.reshape(rng * self.seq_length_tensor, self.shape_flat_offsets)
|
||||
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
|
||||
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
||||
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
|
||||
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||
output_weights = self.cast(output_weights, self.compute_type)
|
||||
input_tensor = self.dense(input_tensor)
|
||||
input_tensor = self.layernorm(input_tensor)
|
||||
logits = self.matmul(input_tensor, output_weights)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = logits + self.output_bias
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class GetNextSentenceOutput(nn.Cell):
|
||||
"""
|
||||
Get next sentence output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of Bert.
|
||||
|
||||
Returns:
|
||||
Tensor, next sentence output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetNextSentenceOutput, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax()
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(config.hidden_size, 2,
|
||||
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
|
||||
self.dtype = config.dtype
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
logits = self.dense(input_tensor)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
log_prob = self.log_softmax(logits)
|
||||
return log_prob
|
||||
|
||||
|
||||
class BertPreTraining(nn.Cell):
|
||||
"""
|
||||
Bert pretraining network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores, seq_relationship_score.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(BertPreTraining, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cls1 = GetMaskedLMOutput(config)
|
||||
self.cls2 = GetNextSentenceOutput(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id,
|
||||
masked_lm_positions):
|
||||
sequence_output, pooled_output, embedding_table = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
prediction_scores = self.cls1(sequence_output,
|
||||
embedding_table,
|
||||
masked_lm_positions)
|
||||
seq_relationship_score = self.cls2(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertPretrainingLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertPretrainingLoss, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_labels):
|
||||
"""Defines the computation performed."""
|
||||
label_ids = self.reshape(masked_lm_ids, self.last_idx)
|
||||
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
||||
masked_lm_loss = numerator / denominator
|
||||
|
||||
# next_sentence_loss
|
||||
labels = self.reshape(next_sentence_labels, self.last_idx)
|
||||
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(
|
||||
one_hot_labels * seq_relationship_score, self.last_idx))
|
||||
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
|
||||
# total_loss
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class BertNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(BertNetworkWithLoss, self).__init__()
|
||||
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = BertPretrainingLoss(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Get pre-training loss"""
|
||||
prediction_scores, seq_relationship_score = \
|
||||
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||
total_loss = self.loss(prediction_scores, seq_relationship_score,
|
||||
masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class BertTrainOneStepCell(nn.TrainOneStepCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
|
||||
condition as input.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||
succ = self.optimizer(grads, overflow)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
cast = P.Cast()
|
||||
add_grads = C.MultitypeFuncGraph("add_grads")
|
||||
|
||||
|
||||
@add_grads.register("Tensor", "Tensor")
|
||||
def _add_grads(accu_grad, grad):
|
||||
return accu_grad + cast(grad, mstype.float32)
|
||||
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
||||
|
||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||
def _accumulate_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
zeroslike = P.ZerosLike()
|
||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will only be implemented in the weight updated step,
|
||||
i.e. the sub-step after gradients accumulated N times.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
||||
mean_loss = F.depend(mean_loss, accu_succ)
|
||||
|
||||
init = F.depend(init, mean_loss)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
else:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(self.accu_grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
accu_overflow = F.depend(accu_overflow, grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
overflow = F.depend(overflow, accu_succ)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
|
||||
will be overided by backend optimization pass.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
|
||||
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
flag_reduce = self.overflow_reducer(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
else:
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
|
@ -0,0 +1,881 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert model."""
|
||||
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class BertConfig:
|
||||
"""
|
||||
Configuration for `BertModel`.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence. Default: 128.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
||||
cell. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in the BertTransformer
|
||||
encoder cell. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
||||
encoder cell. Default: 3072.
|
||||
hidden_act (str): Activation function used in the BertTransformer encoder
|
||||
cell. Default: "gelu".
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
seq_length=128,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32):
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""Get output and embeddings lookup table"""
|
||||
extended_ids = self.expand(input_ids, -1)
|
||||
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output = self.reshape(output_for_reshape, self.shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingPostprocessor(nn.Cell):
|
||||
"""
|
||||
Postprocessors apply positional and token type embeddings to word embeddings.
|
||||
|
||||
Args:
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_relative_positions=False,
|
||||
use_token_type=False,
|
||||
token_type_vocab_size=16,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=512,
|
||||
dropout_prob=0.1):
|
||||
super(EmbeddingPostprocessor, self).__init__()
|
||||
self.use_token_type = use_token_type
|
||||
self.token_type_vocab_size = token_type_vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[token_type_vocab_size,
|
||||
embedding_size]))
|
||||
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.1, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
embedding_size]))
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
flat_ids = self.reshape(token_type_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids,
|
||||
self.token_type_vocab_size, self.on_value, self.off_value)
|
||||
token_type_embeddings = self.array_mul(one_hot_ids,
|
||||
self.embedding_table)
|
||||
else:
|
||||
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
||||
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
||||
output += token_type_embeddings
|
||||
if not self.use_relative_positions:
|
||||
_, seq, width = self.shape
|
||||
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
||||
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||
output += position_embeddings
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class BertOutput(nn.Cell):
|
||||
"""
|
||||
Apply a linear computation to hidden status and a residual computation to input.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
out_channels (int): Output channels.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
initializer_range=0.02,
|
||||
dropout_prob=0.1,
|
||||
compute_type=mstype.float32):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
return output
|
||||
|
||||
|
||||
class RelaPosMatrixGenerator(nn.Cell):
|
||||
"""
|
||||
Generates matrix of relative positions between inputs.
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
max_relative_position (int): Max value of relative position.
|
||||
"""
|
||||
def __init__(self, length, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._length = length
|
||||
self._max_relative_position = max_relative_position
|
||||
self._min_relative_position = -max_relative_position
|
||||
self.range_length = -length + 1
|
||||
|
||||
self.tile = P.Tile()
|
||||
self.range_mat = P.Reshape()
|
||||
self.sub = P.Sub()
|
||||
self.expanddims = P.ExpandDims()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self):
|
||||
"""Generates matrix of relative positions between inputs."""
|
||||
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
||||
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
||||
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
||||
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
|
||||
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
|
||||
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
final_mat = distance_mat_clipped + self._max_relative_position
|
||||
return final_mat
|
||||
|
||||
|
||||
class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||
"""
|
||||
Generates tensor of size [length, length, depth].
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
length,
|
||||
depth,
|
||||
max_relative_position,
|
||||
initializer_range,
|
||||
use_one_hot_embeddings=False):
|
||||
super(RelaPosEmbeddingsGenerator, self).__init__()
|
||||
self.depth = depth
|
||||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
[self.vocab_size, self.depth]))
|
||||
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
"""Generate embedding for each relative position of dimension depth."""
|
||||
relative_positions_matrix_out = self.relative_positions_matrix()
|
||||
|
||||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
flat_relative_positions_matrix)
|
||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||
embeddings = self.reshape(embeddings, my_shape)
|
||||
else:
|
||||
embeddings = self.gather(self.embeddings_table,
|
||||
relative_positions_matrix_out, 0)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""
|
||||
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
||||
the danger that the value will overflow or underflow.
|
||||
|
||||
Args:
|
||||
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
|
||||
self.tensor_min_type = float(np.finfo(np_type).min)
|
||||
self.tensor_max_type = float(np.finfo(np_type).max)
|
||||
|
||||
self.min_op = P.Minimum()
|
||||
self.max_op = P.Maximum()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
out = self.max_op(x, self.tensor_min_type)
|
||||
out = self.min_op(out, self.tensor_max_type)
|
||||
return self.cast(out, self.dst_type)
|
||||
|
||||
|
||||
class BertAttention(nn.Cell):
|
||||
"""
|
||||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
||||
Args:
|
||||
from_tensor_width (int): Size of last dim of from_tensor.
|
||||
to_tensor_width (int): Size of last dim of to_tensor.
|
||||
from_seq_length (int): Length of from_tensor sequence.
|
||||
to_seq_length (int): Length of to_tensor sequence.
|
||||
num_attention_heads (int): Number of attention heads. Default: 1.
|
||||
size_per_head (int): Size of each attention head. Default: 512.
|
||||
query_act (str): Activation function for the query transform. Default: None.
|
||||
key_act (str): Activation function for the key transform. Default: None.
|
||||
value_act (str): Activation function for the value transform. Default: None.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||
tensor. Default: False.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
from_tensor_width,
|
||||
to_tensor_width,
|
||||
from_seq_length,
|
||||
to_seq_length,
|
||||
num_attention_heads=1,
|
||||
size_per_head=512,
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
has_attention_mask=False,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
do_return_2d_tensor=False,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
|
||||
super(BertAttention, self).__init__()
|
||||
self.from_seq_length = from_seq_length
|
||||
self.to_seq_length = to_seq_length
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.size_per_head = size_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.use_relative_positions = use_relative_positions
|
||||
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
activation=query_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
|
||||
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
|
||||
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.multiply = P.Mul()
|
||||
self.transpose = P.Transpose()
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = -10000.0
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if do_return_2d_tensor:
|
||||
self.shape_return = (-1, num_attention_heads * size_per_head)
|
||||
else:
|
||||
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
if self.use_relative_positions:
|
||||
self._generate_relative_positions_embeddings = \
|
||||
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
||||
depth=size_per_head,
|
||||
max_relative_position=16,
|
||||
initializer_range=initializer_range,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""reshape 2d/3d input tensors to 2d"""
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
query_out = self.query_layer(from_tensor_2d)
|
||||
key_out = self.key_layer(to_tensor_2d)
|
||||
value_out = self.value_layer(to_tensor_2d)
|
||||
|
||||
query_layer = self.reshape(query_out, self.shape_from)
|
||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||
key_layer = self.reshape(key_out, self.shape_to)
|
||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||
|
||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_keys is [F|T, F|T, H]
|
||||
relations_keys = self._generate_relative_positions_embeddings()
|
||||
relations_keys = self.cast_compute_type(relations_keys)
|
||||
# query_layer_t is [F, B, N, H]
|
||||
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||
# query_layer_r is [F, B * N, H]
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
# key_position_scores is [F, B * N, F|T]
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
relations_keys)
|
||||
# key_position_scores_r is [F, B, N, F|T]
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.from_seq_length))
|
||||
# key_position_scores_r_t is [B, N, F, F|T]
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
attention_scores = attention_scores + key_position_scores_r_t
|
||||
|
||||
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||
|
||||
if self.has_attention_mask:
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
||||
|
||||
adder = self.multiply(multiply_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
value_layer = self.reshape(value_out, self.shape_to)
|
||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||
context_layer = self.matmul(attention_probs, value_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_values is [F|T, F|T, H]
|
||||
relations_values = self._generate_relative_positions_embeddings()
|
||||
relations_values = self.cast_compute_type(relations_values)
|
||||
# attention_probs_t is [F, B, N, T]
|
||||
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||
# attention_probs_r is [F, B * N, T]
|
||||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.to_seq_length))
|
||||
# value_position_scores is [F, B * N, H]
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
relations_values)
|
||||
# value_position_scores_r is [F, B, N, H]
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
# value_position_scores_r_t is [B, N, F, H]
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
context_layer = context_layer + value_position_scores_r_t
|
||||
|
||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||
context_layer = self.reshape(context_layer, self.shape_return)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Cell):
|
||||
"""
|
||||
Apply self-attention.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence.
|
||||
hidden_size (int): Size of the bert encoder layers.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (hidden_size, num_attention_heads))
|
||||
|
||||
self.size_per_head = int(hidden_size / num_attention_heads)
|
||||
|
||||
self.attention = BertAttention(
|
||||
from_tensor_width=hidden_size,
|
||||
to_tensor_width=hidden_size,
|
||||
from_seq_length=seq_length,
|
||||
to_seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
size_per_head=self.size_per_head,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
use_relative_positions=use_relative_positions,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
compute_type=compute_type)
|
||||
|
||||
self.output = BertOutput(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
input_tensor = self.reshape(input_tensor, self.shape)
|
||||
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
|
||||
output = self.output(attention_output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class BertEncoderCell(nn.Cell):
|
||||
"""
|
||||
Encoder cells used in BertTransformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
seq_length (int): Length of input sequence. Default: 512.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.02.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
seq_length=512,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.02,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32):
|
||||
super(BertEncoderCell, self).__init__()
|
||||
self.attention = BertSelfAttention(
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
compute_type=compute_type)
|
||||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
# self-attention
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
# add and normalize
|
||||
output = self.output(intermediate_output, attention_output)
|
||||
return output
|
||||
|
||||
|
||||
class BertTransformer(nn.Cell):
|
||||
"""
|
||||
Multi-layer bert transformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the encoder layers.
|
||||
seq_length (int): Length of input sequence.
|
||||
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
seq_length,
|
||||
num_hidden_layers,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
return_all_encoders=False):
|
||||
super(BertTransformer, self).__init__()
|
||||
self.return_all_encoders = return_all_encoders
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = BertEncoderCell(hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
self.out_shape = (-1, seq_length, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""Multi-layer bert transformer."""
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
all_encoder_layers = ()
|
||||
for layer_module in self.layers:
|
||||
layer_output = layer_module(prev_output, attention_mask)
|
||||
prev_output = layer_output
|
||||
|
||||
if self.return_all_encoders:
|
||||
layer_output = self.reshape(layer_output, self.out_shape)
|
||||
all_encoder_layers = all_encoder_layers + (layer_output,)
|
||||
|
||||
if not self.return_all_encoders:
|
||||
prev_output = self.reshape(prev_output, self.out_shape)
|
||||
all_encoder_layers = all_encoder_layers + (prev_output,)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask = None
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 1, config.seq_length)
|
||||
|
||||
def construct(self, input_mask):
|
||||
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BertModel(nn.Cell):
|
||||
"""
|
||||
Bidirectional Encoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
|
||||
self.seq_length = config.seq_length
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embedding_size = config.hidden_size
|
||||
self.token_type_ids = None
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [-1, self.seq_length, self.embedding_size]
|
||||
|
||||
self.bert_embedding_lookup = EmbeddingLookup(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range)
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
use_token_type=True,
|
||||
token_type_vocab_size=config.type_vocab_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
dropout_prob=config.hidden_dropout_prob)
|
||||
|
||||
self.bert_encoder = BertTransformer(
|
||||
hidden_size=self.hidden_size,
|
||||
seq_length=self.seq_length,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=config.intermediate_size,
|
||||
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range,
|
||||
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
return_all_encoders=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
|
||||
# bert encoder
|
||||
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
|
||||
attention_mask)
|
||||
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
|
||||
return sequence_output, pooled_output, embedding_tables
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from bert_model import BertConfig
|
||||
cfg = edict({
|
||||
'batch_size': 32,
|
||||
'bert_network': 'base',
|
||||
'loss_scale_value': 65536,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'optimizer': 'Lamb',
|
||||
'enable_global_norm': False,
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
'power': 5.0,
|
||||
'weight_decay': 1e-5,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
'warmup_steps': 10000,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
'power': 5.0,
|
||||
'warmup_steps': 10000,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-8,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
'Thor': edict({
|
||||
'lr_max': 0.0034,
|
||||
'lr_min': 3.244e-5,
|
||||
'lr_power': 1.0,
|
||||
'lr_total_steps': 30000,
|
||||
'damping_max': 5e-2,
|
||||
'damping_min': 1e-6,
|
||||
'damping_power': 1.0,
|
||||
'damping_total_steps': 30000,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'loss_scale': 1.0,
|
||||
'frequency': 100,
|
||||
}),
|
||||
})
|
||||
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
base: Google BERT-base(the base version of BERT model).
|
||||
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||
'''
|
||||
if cfg.bert_network == 'base':
|
||||
cfg.batch_size = 64
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'nezha':
|
||||
cfg.batch_size = 96
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'large':
|
||||
cfg.batch_size = 24
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Data utils used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
class Tuple():
|
||||
"""
|
||||
apply the functions to the corresponding input fields.
|
||||
"""
|
||||
def __init__(self, fn, *args):
|
||||
if isinstance(fn, (list, tuple)):
|
||||
assert args, 'Input pattern not understood. The input of Tuple can be ' \
|
||||
'Tuple(A, B, C) or Tuple([A, B, C]) or Tuple((A, B, C)). ' \
|
||||
'Received fn=%s, args=%s' % (str(fn), str(args))
|
||||
self._fn = fn
|
||||
else:
|
||||
self._fn = (fn,) + args
|
||||
for i, ele_fn in enumerate(self._fn):
|
||||
assert callable(
|
||||
ele_fn
|
||||
), 'Batchify functions must be callable! type(fn[%d]) = %s' % (
|
||||
i, str(type(ele_fn)))
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
assert len(data[0]) == len(self._fn),\
|
||||
'The number of attributes in each data sample should contain' \
|
||||
' {} elements'.format(len(self._fn))
|
||||
ret = []
|
||||
for i, ele_fn in enumerate(self._fn):
|
||||
result = ele_fn([ele[i] for ele in data])
|
||||
if isinstance(result, (tuple, list)):
|
||||
ret.extend(result)
|
||||
else:
|
||||
ret.append(result)
|
||||
return tuple(ret)
|
||||
|
||||
class Pad():
|
||||
"""
|
||||
pad the data with given value
|
||||
"""
|
||||
def __init__(self,
|
||||
pad_val=0,
|
||||
axis=0,
|
||||
ret_length=None,
|
||||
dtype=None,
|
||||
pad_right=True):
|
||||
self._pad_val = pad_val
|
||||
self._axis = axis
|
||||
self._ret_length = ret_length
|
||||
self._dtype = dtype
|
||||
self._pad_right = pad_right
|
||||
|
||||
def __call__(self, data):
|
||||
arrs = [np.asarray(ele) for ele in data]
|
||||
original_length = [ele.shape[self._axis] for ele in arrs]
|
||||
max_size = max(original_length)
|
||||
ret_shape = list(arrs[0].shape)
|
||||
ret_shape[self._axis] = max_size
|
||||
ret_shape = (len(arrs),) + tuple(ret_shape)
|
||||
ret = np.full(
|
||||
shape=ret_shape,
|
||||
fill_value=self._pad_val,
|
||||
dtype=arrs[0].dtype if self._dtype is None else self._dtype)
|
||||
for i, arr in enumerate(arrs):
|
||||
if arr.shape[self._axis] == max_size:
|
||||
ret[i] = arr
|
||||
else:
|
||||
slices = [slice(None) for _ in range(arr.ndim)]
|
||||
if self._pad_right:
|
||||
slices[self._axis] = slice(0, arr.shape[self._axis])
|
||||
else:
|
||||
slices[self._axis] = slice(max_size - arr.shape[self._axis],
|
||||
max_size)
|
||||
|
||||
if slices[self._axis].start != slices[self._axis].stop:
|
||||
slices = [slice(i, i + 1)] + slices
|
||||
ret[tuple(slices)] = arr
|
||||
if self._ret_length:
|
||||
return ret, np.asarray(
|
||||
original_length,
|
||||
dtype="int32") if self._ret_length else np.asarray(
|
||||
original_length, self._ret_length)
|
||||
return ret
|
||||
|
||||
class Stack():
|
||||
"""
|
||||
Stack the input data
|
||||
"""
|
||||
|
||||
def __init__(self, axis=0, dtype=None):
|
||||
self._axis = axis
|
||||
self._dtype = dtype
|
||||
|
||||
def __call__(self, data):
|
||||
data = np.stack(
|
||||
data,
|
||||
axis=self._axis).astype(self._dtype) if self._dtype else np.stack(
|
||||
data, axis=self._axis)
|
||||
return data
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
data convert to mindrecord file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import dataset as data
|
||||
from tokenizer import FullTokenizer
|
||||
from data_util import Tuple, Pad, Stack
|
||||
from mindspore.mindrecord import FileWriter
|
||||
TASK_CLASSES = {
|
||||
'udc': data.UDCv1,
|
||||
'dstc2': data.DSTC2,
|
||||
'atis_slot': data.ATIS_DSF,
|
||||
'atis_intent': data.ATIS_DID,
|
||||
'mrda': data.MRDA,
|
||||
'swda': data.SwDA,
|
||||
}
|
||||
|
||||
def data_save_to_file(data_file_path=None, vocab_file_path='bert-base-uncased-vocab.txt', \
|
||||
output_path=None, task_name=None, mode="train", max_seq_length=128):
|
||||
"""data save to mindrecord file."""
|
||||
MINDRECORD_FILE_PATH = output_path + task_name+"/" + task_name + "_" + mode + ".mindrecord"
|
||||
if not os.path.exists(output_path + task_name):
|
||||
os.makedirs(output_path + task_name)
|
||||
if os.path.exists(MINDRECORD_FILE_PATH):
|
||||
os.remove(MINDRECORD_FILE_PATH)
|
||||
os.remove(MINDRECORD_FILE_PATH + ".db")
|
||||
dataset_class = TASK_CLASSES[task_name]
|
||||
tokenizer = FullTokenizer(vocab_file=vocab_file_path, do_lower_case=True)
|
||||
dataset = dataset_class(data_file_path+task_name, mode=mode)
|
||||
applid_data = []
|
||||
datalist = []
|
||||
print(task_name + " " + mode + " data process begin")
|
||||
dataset_len = len(dataset)
|
||||
if args.task_name == 'atis_slot':
|
||||
batchify_fn = lambda samples, fn=Tuple(
|
||||
Pad(axis=0, pad_val=0), # input
|
||||
Pad(axis=0, pad_val=0), # mask
|
||||
Pad(axis=0, pad_val=0), # segment
|
||||
Pad(axis=0, pad_val=0, dtype='int64') # label
|
||||
): fn(samples)
|
||||
else:
|
||||
batchify_fn = lambda samples, fn=Tuple(
|
||||
Pad(axis=0, pad_val=0), # input
|
||||
Pad(axis=0, pad_val=0), # mask
|
||||
Pad(axis=0, pad_val=0), # segment
|
||||
Stack(dtype='int64') # label
|
||||
): fn(samples)
|
||||
for idx, example in enumerate(dataset):
|
||||
if idx % 1000 == 0:
|
||||
print("Reading example %d of %d" % (idx, dataset_len))
|
||||
data_example = dataset_class.convert_example(example=example, \
|
||||
tokenizer=tokenizer, max_seq_length=max_seq_length)
|
||||
applid_data.append(data_example)
|
||||
|
||||
applid_data = batchify_fn(applid_data)
|
||||
input_ids, input_mask, segment_ids, label_ids = applid_data
|
||||
|
||||
for idx in range(dataset_len):
|
||||
if idx % 1000 == 0:
|
||||
print("Processing example %d of %d" % (idx, dataset_len))
|
||||
sample = {
|
||||
"input_ids": np.array(input_ids[idx], dtype=np.int64),
|
||||
"input_mask": np.array(input_mask[idx], dtype=np.int64),
|
||||
"segment_ids": np.array(segment_ids[idx], dtype=np.int64),
|
||||
"label_ids": np.array([label_ids[idx]], dtype=np.int64),
|
||||
}
|
||||
datalist.append(sample)
|
||||
|
||||
print(task_name + " " + mode + " data process end")
|
||||
writer = FileWriter(file_name=MINDRECORD_FILE_PATH, shard_num=1)
|
||||
nlp_schema = {
|
||||
"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"segment_ids": {"type": "int64", "shape": [-1]},
|
||||
"label_ids": {"type": "int64", "shape": [-1]},
|
||||
}
|
||||
writer.add_schema(nlp_schema, "proprocessed classification dataset")
|
||||
writer.write_raw_data(datalist)
|
||||
writer.commit()
|
||||
print("write success")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="run classifier")
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The directory where the dataset will be load.")
|
||||
parser.add_argument(
|
||||
"--vocab_file_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The directory where the vocab will be load.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The directory where the mindrecord dataset file will be save.")
|
||||
parser.add_argument(
|
||||
"--max_seq_len",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization for trainng. ")
|
||||
parser.add_argument(
|
||||
"--eval_max_seq_len",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization for trainng. ")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.eval_max_seq_len is None:
|
||||
args.eval_max_seq_len = args.max_seq_len
|
||||
data_save_to_file(data_file_path=args.data_dir, vocab_file_path=args.vocab_file_dir, output_path=args.output_dir, \
|
||||
task_name=args.task_name, mode="train", max_seq_length=args.max_seq_len)
|
||||
data_save_to_file(data_file_path=args.data_dir, vocab_file_path=args.vocab_file_dir, output_path=args.output_dir, \
|
||||
task_name=args.task_name, mode="dev", max_seq_length=args.eval_max_seq_len)
|
||||
data_save_to_file(data_file_path=args.data_dir, vocab_file_path=args.vocab_file_dir, output_path=args.output_dir, \
|
||||
task_name=args.task_name, mode="test", max_seq_length=args.eval_max_seq_len)
|
|
@ -0,0 +1,608 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
dataset used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
# The input data bigin with '[CLS]', using '[SEP]' split conversation content(
|
||||
# Previous part, current part, following part, etc.). If there are multiple
|
||||
# conversation in split part, using 'INNER_SEP' to further split.
|
||||
INNER_SEP = '[unused0]'
|
||||
|
||||
class Dataset():
|
||||
""" Dataset base class """
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError("'{}' not implement in class " \
|
||||
"{}".format('__getitem__', self.__class__.__name__))
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError("'{}' not implement in class " \
|
||||
"{}".format('__len__', self.__class__.__name__))
|
||||
|
||||
def get_label_map(label_list):
|
||||
""" Create label maps """
|
||||
label_map = {}
|
||||
for (i, l) in enumerate(label_list):
|
||||
label_map[l] = i
|
||||
return label_map
|
||||
|
||||
|
||||
class UDCv1(Dataset):
|
||||
"""
|
||||
The UDCv1 dataset is using in task Dialogue Response Selection.
|
||||
The source dataset is UDCv1(Ubuntu Dialogue Corpus v1.0). See detail at
|
||||
http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/
|
||||
"""
|
||||
MAX_LEN_OF_RESPONSE = 60
|
||||
LABEL_MAP = get_label_map(['0', '1'])
|
||||
|
||||
def __init__(self, data_dir, mode='train', label_map_config=None):
|
||||
super(UDCv1, self).__init__()
|
||||
self._data_dir = data_dir
|
||||
self._mode = mode
|
||||
self.read_data()
|
||||
self.label_map = None
|
||||
if label_map_config:
|
||||
with open(label_map_config) as f:
|
||||
self.label_map = json.load(f)
|
||||
else:
|
||||
self.label_map = None
|
||||
#read data from file
|
||||
def read_data(self):
|
||||
"""read data from file"""
|
||||
if self._mode == 'train':
|
||||
data_path = os.path.join(self._data_dir, 'train.txt')
|
||||
elif self._mode == 'dev':
|
||||
data_path = os.path.join(self._data_dir, 'dev.txt-small')
|
||||
elif self._mode == 'test':
|
||||
data_path = os.path.join(self._data_dir, 'test.txt')
|
||||
self.data = []
|
||||
with open(data_path, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
if not line:
|
||||
continue
|
||||
arr = line.rstrip('\n').split('\t')
|
||||
if len(arr) < 3:
|
||||
print('Data format error: %s' % '\t'.join(arr))
|
||||
print(
|
||||
'Data row contains at least three parts: label\tconversation1\t.....\tresponse.'
|
||||
)
|
||||
continue
|
||||
label = arr[0]
|
||||
text_a = arr[1:-1]
|
||||
text_b = arr[-1]
|
||||
self.data.append([label, text_a, text_b])
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
def _truncate_and_concat(text_a: List[str], text_b: str, tokenizer, max_seq_length):
|
||||
tokens_b = tokenizer.tokenize(text_b)
|
||||
tokens_b = tokens_b[:min(cls.MAX_LEN_OF_RESPONSE, len(tokens_b))]
|
||||
tokens_a = []
|
||||
for text in text_a:
|
||||
tokens_a.extend(tokenizer.tokenize(text))
|
||||
tokens_a.append(INNER_SEP)
|
||||
tokens_a = tokens_a[:-1]
|
||||
if len(tokens_a) > max_seq_length - len(tokens_b) - 3:
|
||||
tokens_a = tokens_a[len(tokens_a) - max_seq_length + len(tokens_b) + 3:]
|
||||
tokens, segment_ids = [], []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
if tokens_b:
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
return input_ids, input_mask, segment_ids
|
||||
|
||||
label, text_a, text_b = example
|
||||
label = np.array([cls.get_label(label)], dtype='int64')
|
||||
input_ids, input_mask, segment_ids = _truncate_and_concat(text_a, text_b, tokenizer, max_seq_length)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class DSTC2(Dataset):
|
||||
"""
|
||||
The dataset DSTC2 is using in task Dialogue State Tracking.
|
||||
The source dataset is DSTC2(Dialog State Tracking Challenges 2). See detail at
|
||||
https://github.com/matthen/dstc
|
||||
"""
|
||||
LABEL_MAP = get_label_map([str(i) for i in range(217)])
|
||||
|
||||
def __init__(self, data_dir, mode='train'):
|
||||
super(DSTC2, self).__init__()
|
||||
self._data_dir = data_dir
|
||||
self._mode = mode
|
||||
self.read_data()
|
||||
|
||||
def read_data(self):
|
||||
"""read data from file"""
|
||||
def _concat_dialogues(examples):
|
||||
"""concat multi turns dialogues"""
|
||||
new_examples = []
|
||||
max_turns = 20
|
||||
example_len = len(examples)
|
||||
for i in range(example_len):
|
||||
multi_turns = examples[max(i - max_turns, 0):i + 1]
|
||||
new_qa = '\1'.join([example[0] for example in multi_turns])
|
||||
new_examples.append((new_qa.split('\1'), examples[i][1]))
|
||||
return new_examples
|
||||
|
||||
if self._mode == 'train':
|
||||
data_path = os.path.join(self._data_dir, 'train.txt')
|
||||
elif self._mode == 'dev':
|
||||
data_path = os.path.join(self._data_dir, 'dev.txt')
|
||||
elif self._mode == 'test':
|
||||
data_path = os.path.join(self._data_dir, 'test.txt')
|
||||
self.data = []
|
||||
with open(data_path, 'r', encoding='utf8') as fin:
|
||||
pre_idx = -1
|
||||
examples = []
|
||||
for line in fin:
|
||||
if not line:
|
||||
continue
|
||||
arr = line.rstrip('\n').split('\t')
|
||||
if len(arr) != 3:
|
||||
print('Data format error: %s' % '\t'.join(arr))
|
||||
print(
|
||||
'Data row should contains three parts: id\tquestion\1answer\tlabel1 label2 ...'
|
||||
)
|
||||
continue
|
||||
idx = arr[0]
|
||||
qa = arr[1]
|
||||
label_list = arr[2].split()
|
||||
if idx != pre_idx:
|
||||
if idx != 0:
|
||||
examples = _concat_dialogues(examples)
|
||||
self.data.extend(examples)
|
||||
examples = []
|
||||
pre_idx = idx
|
||||
examples.append((qa, label_list))
|
||||
if examples:
|
||||
examples = _concat_dialogues(examples)
|
||||
self.data.extend(examples)
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
|
||||
def _truncate_and_concat(texts: List[str], tokenizer, max_seq_length):
|
||||
tokens = []
|
||||
for text in texts:
|
||||
tokens.extend(tokenizer.tokenize(text))
|
||||
tokens.append(INNER_SEP)
|
||||
tokens = tokens[:-1]
|
||||
if len(tokens) > max_seq_length - 2:
|
||||
tokens = tokens[len(tokens) - max_seq_length + 2:]
|
||||
tokens_, segment_ids = [], []
|
||||
tokens_.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens:
|
||||
tokens_.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens_.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
tokens = tokens_
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
return input_ids, segment_ids
|
||||
|
||||
texts, labels = example
|
||||
input_ids, segment_ids = _truncate_and_concat(texts, tokenizer,
|
||||
max_seq_length)
|
||||
labels = [cls.get_label(l) for l in labels]
|
||||
label = np.zeros(cls.num_classes(), dtype='int64')
|
||||
for l in labels:
|
||||
label[l] = 1
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
class ATIS_DSF(Dataset):
|
||||
"""
|
||||
The dataset ATIS_DSF is using in task Dialogue Slot Filling.
|
||||
The source dataset is ATIS(Airline Travel Information System). See detail at
|
||||
https://www.kaggle.com/siddhadev/ms-cntk-atis
|
||||
"""
|
||||
LABEL_MAP = get_label_map([str(i) for i in range(130)])
|
||||
|
||||
def __init__(self, data_dir, mode='train'):
|
||||
super(ATIS_DSF, self).__init__()
|
||||
self._data_dir = data_dir
|
||||
self._mode = mode
|
||||
self.read_data()
|
||||
|
||||
def read_data(self):
|
||||
"""read data from file"""
|
||||
if self._mode == 'train':
|
||||
data_path = os.path.join(self._data_dir, 'train.txt')
|
||||
elif self._mode == 'dev':
|
||||
data_path = os.path.join(self._data_dir, 'dev.txt')
|
||||
elif self._mode == 'test':
|
||||
data_path = os.path.join(self._data_dir, 'test.txt')
|
||||
self.data = []
|
||||
with open(data_path, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
if not line:
|
||||
continue
|
||||
arr = line.rstrip('\n').split('\t')
|
||||
if len(arr) != 2:
|
||||
print('Data format error: %s' % '\t'.join(arr))
|
||||
print(
|
||||
'Data row should contains two parts: conversation_content\tlabel1 label2 label3.'
|
||||
)
|
||||
continue
|
||||
text = arr[0]
|
||||
label_list = arr[1].split()
|
||||
self.data.append([text, label_list])
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
text, labels = example
|
||||
tokens, label_list = [], []
|
||||
words = text.split()
|
||||
assert len(words) == len(labels)
|
||||
for word, label in zip(words, labels):
|
||||
piece_words = tokenizer.tokenize(word)
|
||||
tokens.extend(piece_words)
|
||||
label = cls.get_label(label)
|
||||
label_list.extend([label] * len(piece_words))
|
||||
if len(tokens) > max_seq_length - 2:
|
||||
tokens = tokens[len(tokens) - max_seq_length + 2:]
|
||||
label_list = label_list[len(tokens) - max_seq_length + 2:]
|
||||
tokens_, segment_ids = [], []
|
||||
tokens_.append("[CLS]")
|
||||
for token in tokens:
|
||||
tokens_.append(token)
|
||||
tokens_.append("[SEP]")
|
||||
tokens = tokens_
|
||||
label_list = [0] + label_list + [0]
|
||||
segment_ids = [0] * len(tokens)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
label = np.array(label_list, dtype='int64')
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class ATIS_DID(Dataset):
|
||||
"""
|
||||
The dataset ATIS_ID is using in task Dialogue Intent Detection.
|
||||
The source dataset is ATIS(Airline Travel Information System). See detail at
|
||||
https://www.kaggle.com/siddhadev/ms-cntk-atis
|
||||
"""
|
||||
LABEL_MAP = get_label_map([str(i) for i in range(26)])
|
||||
|
||||
def __init__(self, data_dir, mode='train'):
|
||||
super(ATIS_DID, self).__init__()
|
||||
self._data_dir = data_dir
|
||||
self._mode = mode
|
||||
self.read_data()
|
||||
|
||||
def read_data(self):
|
||||
"""read data from file"""
|
||||
if self._mode == 'train':
|
||||
data_path = os.path.join(self._data_dir, 'train.txt')
|
||||
elif self._mode == 'dev':
|
||||
data_path = os.path.join(self._data_dir, 'dev.txt')
|
||||
elif self._mode == 'test':
|
||||
data_path = os.path.join(self._data_dir, 'test.txt')
|
||||
self.data = []
|
||||
with open(data_path, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
if not line:
|
||||
continue
|
||||
arr = line.rstrip('\n').split('\t')
|
||||
if len(arr) != 2:
|
||||
print('Data format error: %s' % '\t'.join(arr))
|
||||
print(
|
||||
'Data row should contains two parts: label\tconversation_content.'
|
||||
)
|
||||
continue
|
||||
label = arr[0]
|
||||
text = arr[1]
|
||||
self.data.append([label, text])
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
label, text = example
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if len(tokens) > max_seq_length - 2:
|
||||
tokens = tokens[len(tokens) - max_seq_length + 2:]
|
||||
tokens_, segment_ids = [], []
|
||||
tokens_.append("[CLS]")
|
||||
for token in tokens:
|
||||
tokens_.append(token)
|
||||
tokens_.append("[SEP]")
|
||||
tokens = tokens_
|
||||
segment_ids = [0] * len(tokens)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
label = np.array([cls.get_label(label)], dtype='int64')
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def read_da_data(data_dir, mode):
|
||||
"""read data from file"""
|
||||
def _concat_dialogues(examples):
|
||||
"""concat multi turns dialogues"""
|
||||
new_examples = []
|
||||
example_len = len(examples)
|
||||
for i in range(example_len):
|
||||
label, caller, text = examples[i]
|
||||
cur_txt = "%s : %s" % (caller, text)
|
||||
pre_txt = [
|
||||
"%s : %s" % (item[1], item[2])
|
||||
for item in examples[max(0, i - 5):i]
|
||||
]
|
||||
suf_txt = [
|
||||
"%s : %s" % (item[1], item[2])
|
||||
for item in examples[i + 1:min(len(examples), i + 3)]
|
||||
]
|
||||
sample = [label, pre_txt, cur_txt, suf_txt]
|
||||
new_examples.append(sample)
|
||||
return new_examples
|
||||
|
||||
if mode == 'train':
|
||||
data_path = os.path.join(data_dir, 'train.txt')
|
||||
elif mode == 'dev':
|
||||
data_path = os.path.join(data_dir, 'dev.txt')
|
||||
elif mode == 'test':
|
||||
data_path = os.path.join(data_dir, 'test.txt')
|
||||
data = []
|
||||
with open(data_path, 'r', encoding='utf8') as fin:
|
||||
pre_idx = -1
|
||||
examples = []
|
||||
for line in fin:
|
||||
if not line:
|
||||
continue
|
||||
arr = line.rstrip('\n').split('\t')
|
||||
if len(arr) != 4:
|
||||
print('Data format error: %s' % '\t'.join(arr))
|
||||
print(
|
||||
'Data row should contains four parts: id\tlabel\tcaller\tconversation_content.'
|
||||
)
|
||||
continue
|
||||
idx, label, caller, text = arr
|
||||
if idx != pre_idx:
|
||||
if idx != 0:
|
||||
examples = _concat_dialogues(examples)
|
||||
data.extend(examples)
|
||||
examples = []
|
||||
pre_idx = idx
|
||||
examples.append((label, caller, text))
|
||||
if examples:
|
||||
examples = _concat_dialogues(examples)
|
||||
data.extend(examples)
|
||||
return data
|
||||
|
||||
|
||||
def truncate_and_concat(pre_txt: List[str],
|
||||
cur_txt: str,
|
||||
suf_txt: List[str],
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
max_len_of_cur_text):
|
||||
"""concat data"""
|
||||
cur_tokens = tokenizer.tokenize(cur_txt)
|
||||
cur_tokens = cur_tokens[:min(max_len_of_cur_text, len(cur_tokens))]
|
||||
pre_tokens = []
|
||||
for text in pre_txt:
|
||||
pre_tokens.extend(tokenizer.tokenize(text))
|
||||
pre_tokens.append(INNER_SEP)
|
||||
pre_tokens = pre_tokens[:-1]
|
||||
suf_tokens = []
|
||||
for text in suf_txt:
|
||||
suf_tokens.extend(tokenizer.tokenize(text))
|
||||
suf_tokens.append(INNER_SEP)
|
||||
suf_tokens = suf_tokens[:-1]
|
||||
if len(cur_tokens) + len(pre_tokens) + len(suf_tokens) > max_seq_length - 4:
|
||||
left_num = max_seq_length - 4 - len(cur_tokens)
|
||||
if len(pre_tokens) > len(suf_tokens):
|
||||
suf_num = int(left_num / 2)
|
||||
suf_tokens = suf_tokens[:suf_num]
|
||||
pre_num = left_num - len(suf_tokens)
|
||||
pre_tokens = pre_tokens[max(0, len(pre_tokens) - pre_num):]
|
||||
else:
|
||||
pre_num = int(left_num / 2)
|
||||
pre_tokens = pre_tokens[max(0, len(pre_tokens) - pre_num):]
|
||||
suf_num = left_num - len(pre_tokens)
|
||||
suf_tokens = suf_tokens[:suf_num]
|
||||
tokens, segment_ids = [], []
|
||||
tokens.append("[CLS]")
|
||||
for token in pre_tokens:
|
||||
tokens.append(token)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.extend([0] * len(tokens))
|
||||
for token in cur_tokens:
|
||||
tokens.append(token)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.extend([1] * (len(cur_tokens) + 1))
|
||||
if suf_tokens:
|
||||
for token in suf_tokens:
|
||||
tokens.append(token)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.extend([0] * (len(suf_tokens) + 1))
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
return input_ids, input_mask, segment_ids
|
||||
|
||||
|
||||
class MRDA(Dataset):
|
||||
"""
|
||||
The dataset MRDA is using in task Dialogue Act.
|
||||
The source dataset is MRDA(Meeting Recorder Dialogue Act). See detail at
|
||||
https://www.aclweb.org/anthology/W04-2319.pdf
|
||||
"""
|
||||
MAX_LEN_OF_CUR_TEXT = 50
|
||||
LABEL_MAP = get_label_map([str(i) for i in range(5)])
|
||||
|
||||
def __init__(self, data_dir, mode='train'):
|
||||
super(MRDA, self).__init__()
|
||||
self.data = read_da_data(data_dir, mode)
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
label, pre_txt, cur_txt, suf_txt = example
|
||||
label = np.array([cls.get_label(label)], dtype='int64')
|
||||
input_ids, input_mask, segment_ids = truncate_and_concat(pre_txt, cur_txt, suf_txt, \
|
||||
tokenizer, max_seq_length, cls.MAX_LEN_OF_CUR_TEXT)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class SwDA(Dataset):
|
||||
"""
|
||||
The dataset SwDA is using in task Dialogue Act.
|
||||
The source dataset is SwDA(Switchboard Dialog Act). See detail at
|
||||
http://compprag.christopherpotts.net/swda.html
|
||||
"""
|
||||
MAX_LEN_OF_CUR_TEXT = 50
|
||||
LABEL_MAP = get_label_map([str(i) for i in range(42)])
|
||||
|
||||
def __init__(self, data_dir, mode='train'):
|
||||
super(SwDA, self).__init__()
|
||||
self.data = read_da_data(data_dir, mode)
|
||||
|
||||
@classmethod
|
||||
def get_label(cls, label):
|
||||
return cls.LABEL_MAP[label]
|
||||
|
||||
@classmethod
|
||||
def num_classes(cls):
|
||||
return len(cls.LABEL_MAP)
|
||||
|
||||
@classmethod
|
||||
def convert_example(cls, example, tokenizer, max_seq_length=512):
|
||||
""" Convert a glue example into necessary features. """
|
||||
label, pre_txt, cur_txt, suf_txt = example
|
||||
label = np.array([cls.get_label(label)], dtype='int64')
|
||||
input_ids, input_mask, segment_ids = truncate_and_concat(pre_txt, cur_txt, suf_txt, \
|
||||
tokenizer, max_seq_length, cls.MAX_LEN_OF_CUR_TEXT)
|
||||
return input_ids, input_mask, segment_ids, label
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
config settings, will be used in finetune.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from .bert_model import BertConfig
|
||||
|
||||
optimizer_cfg = edict({
|
||||
'optimizer': 'Lamb',
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-5,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
})
|
||||
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
|
||||
bert_net_udc_cfg = BertConfig(
|
||||
seq_length=224,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Bert finetune and evaluation model script.
|
||||
'''
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
from .bert_model import BertModel
|
||||
|
||||
class BertCLSModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
|
||||
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
|
||||
logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
assessment_method=""):
|
||||
super(BertCLSModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.hidden_probs_dropout_prob = 0.0
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.assessment_method = assessment_method
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
_, pooled_output, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
cls = self.cast(pooled_output, self.dtype)
|
||||
cls = self.dropout(cls)
|
||||
logits = self.dense_1(cls)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
if self.assessment_method != "spearman_correlation":
|
||||
logits = self.log_softmax(logits)
|
||||
return logits
|
||||
|
||||
class BertSquadModel(nn.Cell):
|
||||
'''
|
||||
This class is responsible for SQuAD
|
||||
'''
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquadModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.hidden_probs_dropout_prob = 0.0
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.num_labels = num_labels
|
||||
self.dtype = config.dtype
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
batch_size, seq_length, hidden_size = P.Shape()(sequence_output)
|
||||
sequence = P.Reshape()(sequence_output, (-1, hidden_size))
|
||||
logits = self.dense1(sequence)
|
||||
logits = P.Cast()(logits, self.dtype)
|
||||
logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels))
|
||||
logits = self.log_softmax(logits)
|
||||
return logits
|
||||
|
||||
class BertNERModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
|
||||
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertNERModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.hidden_probs_dropout_prob = 0.0
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.hidden_size)
|
||||
self.use_crf = use_crf
|
||||
self.origin_shape = (-1, config.seq_length, self.num_labels)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""Return the final logits as the results of log_softmax."""
|
||||
sequence_output, _, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
seq = self.dropout(sequence_output)
|
||||
seq = self.reshape(seq, self.shape)
|
||||
logits = self.dense_1(seq)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
if self.use_crf:
|
||||
return_value = self.reshape(logits, self.origin_shape)
|
||||
else:
|
||||
return_value = self.log_softmax(logits)
|
||||
return return_value
|
|
@ -0,0 +1,230 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Metric used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics.metric import Metric
|
||||
|
||||
class F1Score(Metric):
|
||||
"""
|
||||
F1-score is the harmonic mean of precision and recall. Micro-averaging is
|
||||
to create a global confusion matrix for all examples, and then calculate
|
||||
the F1-score. This class is using to evaluate the performance of Dialogue
|
||||
Slot Filling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(F1Score, self).__init__(*args, **kwargs)
|
||||
self._name = 'F1Score'
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Resets all of the metric state.
|
||||
"""
|
||||
self.tp = {}
|
||||
self.fn = {}
|
||||
self.fp = {}
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""
|
||||
Update the states based on the current mini-batch prediction results.
|
||||
|
||||
Args:
|
||||
logits (Tensor): The predicted value is a Tensor with
|
||||
shape [batch_size, seq_len, num_classes] and type float32 or
|
||||
float64.
|
||||
labels (Tensor): The ground truth value is a 2D Tensor,
|
||||
its shape is [batch_size, seq_len] and type is int64.
|
||||
"""
|
||||
output = logits.asnumpy()
|
||||
probs = output.argmax(axis=-1)
|
||||
labels = labels.asnumpy()
|
||||
assert probs.shape[0] == labels.shape[0]
|
||||
assert probs.shape[1] == labels.shape[1]
|
||||
for i in range(probs.shape[0]):
|
||||
start, end = 1, probs.shape[1]
|
||||
while end > start:
|
||||
if labels[i][end - 1] != 0:
|
||||
break
|
||||
end -= 1
|
||||
prob, label = probs[i][start:end], labels[i][start:end]
|
||||
for y_pred, y in zip(prob, label):
|
||||
if y_pred == y:
|
||||
self.tp[y] = self.tp.get(y, 0) + 1
|
||||
else:
|
||||
self.fp[y_pred] = self.fp.get(y_pred, 0) + 1
|
||||
self.fn[y] = self.fn.get(y, 0) + 1
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Calculate the final micro F1 score.
|
||||
|
||||
Returns:
|
||||
A scaler float: results of the calculated micro F1 score.
|
||||
"""
|
||||
tp_total = sum(self.tp.values())
|
||||
fn_total = sum(self.fn.values())
|
||||
fp_total = sum(self.fp.values())
|
||||
p_total = float(tp_total) / (tp_total + fp_total)
|
||||
r_total = float(tp_total) / (tp_total + fn_total)
|
||||
if p_total + r_total == 0:
|
||||
return 0
|
||||
f1_micro = 2 * p_total * r_total / (p_total + r_total)
|
||||
return f1_micro
|
||||
|
||||
def name(self):
|
||||
"""
|
||||
Returns metric name
|
||||
"""
|
||||
return self._name
|
||||
|
||||
class JointAccuracy(Metric):
|
||||
"""
|
||||
The joint accuracy rate is used to evaluate the performance of multi-turn
|
||||
Dialogue State Tracking. For each turn, if and only if all state in
|
||||
state_list are correctly predicted, the dialog state prediction is
|
||||
considered correct. And the joint accuracy rate is equal to 1, otherwise
|
||||
it is equal to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(JointAccuracy, self).__init__(*args, **kwargs)
|
||||
self._name = 'JointAccuracy'
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Resets all of the metric state.
|
||||
"""
|
||||
self.num_samples = 0
|
||||
self.correct_joint = 0.0
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""
|
||||
Update the states based on the current mini-batch prediction results.
|
||||
|
||||
"""
|
||||
probs = self.sigmoid(logits)
|
||||
probs = probs.asnumpy()
|
||||
labels = labels.asnumpy()
|
||||
assert probs.shape[0] == labels.shape[0]
|
||||
assert probs.shape[1] == labels.shape[1]
|
||||
for i in range(probs.shape[0]):
|
||||
pred, refer = [], []
|
||||
for j in range(probs.shape[1]):
|
||||
if probs[i][j] >= 0.5:
|
||||
pred.append(j)
|
||||
if labels[i][j] == 1:
|
||||
refer.append(j)
|
||||
if not pred:
|
||||
pred = [np.argmax(probs[i])]
|
||||
if pred == refer:
|
||||
self.correct_joint += 1
|
||||
self.num_samples += probs.shape[0]
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Returns the results of the calculated JointAccuracy.
|
||||
"""
|
||||
joint_acc = self.correct_joint / self.num_samples
|
||||
return joint_acc
|
||||
|
||||
def name(self):
|
||||
"""
|
||||
Returns metric name
|
||||
"""
|
||||
return self._name
|
||||
|
||||
class RecallAtK(Metric):
|
||||
"""
|
||||
Recall@K is the fraction of relevant results among the retrieved Top K
|
||||
results, using to evaluate the performance of Dialogue Response Selection.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RecallAtK, self).__init__(*args, **kwargs)
|
||||
self._name = 'Recall@K'
|
||||
self.softmax = nn.Softmax()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Resets all of the metric state.
|
||||
"""
|
||||
self.num_sampls = 0
|
||||
self.p_at_1_in_10 = 0.0
|
||||
self.p_at_2_in_10 = 0.0
|
||||
self.p_at_5_in_10 = 0.0
|
||||
|
||||
def get_p_at_n_in_m(self, data, n, m, idx):
|
||||
"""
|
||||
calculate precision in recall n
|
||||
"""
|
||||
pos_score = data[idx][0]
|
||||
curr = data[idx:idx + m]
|
||||
curr = sorted(curr, key=lambda x: x[0], reverse=True)
|
||||
if curr[n - 1][0] <= pos_score:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""
|
||||
Update the states based on the current mini-batch prediction results.
|
||||
Args:
|
||||
logits (Tensor): The predicted value is a Tensor with
|
||||
shape [batch_size, 2] and type float32 or float64.
|
||||
labels (Tensor): The ground truth value is a 2D Tensor,
|
||||
its shape is [batch_size, 1] and type is int64.
|
||||
"""
|
||||
probs = self.softmax(logits)
|
||||
probs = probs.asnumpy()
|
||||
labels = labels.asnumpy()
|
||||
assert probs.shape[0] == labels.shape[0]
|
||||
data = []
|
||||
for prob, label in zip(probs, labels):
|
||||
data.append((prob[1], label))
|
||||
assert len(data) % 10 == 0
|
||||
|
||||
length = int(len(data) / 10)
|
||||
self.num_sampls += length
|
||||
for i in range(length):
|
||||
idx = i * 10
|
||||
assert data[idx][1] == 1
|
||||
self.p_at_1_in_10 += self.get_p_at_n_in_m(data, 1, 10, idx)
|
||||
self.p_at_2_in_10 += self.get_p_at_n_in_m(data, 2, 10, idx)
|
||||
self.p_at_5_in_10 += self.get_p_at_n_in_m(data, 5, 10, idx)
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Calculate the final Recall@K.
|
||||
Returnsa list with scaler float: results of the calculated R1@K, R2@K, R5@K.
|
||||
"""
|
||||
metrics_out = [
|
||||
self.p_at_1_in_10 / self.num_sampls, self.p_at_2_in_10 /
|
||||
self.num_sampls, self.p_at_5_in_10 / self.num_sampls
|
||||
]
|
||||
return metrics_out
|
||||
|
||||
def name(self):
|
||||
"""
|
||||
Returns metric name
|
||||
"""
|
||||
return self._name
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
convert pretrain model from pdparams to mindspore ckpt
|
||||
"""
|
||||
import collections
|
||||
import os
|
||||
import paddle.fluid.dygraph as D
|
||||
from paddle import fluid
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
def build_params_map(attention_num=12):
|
||||
"""
|
||||
build params map from paddle-paddle's BERT to mindspore's BERT
|
||||
:return:
|
||||
"""
|
||||
weight_map = collections.OrderedDict({
|
||||
'bert.embeddings.word_embeddings.weight': "bert.bert.bert_embedding_lookup.embedding_table",
|
||||
'bert.embeddings.token_type_embeddings.weight': "bert.bert.bert_embedding_postprocessor.embedding_table",
|
||||
'bert.embeddings.position_embeddings.weight': "bert.bert.bert_embedding_postprocessor.full_position_embeddings",
|
||||
'bert.embeddings.layer_norm.weight': 'bert.bert.bert_embedding_postprocessor.layernorm.gamma',
|
||||
'bert.embeddings.layer_norm.bias': 'bert.bert.bert_embedding_postprocessor.layernorm.beta',
|
||||
})
|
||||
# add attention layers
|
||||
for i in range(attention_num):
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.q_proj.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.query_layer.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.q_proj.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.query_layer.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.k_proj.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.key_layer.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.k_proj.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.key_layer.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.v_proj.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.value_layer.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.v_proj.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.attention.value_layer.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.out_proj.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.output.dense.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.self_attn.out_proj.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.output.dense.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.linear1.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.intermediate.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.linear1.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.intermediate.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.linear2.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.output.dense.weight'
|
||||
weight_map[f'bert.encoder.layers.{i}.linear2.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.output.dense.bias'
|
||||
weight_map[f'bert.encoder.layers.{i}.norm1.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.output.layernorm.gamma'
|
||||
weight_map[f'bert.encoder.layers.{i}.norm1.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.attention.output.layernorm.beta'
|
||||
weight_map[f'bert.encoder.layers.{i}.norm2.weight'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.output.layernorm.gamma'
|
||||
weight_map[f'bert.encoder.layers.{i}.norm2.bias'] = \
|
||||
f'bert.bert.bert_encoder.layers.{i}.output.layernorm.beta'
|
||||
# add pooler
|
||||
weight_map.update(
|
||||
{
|
||||
'bert.pooler.dense.weight': 'bert.bert.dense.weight',
|
||||
'bert.pooler.dense.bias': 'bert.bert.dense.bias'
|
||||
}
|
||||
)
|
||||
return weight_map
|
||||
|
||||
input_dir = '.'
|
||||
state_dict = []
|
||||
bert_weight_map = build_params_map(attention_num=12)
|
||||
with fluid.dygraph.guard():
|
||||
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'bert-base-uncased'))
|
||||
for weight_name, weight_value in paddle_paddle_params.items():
|
||||
if 'weight' in weight_name:
|
||||
if 'encoder' in weight_name or 'pooler' in weight_name or \
|
||||
'predictions' in weight_name or 'seq_relationship' in weight_name:
|
||||
weight_value = weight_value.transpose()
|
||||
if weight_name in bert_weight_map.keys():
|
||||
state_dict.append({'name': bert_weight_map[weight_name], 'data': Tensor(weight_value)})
|
||||
print(weight_name, '->', bert_weight_map[weight_name], weight_value.shape)
|
||||
save_checkpoint(state_dict, 'base-BertCLS-111.ckpt')
|
|
@ -0,0 +1,302 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Tokenizer used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import io
|
||||
import unicodedata
|
||||
|
||||
import six
|
||||
|
||||
class FullTokenizer():
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
class BasicTokenizer():
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((0x4E00 <= cp <= 0x9FFF) or #
|
||||
(0x3400 <= cp <= 0x4DBF) or #
|
||||
(0x20000 <= cp <= 0x2A6DF) or #
|
||||
(0x2A700 <= cp <= 0x2B73F) or #
|
||||
(0x2B740 <= cp <= 0x2B81F) or #
|
||||
(0x2B820 <= cp <= 0x2CEAF) or
|
||||
(0xF900 <= cp <= 0xFAFF) or #
|
||||
(0x2F800 <= cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
class WordpieceTokenizer():
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
if six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
if isinstance(text, unicode):
|
||||
return text
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
fin = io.open(vocab_file, encoding="utf8")
|
||||
for num, line in enumerate(fin):
|
||||
items = convert_to_unicode(line.strip()).split("\t")
|
||||
if len(items) > 2:
|
||||
break
|
||||
token = items[0]
|
||||
index = items[1] if len(items) == 2 else num
|
||||
token = token.strip()
|
||||
vocab[token] = int(index)
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char in (' ', '\\t', '\\n', '\\r'):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char in ('\\t', '\\n', '\\r'):
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((33 <= cp <= 47) or (58 <= cp <= 64) or
|
||||
(91 <= cp <= 96) or (123 <= cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,284 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Functional Cells used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import (LearningRateSchedule,
|
||||
PolynomialDecayLR, WarmUpLR)
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
def create_classification_dataset(batch_size=32, repeat_count=1,
|
||||
data_file_path=None, schema_file_path=None, do_shuffle=True):
|
||||
"""create finetune or evaluation dataset from mindrecord file"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
data_set = ds.MindDataset([data_file_path], \
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], shuffle=do_shuffle)
|
||||
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
#data_set = data_set.repeat(repeat_count)
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
||||
|
||||
|
||||
class CustomWarmUpLR(LearningRateSchedule):
|
||||
"""
|
||||
apply the functions to the corresponding input fields.
|
||||
·
|
||||
"""
|
||||
def __init__(self, learning_rate, warmup_steps, max_train_steps):
|
||||
super(CustomWarmUpLR, self).__init__()
|
||||
if not isinstance(learning_rate, float):
|
||||
raise TypeError("learning_rate must be float.")
|
||||
validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name)
|
||||
validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
|
||||
self.warmup_steps = warmup_steps
|
||||
self.learning_rate = learning_rate
|
||||
self.max_train_steps = max_train_steps
|
||||
self.cast = P.Cast()
|
||||
def construct(self, current_step):
|
||||
if current_step < self.warmup_steps:
|
||||
warmup_percent = self.cast(current_step, mstype.float32)/ self.warmup_steps
|
||||
else:
|
||||
warmup_percent = 1 - self.cast(current_step, mstype.float32)/ self.max_train_steps
|
||||
|
||||
return self.learning_rate * warmup_percent
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, logits, label_ids, num_labels):
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0
|
||||
return return_value
|
||||
|
||||
def make_directory(path: str):
|
||||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invalid type")
|
||||
|
||||
# convert the relative paths
|
||||
path = os.path.realpath(path)
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
||||
# check the path is exist and write permissions?
|
||||
if os.path.exists(path):
|
||||
real_path = path
|
||||
else:
|
||||
# All exceptions need to be caught because create directory maybe have some limit(permissions)
|
||||
logger.debug("The directory(%s) doesn't exist, will create it", path)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
real_path = path
|
||||
except PermissionError as e:
|
||||
logger.error("No write permission on the directory(%r), error = %r", path, e)
|
||||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, dataset_size=-1):
|
||||
super(LossCallBack, self).__init__()
|
||||
self._dataset_size = dataset_size
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Print loss after each step
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if self._dataset_size > 0:
|
||||
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
|
||||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
|
||||
flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
Find the ckpt finetune generated and load it into eval network.
|
||||
"""
|
||||
files = os.listdir(load_finetune_checkpoint_dir)
|
||||
pre_len = len(prefix)
|
||||
max_num = 0
|
||||
for filename in files:
|
||||
name_ext = os.path.splitext(filename)
|
||||
if name_ext[-1] != ".ckpt":
|
||||
continue
|
||||
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
|
||||
index = filename[pre_len:].find("-")
|
||||
if index == 0 and max_num == 0:
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
elif index not in (0, -1):
|
||||
name_split = name_ext[-2].split('_')
|
||||
if (steps_per_epoch != int(name_split[len(name_split)-1])) \
|
||||
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
|
||||
continue
|
||||
num = filename[pre_len + 1:pre_len + index]
|
||||
if int(num) > max_num:
|
||||
max_num = int(num)
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
return load_finetune_checkpoint_path
|
||||
|
||||
def GetAllCkptPath(save_finetune_checkpoint_path):
|
||||
files_list = os.listdir(save_finetune_checkpoint_path)
|
||||
ckpt_list = []
|
||||
for filename in files_list:
|
||||
if '.ckpt' in filename:
|
||||
load_finetune_checkpoint_dir = os.path.join(save_finetune_checkpoint_path, filename)
|
||||
ckpt_list.append(load_finetune_checkpoint_dir)
|
||||
#print(load_finetune_checkpoint_dir)
|
||||
return ckpt_list
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
def convert_labels_to_index(label_list):
|
||||
"""
|
||||
Convert label_list to indices for NER task.
|
||||
"""
|
||||
label2id = collections.OrderedDict()
|
||||
label2id["O"] = 0
|
||||
prefix = ["S_", "B_", "M_", "E_"]
|
||||
index = 0
|
||||
for label in label_list:
|
||||
for pre in prefix:
|
||||
index += 1
|
||||
sub_label = pre + label
|
||||
label2id[sub_label] = index
|
||||
return label2id
|
||||
|
||||
def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): current step
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_steps(int): number of warmup epochs
|
||||
total_steps(int): total epoch of training
|
||||
poly_power(int): poly learning rate power
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max - lr_end) * (base ** poly_power)
|
||||
lr = lr + lr_end
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = global_step
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
|
||||
total_steps=lr_total_steps, poly_power=lr_power)
|
||||
return Tensor(learning_rate)
|
||||
|
||||
|
||||
def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
|
||||
total_steps=damping_total_steps, poly_power=damping_power)
|
||||
return Tensor(damping)
|
Loading…
Reference in New Issue