forked from mindspore-Ecosystem/mindspore
!22522 add dynamic ranker network
Merge pull request !22522 from zhangyinxia/master
This commit is contained in:
commit
52cd372740
|
@ -0,0 +1,205 @@
|
|||
|
||||
# 目录
|
||||
|
||||
[View English](./README.md)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DYR概述](#dyr概述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [训练过程](#训练过程)
|
||||
- [导出模型](#导出模型)
|
||||
- [推理过程](#推理过程)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [参数说明](#参数说明)
|
||||
- [训练性能](#训练性能)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# DYR概述
|
||||
|
||||
DYR(Dynamic Ranker)模型是一款基于对比学习的分布式语义排序框架,它在2021年由华为泊松实验室提出,并联合分布式并行计算实验室进行开源发布。
|
||||
|
||||
# 模型架构
|
||||
|
||||
DYR模型主要由两个模块构成,一是正负样本块的横纵分布式切分模块;二是负样本多级压缩模块。通过这两个模块实现了高吞吐量和模型精度。
|
||||
|
||||
# 数据集
|
||||
|
||||
- 生成预训练数据集
|
||||
- 将需要训练和推理的数据集进行处理并转换为MindRecord格式。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend处理器)
|
||||
- 准备Ascend处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- 更多关于Mindspore的信息,请查看以下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
从官网下载安装MindSpore之后,您可以按照如下步骤在ModelArts上进行训练和评估,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
## 训练过程
|
||||
|
||||
- 在ModelArts上使用8卡训练
|
||||
|
||||
```python
|
||||
# (1) 上传你的代码到 s3 桶上
|
||||
# (2) 在ModelArts上创建训练任务
|
||||
# (3) 选择代码目录 /{path}/DYR
|
||||
# (4) 选择启动文件 /{path}/DYR/run_dyr.py
|
||||
# (5) 执行a或b
|
||||
# a. 在 /{path}/DYR/dyr_config.yaml 文件中设置参数
|
||||
# b. 设置 ”enable_modelarts=True“
|
||||
# c. 添加其它参数,其它参数配置可以参考参数说明文档
|
||||
# (6) 上传你的 数据 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 在网页上的’资源池选择‘项目下, 选择8卡规格的资源
|
||||
# (10) 创建训练作业
|
||||
# 训练结束后会在'训练输出文件路径'下保存训练的权重
|
||||
```
|
||||
|
||||
- 在ModelArts上运行过程中,您可以在ModelArts上查看训练日志,得到如下损失值:
|
||||
|
||||
```text
|
||||
# grep "epoch" *.log
|
||||
epoch: 1, current epoch percent: 1.000, step: 83002, outputs are (Tensor(shape=[], dtype=Float32, value= 2.19216), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 2048))
|
||||
epoch: 1, current epoch percent: 1.000, step: 83002, outputs are (Tensor(shape=[], dtype=Float32, value= 4.4673), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 2048))
|
||||
...
|
||||
```
|
||||
|
||||
## 导出模型
|
||||
|
||||
- 在ModelArts上导出模型
|
||||
设置推理验证集路径`save_finetune_checkpoint_path`,参数`ckpt_file` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。
|
||||
完成训练后,你将在{save_finetune_checkpoint_path}下看到 'dyr*.ckpt'文件
|
||||
|
||||
## 推理过程
|
||||
|
||||
- 在ModelArts上进行推理
|
||||
设置推理验证集路径`eval_data_file_path`和`do_eval=true`,ModelArts上会执行推理操作。
|
||||
完成推理后,可在ModelArts上日志中看到最终精度结果。
|
||||
|
||||
```eval log
|
||||
mrr@100:0.4306179881095886, mrr@10:0.42366212606430054
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
```shell
|
||||
.
|
||||
└─DYR
|
||||
├─README.md
|
||||
├─README_CN.md
|
||||
├─src
|
||||
├─model_utils
|
||||
├── config.py # 解析 *.yaml参数配置文件
|
||||
├── devcie_adapter.py # 区分本地/ModelArts训练
|
||||
├── local_adapter.py # 本地训练获取相关环境变量
|
||||
└── moxing_adapter.py # ModelArts训练获取相关环境变量、交换数据
|
||||
├─dynamic_ranker.py # 网络骨干编码
|
||||
├─bert_model.py # 网络骨干编码
|
||||
├─dataset.py # 数据预处理
|
||||
├─utils.py # util函数
|
||||
├─dyr_config.yaml # 训练评估参数配置
|
||||
└─run_dyr.py # dyr任务的训练和评估网络
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
- dyr_config.yaml参数详解
|
||||
|
||||
```text
|
||||
数据集和网络参数(训练/评估):
|
||||
dyr_version dyr版本,支持"dyr_base"和"dyr",默认为"dyr_base"
|
||||
do_train 是否执行训练操作,默认执行
|
||||
do_eval 是否执行训练操作,默认执行
|
||||
device_id 执行机器device,默认为0
|
||||
epoch_num 训练epoch的个数,默认为1
|
||||
group_size 选择正负样本个数,默认为8
|
||||
group_num 选择分组个数,默认为1
|
||||
train_data_shuffle 训练数据集是否执行shuffle,默认为true
|
||||
eval_data_shuffle 推理数据集是否执行shuffle,默认为false
|
||||
train_batch_size 输入训练数据集的批次大小,默认为1
|
||||
eval_batch_size 输入推理数据集的批次大小,默认为1
|
||||
save_finetune_checkpoint_path 保存训练checkpoint路径
|
||||
load_pretrain_checkpoint_path 加载预训练模型路径
|
||||
load_finetune_checkpoint_path 加载推理模型路径
|
||||
train_data_file_path 训练数据集路径
|
||||
eval_data_file_path 推理数据集路径
|
||||
eval_ids_path 推理数据集对应ids文件路径
|
||||
eval_qrels_path 推理数据集对应qrels文件路径
|
||||
save_score_path 保存结果文件路径
|
||||
schema_file_path 数据预处理配置文件路径
|
||||
optimizer 网络中采用的优化器,可选项为AdamWerigtDecayDynamicLR、Lamb、或Momentum,默认为Lamb
|
||||
seq_length 输入序列的长度,默认为512
|
||||
vocab_size 各内嵌向量大小,需与所采用的数据集相同。默认为30522
|
||||
hidden_size BERT的encoder层数,默认为768
|
||||
num_hidden_layers 隐藏层数,默认为12
|
||||
num_attention_heads 注意头的数量,默认为12
|
||||
intermediate_size 中间层数,默认为3072
|
||||
hidden_act 所采用的激活函数,默认为gelu
|
||||
hidden_dropout_prob BERT输出的随机失活可能性,默认为0.1
|
||||
attention_probs_dropout_prob BERT注意的随机失活可能性,默认为0.1
|
||||
max_position_embeddings 序列最大长度,默认为512
|
||||
type_vocab_size 标记类型的词汇表大小,默认为16
|
||||
initializer_range TruncatedNormal的初始值,默认为0.02
|
||||
use_relative_positions 是否采用相对位置,可选项为true或false,默认为False
|
||||
dtype 输入的数据类型,可选项为mstype.float16或mstype.float32,默认为mstype.float32
|
||||
compute_type Bert Transformer的计算类型,可选项为mstype.float16或mstype.float32,默认为mstype.float16
|
||||
|
||||
Parameters for optimizer:
|
||||
AdamWeightDecay:
|
||||
decay_steps 学习率开始衰减的步数
|
||||
learning_rate 学习率
|
||||
end_learning_rate 结束学习率,取值需为正数
|
||||
power 幂
|
||||
warmup_steps 热身学习率步数
|
||||
weight_decay 权重衰减
|
||||
eps 增加分母,提高小数稳定性
|
||||
Lamb:
|
||||
decay_steps 学习率开始衰减的步数
|
||||
learning_rate 学习率
|
||||
end_learning_rate 结束学习率
|
||||
power 幂
|
||||
warmup_steps 热身学习率步数
|
||||
weight_decay 权重衰减
|
||||
Momentum:
|
||||
learning_rate 学习率
|
||||
momentum 平均移动动量
|
||||
```
|
||||
|
||||
# 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ---------------------------------------------------------- |
|
||||
| 模型版本 | dyr_base |
|
||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755GB;系统 Euler2.8 |
|
||||
| 上传日期 | 2021-08-27 |
|
||||
| MindSpore版本 | 1.3.0 |
|
||||
| 数据集 | |
|
||||
| 训练参数 | dyr_config.yaml |
|
||||
| 优化器 | Lamb |
|
||||
| 损失函数 | SoftmaxCrossEntropyWithLogits |
|
||||
| 输出 | 概率 |
|
||||
| 轮次 | 2 |
|
||||
| Batch_size | 72*8 |
|
||||
| 损失 | 1.7 |
|
||||
| 速度 | 435毫秒/步 |
|
||||
| 总时长 | 10小时 |
|
||||
| 参数(M) | 110 |
|
||||
| 微调检查点 | 1.2G(.ckpt文件) |
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,106 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
description: "run_dyr"
|
||||
# DYR version: "dry_base" or "dyr"
|
||||
dyr_version: "dry_base"
|
||||
do_train: "true"
|
||||
do_eval: "true"
|
||||
device_id: 0
|
||||
epoch_num: 1
|
||||
group_size: 8
|
||||
group_num: 1
|
||||
train_data_shuffle: "true"
|
||||
eval_data_shuffle: "false"
|
||||
train_batch_size: 1
|
||||
eval_batch_size: 1
|
||||
save_finetune_checkpoint_path: "./classifier_finetune/ckpt/"
|
||||
load_pretrain_checkpoint_path: ""
|
||||
load_finetune_checkpoint_path: ""
|
||||
train_data_file_path: ""
|
||||
eval_data_file_path: ""
|
||||
eval_ids_path: "ids.tsv"
|
||||
eval_qrels_path: "msmarco-docdev-qrels.tsv"
|
||||
save_score_path: "score.txt"
|
||||
schema_file_path: ""
|
||||
|
||||
optimizer_cfg:
|
||||
optimizer: 'Lamb'
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.00001 # 1e-5
|
||||
end_learning_rate: 0.0000000001 # 1e-10
|
||||
power: 1.0
|
||||
weight_decay: 0.01 # 1e-5
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
eps: 0.000001 # 1e-6
|
||||
Lamb:
|
||||
learning_rate: 0.00001 # 1e-5,
|
||||
end_learning_rate: 0.0000001 # 1e-7
|
||||
power: 1.0
|
||||
weight_decay: 0.01
|
||||
decay_filter: ['layernorm', 'bias']
|
||||
Momentum:
|
||||
learning_rate: 0.00002 # 2e-5
|
||||
momentum: 0.9
|
||||
|
||||
dyr_net_cfg:
|
||||
seq_length: 512
|
||||
vocab_size: 30522
|
||||
hidden_size: 768
|
||||
num_hidden_layers: 12
|
||||
num_attention_heads: 12
|
||||
intermediate_size: 3072
|
||||
hidden_act: "gelu"
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_probs_dropout_prob: 0.1
|
||||
max_position_embeddings: 512
|
||||
type_vocab_size: 2
|
||||
initializer_range: 0.02
|
||||
use_relative_positions: False
|
||||
dtype: mstype.float32
|
||||
compute_type: mstype.float16
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
dyr_version: "dyr version"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend or GPU, now only supports Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
do_train: "Enable train, default is false"
|
||||
do_eval: "Enable eval, default is false"
|
||||
device_id: "Device id, default is 0."
|
||||
epoch_num: "Epoch number, default is 3."
|
||||
group_size: "Sample number in one block."
|
||||
group_num: "Cards are divided into groups."
|
||||
train_data_shuffle: "Enable train data shuffle, default is true"
|
||||
eval_data_shuffle: "Enable eval data shuffle, default is false"
|
||||
train_batch_size: "Train batch size, default is 32"
|
||||
eval_batch_size: "Eval batch size, default is 1"
|
||||
save_finetune_checkpoint_path: "Save checkpoint path"
|
||||
load_pretrain_checkpoint_path: "Load checkpoint file path"
|
||||
load_finetune_checkpoint_path: "Load checkpoint file path"
|
||||
train_data_file_path: "Data path, it is better to use absolute path"
|
||||
eval_data_file_path: "Data path, it is better to use absolute path"
|
||||
eval_ids_path: "Ids path, it is better to use absolute path"
|
||||
eval_qrels_path: "Qrels path, it is better to use absolute path"
|
||||
save_score_path: "Score path, it is better to use absolute path"
|
||||
schema_file_path: "Schema path, it is better to use absolute path"
|
||||
---
|
||||
device_target: ['Ascend']
|
||||
do_train: ["true", "false"]
|
||||
do_eval: ["true", "false"]
|
||||
train_data_shuffle: ["true", "false"]
|
||||
eval_data_shuffle: ["true", "false"]
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
dynamic ranker train and evaluation script.
|
||||
'''
|
||||
|
||||
import os
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from src.dynamic_ranker import DynamicRankerPredict, DynamicRankerFinetuneCell, DynamicRankerBase, DynamicRanker
|
||||
from src.dataset import create_dyr_base_dataset, create_dyr_dataset_predict, create_dyr_dataset
|
||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, DynamicRankerLearningRate, MRR
|
||||
from src.model_utils.config import config as args_opt, optimizer_cfg, dyr_net_cfg
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1, prefix="dyr"):
|
||||
""" do train """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
# optimizer
|
||||
if optimizer_cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = DynamicRankerLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
elif optimizer_cfg.optimizer == 'Lamb':
|
||||
lr_schedule = DynamicRankerLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
|
||||
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=optimizer_cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
|
||||
elif optimizer_cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
|
||||
momentum=optimizer_cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**16, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = DynamicRankerFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
|
||||
model.train(epoch_num, dataset, callbacks=callbacks)
|
||||
|
||||
def do_predict(rank_id=0, dataset=None, network=None, load_checkpoint_path="",
|
||||
eval_ids_path="", eval_qrels_path="", save_score_path=""):
|
||||
""" do eval """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
net_for_pretraining = network(dyr_net_cfg, False, dropout_prob=0.0)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(net_for_pretraining, param_dict)
|
||||
model = Model(net_for_pretraining)
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids"]
|
||||
loss = []
|
||||
for data in dataset.create_dict_iterator(num_epochs=1):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id = input_data
|
||||
logits = model.predict(input_ids, input_mask, token_type_id)
|
||||
print(logits)
|
||||
logits = logits[0][0].asnumpy()
|
||||
loss.append(logits)
|
||||
pred_qids = []
|
||||
pred_pids = []
|
||||
with open(eval_ids_path) as f:
|
||||
for l in f:
|
||||
q, p = l.split()
|
||||
pred_qids.append(q)
|
||||
pred_pids.append(p)
|
||||
if len(pred_qids) != len(loss):
|
||||
raise ValueError("len(pred_qids) != len(loss)!")
|
||||
|
||||
with open(save_score_path, "w") as writer:
|
||||
for qid, pid, score in zip(pred_qids, pred_pids, loss):
|
||||
writer.write(f'{qid}\t{pid}\t{score}\n')
|
||||
|
||||
mrr = MRR()
|
||||
mrr.accuracy(qrels_path=eval_qrels_path, scores_path=save_score_path)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
args_opt.device_id = get_device_id()
|
||||
_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
args_opt.load_pretrain_checkpoint_path = os.path.join(_file_dir, args_opt.load_pretrain_checkpoint_path)
|
||||
args_opt.load_finetune_checkpoint_path = os.path.join(args_opt.output_path, args_opt.load_finetune_checkpoint_path)
|
||||
args_opt.save_finetune_checkpoint_path = os.path.join(args_opt.output_path, args_opt.save_finetune_checkpoint_path)
|
||||
if args_opt.schema_file_path:
|
||||
args_opt.schema_file_path = os.path.join(args_opt.data_path, args_opt.schema_file_path)
|
||||
args_opt.train_data_file_path = os.path.join(args_opt.data_path, args_opt.train_data_file_path)
|
||||
args_opt.eval_data_file_path = os.path.join(args_opt.data_path, args_opt.eval_data_file_path)
|
||||
args_opt.save_score_path = os.path.join(args_opt.output_path, args_opt.save_score_path)
|
||||
args_opt.eval_ids_path = os.path.join(args_opt.data_path, args_opt.eval_ids_path)
|
||||
args_opt.eval_qrels_path = os.path.join(args_opt.data_path, args_opt.eval_qrels_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_dyr():
|
||||
"""run dyr task"""
|
||||
set_seed(1234)
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
epoch_num = args_opt.epoch_num
|
||||
load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
|
||||
save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
|
||||
load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
|
||||
eval_ids_path = args_opt.eval_ids_path
|
||||
eval_qrels_path = args_opt.eval_qrels_path
|
||||
save_score_path = args_opt.save_score_path
|
||||
target = args_opt.device_target
|
||||
if target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
device_id=args_opt.device_id, save_graphs=False)
|
||||
else:
|
||||
raise Exception("Target error, Ascend is supported.")
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank_id = D.get_rank()
|
||||
save_finetune_checkpoint_path = os.path.join(args_opt.save_finetune_checkpoint_path, 'ckpt_' + str(rank_id))
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
|
||||
create_dataset = create_dyr_base_dataset
|
||||
dyr_network = DynamicRankerBase
|
||||
if args_opt.dyr_version.lower() == "dyr":
|
||||
create_dataset = create_dyr_dataset
|
||||
dyr_network = DynamicRanker
|
||||
|
||||
netwithloss = dyr_network(dyr_net_cfg, True, dropout_prob=0.1,
|
||||
batch_size=args_opt.train_batch_size,
|
||||
group_size=args_opt.group_size,
|
||||
group_num=args_opt.group_num,
|
||||
rank_id=rank_id,
|
||||
device_num=device_num)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_dataset(device_num, rank_id, batch_size=args_opt.train_batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
group_size=args_opt.group_size, group_num=args_opt.group_num)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path,
|
||||
epoch_num, args_opt.dyr_version.lower())
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if save_finetune_checkpoint_path == "":
|
||||
load_finetune_checkpoint_dir = _cur_dir
|
||||
else:
|
||||
load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path)
|
||||
load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
|
||||
ds.get_dataset_size(), epoch_num,
|
||||
args_opt.dyr_version.lower())
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if rank_id == 0:
|
||||
ds = create_dyr_dataset_predict(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_predict(rank_id, ds, DynamicRankerPredict, load_finetune_checkpoint_path,
|
||||
eval_ids_path, eval_qrels_path, save_score_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_dyr()
|
|
@ -0,0 +1,874 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert model."""
|
||||
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class BertConfig:
|
||||
"""
|
||||
Configuration for `BertModel`.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence. Default: 128.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
||||
cell. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in the BertTransformer
|
||||
encoder cell. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
||||
encoder cell. Default: 3072.
|
||||
hidden_act (str): Activation function used in the BertTransformer encoder
|
||||
cell. Default: "gelu".
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
seq_length=128,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32):
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""Get output and embeddings lookup table"""
|
||||
extended_ids = self.expand(input_ids, -1)
|
||||
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output = self.reshape(output_for_reshape, self.shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingPostprocessor(nn.Cell):
|
||||
"""
|
||||
Postprocessors apply positional and token type embeddings to word embeddings.
|
||||
|
||||
Args:
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_relative_positions=False,
|
||||
use_token_type=False,
|
||||
token_type_vocab_size=16,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=512,
|
||||
dropout_prob=0.1):
|
||||
super(EmbeddingPostprocessor, self).__init__()
|
||||
self.use_token_type = use_token_type
|
||||
self.token_type_vocab_size = token_type_vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.token_type_embedding = nn.Embedding(
|
||||
vocab_size=token_type_vocab_size,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.1, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
_, seq, _ = self.shape
|
||||
self.full_position_embedding = nn.Embedding(
|
||||
vocab_size=max_position_embeddings,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=False)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
token_type_embeddings = self.token_type_embedding(token_type_ids)
|
||||
output = self.add(output, token_type_embeddings)
|
||||
if not self.use_relative_positions:
|
||||
position_embeddings = self.full_position_embedding(self.position_ids)
|
||||
output = self.add(output, position_embeddings)
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class BertOutput(nn.Cell):
|
||||
"""
|
||||
Apply a linear computation to hidden status and a residual computation to input.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
out_channels (int): Output channels.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
initializer_range=0.02,
|
||||
dropout_prob=0.1,
|
||||
compute_type=mstype.float32):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
return output
|
||||
|
||||
|
||||
class RelaPosMatrixGenerator(nn.Cell):
|
||||
"""
|
||||
Generates matrix of relative positions between inputs.
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
max_relative_position (int): Max value of relative position.
|
||||
"""
|
||||
def __init__(self, length, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._length = length
|
||||
self._max_relative_position = max_relative_position
|
||||
self._min_relative_position = -max_relative_position
|
||||
self.range_length = -length + 1
|
||||
|
||||
self.tile = P.Tile()
|
||||
self.range_mat = P.Reshape()
|
||||
self.sub = P.Sub()
|
||||
self.expanddims = P.ExpandDims()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self):
|
||||
"""Generates matrix of relative positions between inputs."""
|
||||
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
||||
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
||||
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
||||
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
|
||||
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
|
||||
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
final_mat = distance_mat_clipped + self._max_relative_position
|
||||
return final_mat
|
||||
|
||||
|
||||
class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||
"""
|
||||
Generates tensor of size [length, length, depth].
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
length,
|
||||
depth,
|
||||
max_relative_position,
|
||||
initializer_range,
|
||||
use_one_hot_embeddings=False):
|
||||
super(RelaPosEmbeddingsGenerator, self).__init__()
|
||||
self.depth = depth
|
||||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
[self.vocab_size, self.depth]))
|
||||
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
"""Generate embedding for each relative position of dimension depth."""
|
||||
relative_positions_matrix_out = self.relative_positions_matrix()
|
||||
|
||||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
flat_relative_positions_matrix)
|
||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||
embeddings = self.reshape(embeddings, my_shape)
|
||||
else:
|
||||
embeddings = self.gather(self.embeddings_table,
|
||||
relative_positions_matrix_out, 0)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""
|
||||
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
||||
the danger that the value will overflow or underflow.
|
||||
|
||||
Args:
|
||||
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
|
||||
self.tensor_min_type = float(np.finfo(np_type).min)
|
||||
self.tensor_max_type = float(np.finfo(np_type).max)
|
||||
|
||||
self.min_op = P.Minimum()
|
||||
self.max_op = P.Maximum()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
out = self.max_op(x, self.tensor_min_type)
|
||||
out = self.min_op(out, self.tensor_max_type)
|
||||
return self.cast(out, self.dst_type)
|
||||
|
||||
|
||||
class BertAttention(nn.Cell):
|
||||
"""
|
||||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
||||
Args:
|
||||
from_tensor_width (int): Size of last dim of from_tensor.
|
||||
to_tensor_width (int): Size of last dim of to_tensor.
|
||||
from_seq_length (int): Length of from_tensor sequence.
|
||||
to_seq_length (int): Length of to_tensor sequence.
|
||||
num_attention_heads (int): Number of attention heads. Default: 1.
|
||||
size_per_head (int): Size of each attention head. Default: 512.
|
||||
query_act (str): Activation function for the query transform. Default: None.
|
||||
key_act (str): Activation function for the key transform. Default: None.
|
||||
value_act (str): Activation function for the value transform. Default: None.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||
tensor. Default: False.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
from_tensor_width,
|
||||
to_tensor_width,
|
||||
from_seq_length,
|
||||
to_seq_length,
|
||||
num_attention_heads=1,
|
||||
size_per_head=512,
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
has_attention_mask=False,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
do_return_2d_tensor=False,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
|
||||
super(BertAttention, self).__init__()
|
||||
self.from_seq_length = from_seq_length
|
||||
self.to_seq_length = to_seq_length
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.size_per_head = size_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.use_relative_positions = use_relative_positions
|
||||
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
activation=query_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
|
||||
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
|
||||
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.multiply = P.Mul()
|
||||
self.transpose = P.Transpose()
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = -10000.0
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if do_return_2d_tensor:
|
||||
self.shape_return = (-1, num_attention_heads * size_per_head)
|
||||
else:
|
||||
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
if self.use_relative_positions:
|
||||
self._generate_relative_positions_embeddings = \
|
||||
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
||||
depth=size_per_head,
|
||||
max_relative_position=16,
|
||||
initializer_range=initializer_range,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""reshape 2d/3d input tensors to 2d"""
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
query_out = self.query_layer(from_tensor_2d)
|
||||
key_out = self.key_layer(to_tensor_2d)
|
||||
value_out = self.value_layer(to_tensor_2d)
|
||||
|
||||
query_layer = self.reshape(query_out, self.shape_from)
|
||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||
key_layer = self.reshape(key_out, self.shape_to)
|
||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||
|
||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_keys is [F|T, F|T, H]
|
||||
relations_keys = self._generate_relative_positions_embeddings()
|
||||
relations_keys = self.cast_compute_type(relations_keys)
|
||||
# query_layer_t is [F, B, N, H]
|
||||
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||
# query_layer_r is [F, B * N, H]
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
# key_position_scores is [F, B * N, F|T]
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
relations_keys)
|
||||
# key_position_scores_r is [F, B, N, F|T]
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.from_seq_length))
|
||||
# key_position_scores_r_t is [B, N, F, F|T]
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
attention_scores = attention_scores + key_position_scores_r_t
|
||||
|
||||
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||
|
||||
if self.has_attention_mask:
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
||||
|
||||
adder = self.multiply(multiply_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
value_layer = self.reshape(value_out, self.shape_to)
|
||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||
context_layer = self.matmul(attention_probs, value_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
# relations_values is [F|T, F|T, H]
|
||||
relations_values = self._generate_relative_positions_embeddings()
|
||||
relations_values = self.cast_compute_type(relations_values)
|
||||
# attention_probs_t is [F, B, N, T]
|
||||
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||
# attention_probs_r is [F, B * N, T]
|
||||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.to_seq_length))
|
||||
# value_position_scores is [F, B * N, H]
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
relations_values)
|
||||
# value_position_scores_r is [F, B, N, H]
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(self.from_seq_length,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
# value_position_scores_r_t is [B, N, F, H]
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
context_layer = context_layer + value_position_scores_r_t
|
||||
|
||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||
context_layer = self.reshape(context_layer, self.shape_return)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Cell):
|
||||
"""
|
||||
Apply self-attention.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence.
|
||||
hidden_size (int): Size of the bert encoder layers.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (hidden_size, num_attention_heads))
|
||||
|
||||
self.size_per_head = int(hidden_size / num_attention_heads)
|
||||
|
||||
self.attention = BertAttention(
|
||||
from_tensor_width=hidden_size,
|
||||
to_tensor_width=hidden_size,
|
||||
from_seq_length=seq_length,
|
||||
to_seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
size_per_head=self.size_per_head,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
use_relative_positions=use_relative_positions,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
compute_type=compute_type)
|
||||
|
||||
self.output = BertOutput(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
input_tensor = self.reshape(input_tensor, self.shape)
|
||||
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
|
||||
output = self.output(attention_output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class BertEncoderCell(nn.Cell):
|
||||
"""
|
||||
Encoder cells used in BertTransformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
seq_length (int): Length of input sequence. Default: 512.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.02.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
seq_length=512,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.02,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32):
|
||||
super(BertEncoderCell, self).__init__()
|
||||
self.attention = BertSelfAttention(
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
compute_type=compute_type)
|
||||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
# self-attention
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
# add and normalize
|
||||
output = self.output(intermediate_output, attention_output)
|
||||
return output
|
||||
|
||||
|
||||
class BertTransformer(nn.Cell):
|
||||
"""
|
||||
Multi-layer bert transformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the encoder layers.
|
||||
seq_length (int): Length of input sequence.
|
||||
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
seq_length,
|
||||
num_hidden_layers,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
return_all_encoders=False):
|
||||
super(BertTransformer, self).__init__()
|
||||
self.return_all_encoders = return_all_encoders
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = BertEncoderCell(hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
self.out_shape = (-1, seq_length, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""Multi-layer bert transformer."""
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
all_encoder_layers = ()
|
||||
for layer_module in self.layers:
|
||||
layer_output = layer_module(prev_output, attention_mask)
|
||||
prev_output = layer_output
|
||||
|
||||
if self.return_all_encoders:
|
||||
layer_output = self.reshape(layer_output, self.out_shape)
|
||||
all_encoder_layers = all_encoder_layers + (layer_output,)
|
||||
|
||||
if not self.return_all_encoders:
|
||||
prev_output = self.reshape(prev_output, self.out_shape)
|
||||
all_encoder_layers = all_encoder_layers + (prev_output,)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask = None
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 1, config.seq_length)
|
||||
|
||||
def construct(self, input_mask):
|
||||
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BertModel(nn.Cell):
|
||||
"""
|
||||
Bidirectional Encoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
|
||||
self.seq_length = config.seq_length
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embedding_size = config.hidden_size
|
||||
self.token_type_ids = None
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [-1, self.seq_length, self.embedding_size]
|
||||
|
||||
self.bert_embedding_lookup = nn.Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings,
|
||||
embedding_table=TruncatedNormal(config.initializer_range))
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
use_token_type=True,
|
||||
token_type_vocab_size=config.type_vocab_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
dropout_prob=config.hidden_dropout_prob)
|
||||
|
||||
self.bert_encoder = BertTransformer(
|
||||
hidden_size=self.hidden_size,
|
||||
seq_length=self.seq_length,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=config.intermediate_size,
|
||||
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range,
|
||||
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
return_all_encoders=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
word_embeddings = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
|
||||
# bert encoder
|
||||
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
|
||||
attention_mask)
|
||||
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
|
||||
return sequence_output, pooled_output, embedding_tables
|
|
@ -0,0 +1,334 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in run_pretrain.py
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
# samples in one block
|
||||
POS_SIZE = 1
|
||||
# max seq length
|
||||
SEQ_LEN = 512
|
||||
# rand id
|
||||
RANK_ID = 0
|
||||
# pos and neg samples in one minibatch on one device id
|
||||
GROUP_SIZE = 8
|
||||
# group number
|
||||
GROUP_NUM = 1
|
||||
# device number
|
||||
DEVICE_NUM = 1
|
||||
# batch size
|
||||
BATCH_SIZE = 1
|
||||
|
||||
def process_samples_base(input_ids, input_mask, segment_ids, label_ids):
|
||||
"""create block of samples"""
|
||||
random.seed(1)
|
||||
global GROUP_SIZE, SEQ_LEN
|
||||
neg_len = GROUP_SIZE - 1
|
||||
input_ids = input_ids.reshape(-1, SEQ_LEN)
|
||||
input_mask = input_mask.reshape(-1, SEQ_LEN)
|
||||
segment_ids = segment_ids.reshape(-1, SEQ_LEN)
|
||||
label_ids = label_ids.reshape(-1, 1)
|
||||
input_ids_l = input_ids.tolist()
|
||||
input_mask_l = input_mask.tolist()
|
||||
segment_ids_l = segment_ids.tolist()
|
||||
label_ids_l = label_ids.tolist()
|
||||
|
||||
temp = []
|
||||
for i in range(1, len(input_ids_l)):
|
||||
temp.append({"input_ids": input_ids_l[i],
|
||||
"input_mask": input_mask_l[i],
|
||||
"segment_ids": segment_ids_l[i],
|
||||
"label_ids": label_ids_l[i]})
|
||||
negs = []
|
||||
if len(temp) < neg_len:
|
||||
negs = random.choices(temp, k=neg_len)
|
||||
else:
|
||||
negs = random.sample(temp, k=neg_len)
|
||||
input_ids_n = [input_ids_l.pop(0)]
|
||||
input_mask_n = [input_mask_l.pop(0)]
|
||||
segment_ids_n = [segment_ids_l.pop(0)]
|
||||
label_ids_n = [label_ids_l.pop(0)]
|
||||
for i in range(neg_len):
|
||||
input_ids_n.append(negs[i]["input_ids"])
|
||||
input_mask_n.append(negs[i]["input_mask"])
|
||||
segment_ids_n.append(negs[i]["segment_ids"])
|
||||
label_ids_n.append(negs[i]["label_ids"])
|
||||
input_ids = np.array(input_ids_n, dtype=np.int64)
|
||||
input_mask = np.array(input_mask_n, dtype=np.int64)
|
||||
segment_ids = np.array(segment_ids_n, dtype=np.int64)
|
||||
label_ids = np.array(label_ids_n, dtype=np.int64)
|
||||
|
||||
input_ids = input_ids.reshape(-1, SEQ_LEN)
|
||||
input_mask = input_mask.reshape(-1, SEQ_LEN)
|
||||
segment_ids = segment_ids.reshape(-1, SEQ_LEN)
|
||||
label_ids = label_ids.reshape(-1, POS_SIZE)
|
||||
return input_ids, input_mask, segment_ids, label_ids
|
||||
|
||||
def samples_base(input_ids, input_mask, segment_ids, label_ids):
|
||||
"""split samples for device"""
|
||||
global GROUP_SIZE, GROUP_NUM, RANK_ID, SEQ_LEN, BATCH_SIZE, DEVICE_NUM
|
||||
out_ids = []
|
||||
out_mask = []
|
||||
out_seg = []
|
||||
out_label = []
|
||||
assert len(input_ids) >= len(input_mask)
|
||||
assert len(input_ids) >= len(segment_ids)
|
||||
assert len(input_ids) >= len(label_ids)
|
||||
group_id = RANK_ID * GROUP_NUM // DEVICE_NUM
|
||||
begin0 = BATCH_SIZE * group_id
|
||||
end0 = (group_id + 1) * BATCH_SIZE
|
||||
begin = (RANK_ID % (DEVICE_NUM//GROUP_NUM)) * GROUP_NUM * GROUP_SIZE // DEVICE_NUM
|
||||
end = ((RANK_ID % (DEVICE_NUM//GROUP_NUM)) + 1) * GROUP_NUM * GROUP_SIZE // DEVICE_NUM
|
||||
begin_temp = begin
|
||||
end_temp = end
|
||||
for i in range(begin0, end0):
|
||||
ids, mask, seg, lab = input_ids[i], input_mask[i], segment_ids[i], label_ids[i]
|
||||
if begin_temp > len(input_ids[i]):
|
||||
begin_temp = begin_temp - len(input_ids[i])
|
||||
end_temp = end_temp - len(input_ids[i])
|
||||
continue
|
||||
ids = ids.reshape(-1, SEQ_LEN)
|
||||
mask = mask.reshape(-1, SEQ_LEN)
|
||||
seg = seg.reshape(-1, SEQ_LEN)
|
||||
lab = lab.reshape(-1, 1)
|
||||
ids = ids[begin_temp:end_temp]
|
||||
mask = mask[begin_temp:end_temp]
|
||||
seg = seg[begin_temp:end_temp]
|
||||
lab = lab[begin_temp:end_temp]
|
||||
out_ids.append(ids)
|
||||
out_mask.append(mask)
|
||||
out_seg.append(seg)
|
||||
out_label.append(lab)
|
||||
begin_temp = begin
|
||||
end_temp = end
|
||||
input_ids = np.array(out_ids, dtype=np.int64)
|
||||
input_mask = np.array(out_mask, dtype=np.int64)
|
||||
segment_ids = np.array(out_seg, dtype=np.int64)
|
||||
label_ids = np.array(out_label, dtype=np.int64)
|
||||
return input_ids, input_mask, segment_ids, label_ids
|
||||
|
||||
def create_dyr_base_dataset(device_num=1, rank=0, batch_size=1, repeat_count=1, dataset_format="mindrecord",
|
||||
data_file_path=None, schema_file_path=None, do_shuffle=True,
|
||||
group_size=1, group_num=1, seq_len=512):
|
||||
"""create finetune dataset"""
|
||||
global GROUP_SIZE, GROUP_NUM, RANK_ID, SEQ_LEN, BATCH_SIZE, DEVICE_NUM
|
||||
GROUP_SIZE = group_size
|
||||
GROUP_NUM = group_num
|
||||
RANK_ID = rank
|
||||
SEQ_LEN = seq_len
|
||||
BATCH_SIZE = batch_size
|
||||
DEVICE_NUM = device_num
|
||||
print("device_num = %d, rank_id = %d, batch_size = %d" %(device_num, rank, batch_size))
|
||||
print("group_size = %d, group_num = %d, seq_len = %d" %(group_size, group_num, seq_len))
|
||||
|
||||
divide = (group_size * group_num) % device_num
|
||||
assert divide == 0
|
||||
assert device_num >= group_num
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds.config.set_seed(1000)
|
||||
random.seed(1)
|
||||
data_files = []
|
||||
if ".mindrecord" in data_file_path:
|
||||
data_files.append(data_file_path)
|
||||
else:
|
||||
files = os.listdir(data_file_path)
|
||||
for file_name in files:
|
||||
if "mindrecord" in file_name and "mindrecord.db" not in file_name:
|
||||
data_files.append(os.path.join(data_file_path, file_name))
|
||||
|
||||
if dataset_format == "mindrecord":
|
||||
data_set = ds.MindDataset(data_files,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
else:
|
||||
data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
|
||||
data_set = data_set.map(operations=process_samples_base,
|
||||
input_columns=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
batch_size = batch_size * group_num
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
data_set = data_set.map(operations=samples_base,
|
||||
input_columns=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
data_set = data_set.repeat(repeat_count)
|
||||
# apply batch operations
|
||||
return data_set
|
||||
|
||||
def process_samples(input_ids, input_mask, segment_ids, label_ids):
|
||||
"""create block of samples"""
|
||||
random.seed(None)
|
||||
rand_id = random.sample(range(0, 15), 15)
|
||||
random.seed(1)
|
||||
global GROUP_SIZE, SEQ_LEN
|
||||
neg_len = GROUP_SIZE - 1
|
||||
input_ids = input_ids.reshape(-1, SEQ_LEN)
|
||||
input_mask = input_mask.reshape(-1, SEQ_LEN)
|
||||
segment_ids = segment_ids.reshape(-1, SEQ_LEN)
|
||||
label_ids = label_ids.reshape(-1, 1)
|
||||
input_ids_l = input_ids.tolist()
|
||||
input_mask_l = input_mask.tolist()
|
||||
segment_ids_l = segment_ids.tolist()
|
||||
label_ids_l = label_ids.tolist()
|
||||
|
||||
temp = []
|
||||
for i in range(1, len(input_ids_l)):
|
||||
temp.append({"input_ids": input_ids_l[i],
|
||||
"input_mask": input_mask_l[i],
|
||||
"segment_ids": segment_ids_l[i],
|
||||
"label_ids": label_ids_l[i]})
|
||||
negs = []
|
||||
if len(temp) < neg_len:
|
||||
negs = random.choices(temp, k=neg_len)
|
||||
else:
|
||||
negs = random.sample(temp, k=neg_len)
|
||||
input_ids_n = [input_ids_l.pop(0)]
|
||||
input_mask_n = [input_mask_l.pop(0)]
|
||||
segment_ids_n = [segment_ids_l.pop(0)]
|
||||
label_ids_n = [label_ids_l.pop(0)]
|
||||
for i in range(neg_len):
|
||||
input_ids_n.append(negs[i]["input_ids"])
|
||||
input_mask_n.append(negs[i]["input_mask"])
|
||||
segment_ids_n.append(negs[i]["segment_ids"])
|
||||
label_ids_n.append(negs[i]["label_ids"])
|
||||
input_ids = np.array(input_ids_n, dtype=np.int64)
|
||||
input_mask = np.array(input_mask_n, dtype=np.int64)
|
||||
segment_ids = np.array(segment_ids_n, dtype=np.int64)
|
||||
label_ids = np.array(label_ids_n, dtype=np.int64)
|
||||
|
||||
input_ids = input_ids.reshape(-1, SEQ_LEN)
|
||||
input_mask = input_mask.reshape(-1, SEQ_LEN)
|
||||
segment_ids = segment_ids.reshape(-1, SEQ_LEN)
|
||||
label_ids = label_ids.reshape(-1, POS_SIZE)
|
||||
|
||||
label_ids = np.array(rand_id, dtype=np.int64)
|
||||
label_ids = label_ids.reshape(-1, 15)
|
||||
return input_ids, input_mask, segment_ids, label_ids
|
||||
|
||||
def samples(input_ids, input_mask, segment_ids, label_ids):
|
||||
"""split samples for device"""
|
||||
global GROUP_SIZE, GROUP_NUM, RANK_ID, SEQ_LEN, BATCH_SIZE, DEVICE_NUM
|
||||
out_ids = []
|
||||
out_mask = []
|
||||
out_seg = []
|
||||
out_label = []
|
||||
assert len(input_ids) >= len(input_mask)
|
||||
assert len(input_ids) >= len(segment_ids)
|
||||
assert len(input_ids) >= len(label_ids)
|
||||
group_id = RANK_ID * GROUP_NUM // DEVICE_NUM
|
||||
begin0 = BATCH_SIZE * group_id
|
||||
end0 = (group_id + 1) * BATCH_SIZE
|
||||
begin = (RANK_ID % (DEVICE_NUM // GROUP_NUM)) * GROUP_NUM * GROUP_SIZE // DEVICE_NUM
|
||||
end = ((RANK_ID % (DEVICE_NUM // GROUP_NUM)) + 1) * GROUP_NUM * GROUP_SIZE // DEVICE_NUM
|
||||
begin_temp = begin
|
||||
end_temp = end
|
||||
for i in range(begin0, end0):
|
||||
ids, mask, seg, lab = input_ids[i], input_mask[i], segment_ids[i], label_ids[i]
|
||||
if begin_temp > len(input_ids[i]):
|
||||
begin_temp = begin_temp - len(input_ids[i])
|
||||
end_temp = end_temp - len(input_ids[i])
|
||||
continue
|
||||
ids = ids.reshape(-1, SEQ_LEN)
|
||||
mask = mask.reshape(-1, SEQ_LEN)
|
||||
seg = seg.reshape(-1, SEQ_LEN)
|
||||
lab = lab.reshape(-1, 15)
|
||||
ids = ids[begin_temp:end_temp]
|
||||
mask = mask[begin_temp:end_temp]
|
||||
seg = seg[begin_temp:end_temp]
|
||||
out_ids.append(ids)
|
||||
out_mask.append(mask)
|
||||
out_seg.append(seg)
|
||||
out_label.append(lab)
|
||||
begin_temp = begin
|
||||
end_temp = end
|
||||
input_ids = np.array(out_ids, dtype=np.int64)
|
||||
input_mask = np.array(out_mask, dtype=np.int64)
|
||||
segment_ids = np.array(out_seg, dtype=np.int64)
|
||||
label_ids = np.array(out_label, dtype=np.int64)
|
||||
return input_ids, input_mask, segment_ids, label_ids
|
||||
|
||||
def create_dyr_dataset(device_num=1, rank=0, batch_size=1, repeat_count=1, dataset_format="mindrecord",
|
||||
data_file_path=None, schema_file_path=None, do_shuffle=True,
|
||||
group_size=1, group_num=1, seq_len=512):
|
||||
"""create finetune dataset"""
|
||||
global GROUP_SIZE, GROUP_NUM, RANK_ID, SEQ_LEN, BATCH_SIZE, DEVICE_NUM
|
||||
GROUP_SIZE = group_size
|
||||
GROUP_NUM = group_num
|
||||
RANK_ID = rank
|
||||
SEQ_LEN = seq_len
|
||||
BATCH_SIZE = batch_size
|
||||
DEVICE_NUM = device_num
|
||||
print("device_num = %d, rank_id = %d, batch_size = %d" %(device_num, rank, batch_size))
|
||||
print("group_size = %d, group_num = %d, seq_len = %d" %(group_size, group_num, seq_len))
|
||||
|
||||
divide = (group_size * group_num) % device_num
|
||||
assert divide == 0
|
||||
assert device_num >= group_num
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds.config.set_seed(1000)
|
||||
random.seed(1)
|
||||
data_files = []
|
||||
if ".mindrecord" in data_file_path:
|
||||
data_files.append(data_file_path)
|
||||
else:
|
||||
files = os.listdir(data_file_path)
|
||||
for file_name in files:
|
||||
if "mindrecord" in file_name and "mindrecord.db" not in file_name:
|
||||
data_files.append(os.path.join(data_file_path, file_name))
|
||||
|
||||
if dataset_format == "mindrecord":
|
||||
data_set = ds.MindDataset(data_files,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
else:
|
||||
data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
|
||||
data_set = data_set.map(operations=process_samples,
|
||||
input_columns=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
batch_size = batch_size * group_num
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
data_set = data_set.map(operations=samples, input_columns=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
data_set = data_set.repeat(repeat_count)
|
||||
return data_set
|
||||
def create_dyr_dataset_predict(batch_size=1, repeat_count=1, dataset_format="mindrecord",
|
||||
data_file_path=None, schema_file_path=None, do_shuffle=True):
|
||||
"""create evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
data_set = ds.MindDataset([data_file_path],
|
||||
columns_list=["input_ids", "input_mask", "segment_ids"],
|
||||
shuffle=do_shuffle)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
data_set = data_set.repeat(repeat_count)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
|
@ -0,0 +1,323 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
DynamicRanker network script.
|
||||
'''
|
||||
import numpy as np
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from src.bert_model import BertModel
|
||||
|
||||
class DynamicRankerModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for DynamicRanker task evaluation.
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
dropout_prob (float): The dropout probability for DynamicRankerModel. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self, config, is_training, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(DynamicRankerModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.hidden_probs_dropout_prob = 0.0
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dtype = config.dtype
|
||||
self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
_, pooled_output, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
cls = self.cast(pooled_output, self.dtype)
|
||||
cls = self.dropout(cls)
|
||||
logits = self.dense_1(cls)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
return logits
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class DynamicRankerFinetuneCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Especially defined for finetuning where only four inputs tensor are needed.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Different from the builtin loss_scale wrapper cell, we apply grad_clip before the optimization.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(DynamicRankerFinetuneCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
sens=None):
|
||||
"""DynamicRanker Finetune"""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
Calculate the cross entropy loss
|
||||
Inputs:
|
||||
logits: the output logits of the backbone
|
||||
label: the ground truth label of the sample
|
||||
Returns:
|
||||
loss: Tensor, the corrsponding cross entropy loss
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean()
|
||||
self.one_hot = P.OneHot()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero)
|
||||
loss = self.cross_entropy(logits, label)[0]
|
||||
loss = self.mean(loss, (-1,))
|
||||
return loss
|
||||
|
||||
class DynamicRankerBase(nn.Cell):
|
||||
"""
|
||||
Train interface for DynamicRanker base finetuning task.
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
dropout_prob (float): The dropout probability for DynamicRankerModel. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
batch_size (int): size of input batch.
|
||||
group_size (int): group size of block.
|
||||
group_num (int): group number of block.
|
||||
rank_id (int): rank id of device.
|
||||
device_num (int): number of device.
|
||||
seq_len (int): Length of input sequence.
|
||||
"""
|
||||
def __init__(self, config, is_training, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
batch_size=64, group_size=8, group_num=2, rank_id=0, device_num=1, seq_len=512):
|
||||
super(DynamicRankerBase, self).__init__()
|
||||
self.bert = DynamicRankerModel(config, is_training, dropout_prob, use_one_hot_embeddings)
|
||||
self.is_training = is_training
|
||||
self.labels = Tensor(np.zeros([batch_size]).astype(np.int32))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_flat = (batch_size, -1)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.allgather = ops.AllGather()
|
||||
self.loss = CrossEntropyLoss()
|
||||
self.slice = ops.Slice()
|
||||
self.group_id = rank_id * group_num // device_num
|
||||
self.begin = group_size * batch_size * self.group_id
|
||||
self.size = group_size * batch_size
|
||||
self.transpose = P.Transpose()
|
||||
self.shape1 = (device_num // group_num, batch_size, -1)
|
||||
self.shape2 = (batch_size, -1)
|
||||
self.trans_shape = (1, 0, 2)
|
||||
self.seq_len = seq_len
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
"""
|
||||
construct interface for DynamicRanker base finetuning task.
|
||||
"""
|
||||
input_ids = P.Reshape()(input_ids, (-1, self.seq_len))
|
||||
input_mask = P.Reshape()(input_mask, (-1, self.seq_len))
|
||||
token_type_id = P.Reshape()(token_type_id, (-1, self.seq_len))
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
logits = self.allgather(logits)
|
||||
logits = self.slice(logits, [self.begin, 0], [self.size, 1])
|
||||
logits = self.reshape(logits, self.shape1)
|
||||
logits = self.transpose(logits, self.trans_shape)
|
||||
logits = self.reshape(logits, self.shape2)
|
||||
loss = self.loss(logits, self.labels)
|
||||
return loss
|
||||
|
||||
class DynamicRanker(nn.Cell):
|
||||
"""
|
||||
Train interface for DynamicRanker v3 finetuning task.
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
dropout_prob (float): The dropout probability for DynamicRankerModel. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
batch_size (int): size of input batch.
|
||||
group_size (int): group size of block.
|
||||
group_num (int): group number of block.
|
||||
rank_id (int): rank id of device.
|
||||
device_num (int): number of device.
|
||||
seq_len (int): Length of input sequence.
|
||||
"""
|
||||
def __init__(self, config, is_training, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
batch_size=64, group_size=8, group_num=2, rank_id=0, device_num=1, seq_len=512):
|
||||
super(DynamicRanker, self).__init__()
|
||||
self.bert = DynamicRankerModel(config, is_training, dropout_prob, use_one_hot_embeddings)
|
||||
self.is_training = is_training
|
||||
self.labels = Tensor(np.zeros([batch_size]).astype(np.int32))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_flat = (batch_size, -1)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.allgather = ops.AllGather()
|
||||
self.loss = CrossEntropyLoss()
|
||||
self.slice = ops.Slice()
|
||||
self.group_id = rank_id * group_num // device_num
|
||||
self.begin = group_size * batch_size * self.group_id
|
||||
self.size = group_size * batch_size
|
||||
self.transpose = P.Transpose()
|
||||
self.shape1 = (device_num // group_num, batch_size, -1)
|
||||
self.shape2 = (batch_size, -1)
|
||||
self.trans_shape = (1, 0, 2)
|
||||
self.batch_size = batch_size
|
||||
self.group_size = group_size
|
||||
self.topk = ops.TopK(sorted=True)
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.cast = ops.Cast()
|
||||
self.seq_len = seq_len
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
"""
|
||||
construct interface for DynamicRanker v3 finetuning task.
|
||||
"""
|
||||
input_ids = P.Reshape()(input_ids, (-1, self.seq_len))
|
||||
input_mask = P.Reshape()(input_mask, (-1, self.seq_len))
|
||||
token_type_id = P.Reshape()(token_type_id, (-1, self.seq_len))
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
logits = self.allgather(logits)
|
||||
logits = self.slice(logits, [self.begin, 0], [self.size, 1])
|
||||
logits = self.reshape(logits, self.shape1)
|
||||
logits = self.transpose(logits, self.trans_shape)
|
||||
logits = self.reshape(logits, self.shape2)
|
||||
pos_sample = self.slice(logits, [0, 0], [self.batch_size, 1])
|
||||
res_sample = self.slice(logits, [0, 1], [self.batch_size, self.group_size - 1])
|
||||
values, _ = self.topk(res_sample, 15)
|
||||
label_ids = P.Reshape()(label_ids, (-1, 15))
|
||||
indices_ = self.cast(label_ids, mstype.float32)
|
||||
_, indices = self.topk(indices_, 15)
|
||||
values = mnp.take_along_axis(values, indices, 1)
|
||||
c2_score = self.concat((pos_sample, values))
|
||||
loss = self.loss(c2_score, self.labels)
|
||||
return loss
|
||||
|
||||
|
||||
class DynamicRankerPredict(nn.Cell):
|
||||
"""
|
||||
Predict interface for DynamicRanker finetuning task.
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
dropout_prob (float): The dropout probability for DynamicRankerModel. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
def __init__(self, config, is_training, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(DynamicRankerPredict, self).__init__()
|
||||
self.bert = DynamicRankerModel(config, is_training, dropout_prob, use_one_hot_embeddings)
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
return logits
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.bert_model import BertConfig
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="pretrain_base_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
# print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def parse_dtype(dtype):
|
||||
if dtype not in ["mstype.float32", "mstype.float16"]:
|
||||
raise ValueError("Not supported dtype")
|
||||
|
||||
if dtype == "mstype.float32":
|
||||
return mstype.float32
|
||||
if dtype == "mstype.float16":
|
||||
return mstype.float16
|
||||
return None
|
||||
|
||||
def extra_operations(cfg):
|
||||
"""
|
||||
Do extra work on config
|
||||
|
||||
Args:
|
||||
config: Object after instantiation of class 'Config'.
|
||||
"""
|
||||
def create_filter_fun(keywords):
|
||||
return lambda x: not (True in [key in x.name.lower() for key in keywords])
|
||||
if cfg.description == 'run_dyr':
|
||||
cfg.optimizer_cfg.AdamWeightDecay.decay_filter = \
|
||||
create_filter_fun(cfg.optimizer_cfg.AdamWeightDecay.decay_filter)
|
||||
cfg.optimizer_cfg.Lamb.decay_filter = create_filter_fun(cfg.optimizer_cfg.Lamb.decay_filter)
|
||||
cfg.dyr_net_cfg.dtype = mstype.float32
|
||||
cfg.dyr_net_cfg.compute_type = mstype.float16
|
||||
cfg.dyr_net_cfg = BertConfig(**cfg.dyr_net_cfg.__dict__)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
def get_abs_path(path_relative):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(current_dir, path_relative)
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
parser.add_argument("--config_path", type=get_abs_path, default="../../dyr_config.yaml",
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
config_obj = Config(final_config)
|
||||
extra_operations(config_obj)
|
||||
return config_obj
|
||||
|
||||
|
||||
config = get_config()
|
||||
dyr_net_cfg = config.dyr_net_cfg
|
||||
if config.description == 'run_dyr':
|
||||
optimizer_cfg = config.optimizer_cfg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Functional Cells used in dyr train and evaluation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
|
||||
def make_directory(path: str):
|
||||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invalid type")
|
||||
|
||||
# convert the relative paths
|
||||
path = os.path.realpath(path)
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
||||
# check the path is exist and write permissions?
|
||||
if os.path.exists(path):
|
||||
real_path = path
|
||||
else:
|
||||
# All exceptions need to be caught because create directory maybe have some limit(permissions)
|
||||
logger.debug("The directory(%s) doesn't exist, will create it", path)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
real_path = path
|
||||
except PermissionError as e:
|
||||
logger.error("No write permission on the directory(%r), error = %r", path, e)
|
||||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, dataset_size=-1):
|
||||
super(LossCallBack, self).__init__()
|
||||
self._dataset_size = dataset_size
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Print loss after each step
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if self._dataset_size > 0:
|
||||
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
|
||||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
|
||||
flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
Find the ckpt finetune generated and load it into eval network.
|
||||
"""
|
||||
files = os.listdir(load_finetune_checkpoint_dir)
|
||||
pre_len = len(prefix)
|
||||
max_num = 0
|
||||
for filename in files:
|
||||
name_ext = os.path.splitext(filename)
|
||||
if name_ext[-1] != ".ckpt":
|
||||
continue
|
||||
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
|
||||
index = filename[pre_len:].find("-")
|
||||
if index == 0 and max_num == 0:
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
elif index not in (0, -1):
|
||||
name_split = name_ext[-2].split('_')
|
||||
if (steps_per_epoch != int(name_split[len(name_split)-1])) \
|
||||
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
|
||||
continue
|
||||
num = filename[pre_len + 1:pre_len + index]
|
||||
if int(num) > max_num:
|
||||
max_num = int(num)
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
return load_finetune_checkpoint_path
|
||||
|
||||
|
||||
class DynamicRankerLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for DynamicRanker network.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(DynamicRankerLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
class MRR():
|
||||
"""
|
||||
Calculate MRR@100 and MRR@10.
|
||||
"""
|
||||
def _mrr(self, gt, pred, val):
|
||||
"""
|
||||
Calculate MRR
|
||||
"""
|
||||
score = 0.0
|
||||
for rank, item in enumerate(pred[:val]):
|
||||
if item in gt:
|
||||
score = 1.0 / (rank + 1.0)
|
||||
break
|
||||
return score
|
||||
|
||||
def _get_qrels(self, qrels_path):
|
||||
"""
|
||||
Get qrels
|
||||
"""
|
||||
qrels = {}
|
||||
with open(qrels_path) as qf:
|
||||
for line in qf:
|
||||
qid, _, docid, _ = line.strip().split()
|
||||
if qid in qrels:
|
||||
qrels[qid].append(docid)
|
||||
else:
|
||||
qrels[qid] = [docid]
|
||||
return qrels
|
||||
|
||||
def _get_scores(self, scores_path):
|
||||
"""
|
||||
Get scores
|
||||
"""
|
||||
scores = {}
|
||||
with open(scores_path) as sf:
|
||||
for line in sf:
|
||||
qid, docid, score = line.strip().split()
|
||||
if qid in scores:
|
||||
scores[qid] += [(docid, float(score))]
|
||||
else:
|
||||
scores[qid] = [(docid, float(score))]
|
||||
for qid in scores:
|
||||
scores[qid] = sorted(scores[qid], key=lambda x: x[1], reverse=True)
|
||||
return scores
|
||||
|
||||
def _calc(self, qrels, scores):
|
||||
"""
|
||||
Calculate MRR@100 and MRR@10.
|
||||
"""
|
||||
cn = 0
|
||||
mrr100 = []
|
||||
mrr10 = []
|
||||
for qid in scores:
|
||||
if qid in qrels:
|
||||
gold_set = set(qrels[qid])
|
||||
y = [s[0] for s in scores[qid]]
|
||||
mrr100 += [self._mrr(gt=gold_set, pred=y, val=100)]
|
||||
mrr10 += [self._mrr(gt=gold_set, pred=y, val=10)]
|
||||
else:
|
||||
cn += 1
|
||||
return mrr100, mrr10
|
||||
|
||||
def accuracy(self, qrels_path, scores_path):
|
||||
"""
|
||||
Calculate MRR@100 and MRR@10.
|
||||
Args:
|
||||
qrels_path : Path of qrels file.
|
||||
score_path : Path of scores file.
|
||||
"""
|
||||
qrels = self._get_qrels(qrels_path)
|
||||
scores = self._get_scores(scores_path)
|
||||
mrr100, mrr10 = self._calc(qrels, scores)
|
||||
print(f"mrr@100:{np.mean(mrr100)}, mrr@10:{np.mean(mrr10)} ")
|
Loading…
Reference in New Issue