forked from mindspore-Ecosystem/mindspore
add seq2seq model
This commit is contained in:
parent
20a5e30481
commit
db07b3400e
|
@ -0,0 +1,245 @@
|
|||
# 目录
|
||||
|
||||
[TOC]
|
||||
|
||||
# Seq2seq描述
|
||||
|
||||
Seq2seq是2014年由谷歌公司的研究人员Ilya Sutskever提出的NLP模型,主要用于英语-法语的机器翻译工作。
|
||||
|
||||
[论文](https://arxiv.org/abs/1409.3215):Ilya Sutskever, Oriol Vinyals, and Quoc V. Le. 2014. Sequence to sequence learning with neural networks. In <i>Proceedings of the 27th International Conference on Neural Information Processing Systems - Volume 2 (NIPS'14). MIT Press, Cambridge, MA, USA, 3104–3112.
|
||||
|
||||
# 模型架构
|
||||
|
||||
Seq2seq模型使用Encoder-Decoder结构,Encoder和Decoder均为4层LSTM。并且输出句子时采用BeamSearch机制搜索。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[WMT14](http://www.statmt.org/wmt14/translation-task.html)
|
||||
|
||||
数据集下载:
|
||||
|
||||
```shell
|
||||
cd scripts
|
||||
bash wmt14_en_fr.sh
|
||||
```
|
||||
|
||||
- 数据集大小:
|
||||
- 训练集:400万行英语句子,400万行法语句子
|
||||
- 测试集:3003行英语句子,3003行法语句子
|
||||
- 数据格式:txt文件
|
||||
- 注:数据将在create_dataset.py中处理。
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度
|
||||
|
||||
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行训练示例
|
||||
python train.py > train.log 2>&1 &
|
||||
|
||||
# 运行分布式训练示例
|
||||
sh scripts/run_train.sh rank_table.json
|
||||
|
||||
# 运行评估示例
|
||||
python eval.py > eval.log 2>&1 &
|
||||
或
|
||||
sh run_eval.sh
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── seq2seq
|
||||
├── README.md // Introduction of Seq2seq model.
|
||||
├── config
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──config.json // Configuration file for pre-train or finetune.
|
||||
│ ├──config_test.json // Configuration file for test.
|
||||
├── src
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──dataset
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──base.py // Base class of data loader.
|
||||
│ ├──bi_data_loader.py // Bilingual data loader.
|
||||
│ ├──load_dataset.py // Dataset loader to feed into model.
|
||||
│ ├──schema.py // Define schema of mindrecord.
|
||||
│ ├──tokenizer.py // Tokenizer class.
|
||||
│ ├──seq2seq_model
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──beam_search.py // Beam search decoder for inferring.
|
||||
│ ├──bleu_calculate.py // Calculat the blue accuracy.
|
||||
│ ├──components.py // Components.
|
||||
│ ├──decoder.py // Seq2seq decoder component.
|
||||
│ ├──decoder_beam_infer.py // Seq2seq decoder component for beam search.
|
||||
│ ├──dynamic_rnn.py // DynamicRNN.
|
||||
│ ├──embedding.py // Embedding component.
|
||||
│ ├──encoder.py // seq2seq encoder component.
|
||||
│ ├──seq2seq.py // seq2seq model architecture.
|
||||
│ ├──seq2seq_for_infer.py // Use Seq2seq to infer.
|
||||
│ ├──seq2seq_for_train.py // Use Seq2seq to train.
|
||||
│ ├──utils
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──initializer.py // Parameters initializer.
|
||||
│ ├──load_weights.py // Load weights from a checkpoint or NPZ file.
|
||||
│ ├──loss_moniter.py // Callback of monitering loss during training step.
|
||||
│ ├──lr_scheduler.py // Learning rate scheduler.
|
||||
│ ├──optimizer.py // Optimizer.
|
||||
├── scripts
|
||||
│ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
|
||||
│ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
|
||||
│ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
|
||||
│ ├──wmt14_en_fr.sh // Shell script for download dataset.
|
||||
│ ├──filter_dataset.py // dataset filter
|
||||
├── create_dataset.py // Dataset preparation.
|
||||
├── eval.py // Infer API entry.
|
||||
├── export.py // Export checkpoint file into air models.
|
||||
├── mindspore_hub_conf.py // Hub config.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置WMT14-en2fr数据集。
|
||||
|
||||
```json
|
||||
"random_seed": 20,
|
||||
"epochs": 8,
|
||||
"batch_size": 128,
|
||||
"dataset_sink_mode": false
|
||||
"seq_length": 51,
|
||||
"vocab_size": 32130,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 4,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"initializer_range": 0.08,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 2,
|
||||
"length_penalty_weight": 0.8,
|
||||
"max_decode_length": 50
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`config.json`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过scripts/train/log_seq2seq_network.log文件查看结果。loss值保存在scripts/train/loss.log
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。模型检查点保存scripts/train/text_translation/ckpt_0下。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_distributed_train_ascend rank_table.json
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过scripts/device[X]/log_seq2seq_network.log文件查看结果。loss值保存在scripts/device[X]/loss.log
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。模型检查点保存scripts/device0/text_translation/ckpt_0下。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估,脚本示例如下
|
||||
|
||||
```bash
|
||||
sh run_standalone_eval_ascend.sh \
|
||||
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
||||
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
||||
seq2seq/dataset_menu/vocab.bpe.32000 \
|
||||
seq2seq/dataset_menu/bpe.32000 \
|
||||
seq2seq/dataset_menu/newstest2014.fr
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通scripts/eval/log_infer.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
# grep "accuracy:"
|
||||
BLEU scores is :12.9
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 资源 | Ascend 910, CPU 2.60GHz, 56核, 内存:314G |
|
||||
| 上传日期 | 2021-3-29 |
|
||||
| MindSpore版本 | 1.1.1 |
|
||||
| 数据集 | WMT14 |
|
||||
| 训练参数 | epoch=8, steps=27496, batch_size=128, lr=2e-3 |
|
||||
| 优化器 | adam |
|
||||
| 损失函数 | LableSmooth交叉熵 |
|
||||
| 输出 | 翻译后的句子与BLEU值 |
|
||||
| 损失 | 50 |
|
||||
| 速度 | 单卡:169毫秒/步; 8卡:208毫秒/步 |
|
||||
| 总时长 | 8卡:2小时 |
|
||||
| 微调检查点 | 1.48G (.ckpt文件) |
|
||||
| 脚本 | [seq2seq脚本](https://gitee.com/honghu-zero/mindspore/tree/seq2seq_1.1/model_zoo/research/nlp/seq2seq) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | -------------- |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-03-29 |
|
||||
| MindSpore版本 | 1.1.1 |
|
||||
| 数据集 | WMT14 |
|
||||
| batch_size | 128 |
|
||||
| 输出 | BLEU |
|
||||
| 准确性 | 8卡: BLEU=12.9 |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中我们设置了随机种子,可在config.json文件中更改随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""seq2seqConifg model configuration."""
|
||||
from .config import Seq2seqConfig
|
||||
|
||||
__all__ = [
|
||||
"Seq2seqConfig"
|
||||
]
|
|
@ -0,0 +1,48 @@
|
|||
{
|
||||
"dataset_config": {
|
||||
"random_seed": 20,
|
||||
"epochs": 8,
|
||||
"batch_size": 128,
|
||||
"pre_train_dataset": "dataset_menu/train.tok.clean.bpe.32000.en.mindrecord",
|
||||
"fine_tune_dataset": null,
|
||||
"valid_dataset": null,
|
||||
"dataset_sink_mode": false
|
||||
},
|
||||
"model_config": {
|
||||
"seq_length": 51,
|
||||
"vocab_size": 32130,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 4,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"initializer_range": 0.08,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 2,
|
||||
"length_penalty_weight": 0.8,
|
||||
"max_decode_length": 50
|
||||
},
|
||||
"loss_scale_config": {
|
||||
"init_loss_scale": 65536,
|
||||
"loss_scale_factor": 2,
|
||||
"scale_window": 1000
|
||||
},
|
||||
"learn_rate_config": {
|
||||
"optimizer": "adam",
|
||||
"lr": 2e-3,
|
||||
"lr_scheduler": "WarmupMultiStepLR",
|
||||
"lr_scheduler_power": 0.5,
|
||||
"warmup_lr_remain_steps": 0.666,
|
||||
"warmup_lr_decay_interval": -1,
|
||||
"decay_steps": 4,
|
||||
"decay_start_step": -1,
|
||||
"warmup_steps": 200,
|
||||
"min_lr": 1e-6
|
||||
},
|
||||
"checkpoint_options": {
|
||||
"existed_ckpt": "",
|
||||
"save_ckpt_steps": 3452,
|
||||
"keep_ckpt_max": 6,
|
||||
"ckpt_prefix": "seq2seq",
|
||||
"ckpt_path": "text_translation"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Configuration class for Seq2seq."""
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def _is_dataset_file(file: str):
|
||||
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
|
||||
|
||||
|
||||
def _get_files_from_dir(folder: str):
|
||||
_files = []
|
||||
for file in os.listdir(folder):
|
||||
if _is_dataset_file(file):
|
||||
_files.append(os.path.join(folder, file))
|
||||
return _files
|
||||
|
||||
|
||||
def get_source_list(folder: str) -> List:
|
||||
"""
|
||||
Get file list from a folder.
|
||||
|
||||
Returns:
|
||||
list, file list.
|
||||
"""
|
||||
_list = []
|
||||
if not folder:
|
||||
return _list
|
||||
|
||||
if os.path.isdir(folder):
|
||||
_list = _get_files_from_dir(folder)
|
||||
else:
|
||||
if _is_dataset_file(folder):
|
||||
_list.append(folder)
|
||||
return _list
|
||||
|
||||
|
||||
PARAM_NODES = {"dataset_config",
|
||||
"model_config",
|
||||
"loss_scale_config",
|
||||
"learn_rate_config",
|
||||
"checkpoint_options"}
|
||||
|
||||
|
||||
class Seq2seqConfig:
|
||||
"""
|
||||
Configuration for `seq2seq`.
|
||||
|
||||
Args:
|
||||
random_seed (int): Random seed, it can be changed.
|
||||
epochs (int): Epoch number.
|
||||
batch_size (int): Batch size of input dataset.
|
||||
pre_train_dataset (str): Path of pre-training dataset file or folder.
|
||||
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
|
||||
test_dataset (str): Path of test dataset file or folder.
|
||||
valid_dataset (str): Path of validation dataset file or folder.
|
||||
dataset_sink_mode (bool): Whether enable dataset sink mode.
|
||||
seq_length (int): Length of input sequence.
|
||||
vocab_size (int): The shape of each embedding vector.
|
||||
hidden_size (int): Size of embedding, attention, dim.
|
||||
num_hidden_layers (int): Encoder, Decoder layers.
|
||||
intermediate_size (int): Size of intermediate layer in the Transformer
|
||||
encoder/decoder cell.
|
||||
hidden_act (str): Activation function used in the Transformer encoder/decoder
|
||||
cell.
|
||||
hidden_dropout_prob (float): The dropout probability for hidden outputs.
|
||||
attention_dropout_prob (float): The dropout probability for Attention module.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
label_smoothing (float): Label smoothing setting.
|
||||
beam_width (int): Beam width for beam search in inferring.
|
||||
length_penalty_weight (float): Penalty for sentence length.
|
||||
max_decode_length (int): Max decode length for inferring.
|
||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||
dataset.
|
||||
init_loss_scale (int): Initialized loss scale.
|
||||
loss_scale_factor (int): Loss scale factor.
|
||||
scale_window (int): Window size of loss scale.
|
||||
lr_scheduler (str): Learning rate scheduler. Please see the Note as follow.
|
||||
optimizer (str): Optimizer for training, e.g. Adam, Lamb, momentum. Default: Adam.
|
||||
lr (float): Initial learning rate.
|
||||
min_lr (float): Minimum learning rate.
|
||||
decay_steps (int): Decay steps.
|
||||
lr_scheduler_power(float): A value used to calculate decayed learning rate.
|
||||
warmup_lr_remain_steps (int or float): Start decay at 'remain_steps' iteration.
|
||||
warmup_lr_decay_interval (int):interval between LR decay steps.
|
||||
decay_start_step (int): Step to decay.
|
||||
warmup_steps (int): Warm up steps.
|
||||
existed_ckpt (str): Using existed checkpoint to keep training or not.
|
||||
save_ckpt_steps (int): Interval of saving ckpt.
|
||||
keep_ckpt_max (int): Max ckpt files number.
|
||||
ckpt_prefix (str): Prefix of ckpt file.
|
||||
ckpt_path (str): Checkpoints save path.
|
||||
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
|
||||
is wanted.
|
||||
dtype (mstype): Data type of the input.
|
||||
|
||||
Note:
|
||||
There are three types of learning rate scheduler, square root scheduler, polynomial
|
||||
decay scheduler and warmup multistep learning rate scheduler.
|
||||
In square root scheduler, the following parameters can be used, lr, decay_start_step,
|
||||
warmup_steps and min_lr.
|
||||
In polynomial decay scheduler, the following parameters can be used, lr, min_lr, decay_steps,
|
||||
warmup_steps, lr_scheduler_power.
|
||||
In warmmup multistep learning rate scheduler, the following parameters can be used, lr, warmup_steps,
|
||||
warmup_lr_remain_steps, warmup_lr_decay_interval, decay_steps, lr_scheduler_power.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
random_seed=50,
|
||||
epochs=6, batch_size=128,
|
||||
pre_train_dataset: str = None,
|
||||
fine_tune_dataset: str = None,
|
||||
test_dataset: str = None,
|
||||
valid_dataset: str = None,
|
||||
dataset_sink_mode=True,
|
||||
seq_length=51, vocab_size=32320, hidden_size=1024,
|
||||
num_hidden_layers=4, intermediate_size=4096,
|
||||
hidden_act="tanh",
|
||||
hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
|
||||
initializer_range=0.1,
|
||||
label_smoothing=0.1,
|
||||
beam_width=2,
|
||||
length_penalty_weight=0.6,
|
||||
max_decode_length=50,
|
||||
input_mask_from_dataset=False,
|
||||
init_loss_scale=65536,
|
||||
loss_scale_factor=2, scale_window=1000,
|
||||
lr_scheduler="WarmupMultiStepLR",
|
||||
optimizer="adam",
|
||||
lr=2e-3, min_lr=1e-6,
|
||||
decay_steps=4, lr_scheduler_power=0.5,
|
||||
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
|
||||
decay_start_step=-1, warmup_steps=200,
|
||||
existed_ckpt="", save_ckpt_steps=3452, keep_ckpt_max=6,
|
||||
ckpt_prefix="seq2seq", ckpt_path: str = None,
|
||||
save_graphs=False,
|
||||
dtype=mstype.float32):
|
||||
|
||||
self.save_graphs = save_graphs
|
||||
self.random_seed = random_seed
|
||||
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
|
||||
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
|
||||
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
|
||||
self.test_dataset = get_source_list(test_dataset) # type: List[str]
|
||||
|
||||
if not isinstance(epochs, int) and epochs < 0:
|
||||
raise ValueError("`epoch` must be type of int.")
|
||||
|
||||
self.epochs = epochs
|
||||
self.dataset_sink_mode = dataset_sink_mode
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
self.keep_ckpt_max = keep_ckpt_max
|
||||
self.save_ckpt_steps = save_ckpt_steps
|
||||
self.ckpt_prefix = ckpt_prefix
|
||||
self.existed_ckpt = existed_ckpt
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_dropout_prob = attention_dropout_prob
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
self.beam_width = beam_width
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.max_decode_length = max_decode_length
|
||||
self.input_mask_from_dataset = input_mask_from_dataset
|
||||
self.compute_type = mstype.float16
|
||||
self.dtype = dtype
|
||||
|
||||
self.scale_window = scale_window
|
||||
self.loss_scale_factor = loss_scale_factor
|
||||
self.init_loss_scale = init_loss_scale
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr = lr
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.min_lr = min_lr
|
||||
self.lr_scheduler_power = lr_scheduler_power
|
||||
self.warmup_lr_remain_steps = warmup_lr_remain_steps
|
||||
self.warmup_lr_decay_interval = warmup_lr_decay_interval
|
||||
self.decay_steps = decay_steps
|
||||
self.decay_start_step = decay_start_step
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object: dict):
|
||||
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
|
||||
_params = {}
|
||||
for node in PARAM_NODES:
|
||||
for key in json_object[node]:
|
||||
_params[key] = json_object[node][key]
|
||||
return cls(**_params)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `TransformerConfig` from a json file of parameters."""
|
||||
with open(json_file, "r") as reader:
|
||||
return cls.from_dict(json.load(reader))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
@ -0,0 +1,50 @@
|
|||
{
|
||||
"dataset_config": {
|
||||
"random_seed": 50,
|
||||
"epochs": 8,
|
||||
"batch_size": 128,
|
||||
"pre_train_dataset": null,
|
||||
"fine_tune_dataset": null,
|
||||
"test_dataset": "dataset_menu/newstest2014.en.mindrecord",
|
||||
"valid_dataset": null,
|
||||
"dataset_sink_mode": true
|
||||
},
|
||||
"model_config": {
|
||||
"seq_length": 86,
|
||||
"vocab_size": 32130,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 4,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"attention_dropout_prob": 0.2,
|
||||
"initializer_range": 0.1,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 2,
|
||||
"length_penalty_weight": 0.6,
|
||||
"max_decode_length": 80
|
||||
},
|
||||
"loss_scale_config": {
|
||||
"init_loss_scale": 65536,
|
||||
"loss_scale_factor": 2,
|
||||
"scale_window": 1000
|
||||
},
|
||||
"learn_rate_config": {
|
||||
"optimizer": "adam",
|
||||
"lr": 2e-3,
|
||||
"lr_scheduler": "WarmupMultiStepLR",
|
||||
"lr_scheduler_power": 0.5,
|
||||
"warmup_lr_remain_steps": 0.666,
|
||||
"warmup_lr_decay_interval": -1,
|
||||
"decay_steps": 4,
|
||||
"decay_start_step": -1,
|
||||
"warmup_steps": 200,
|
||||
"min_lr": 1e-6
|
||||
},
|
||||
"checkpoint_options": {
|
||||
"existed_ckpt": " ",
|
||||
"save_ckpt_steps": 3452,
|
||||
"keep_ckpt_max": 6,
|
||||
"ckpt_prefix": "seq2seq",
|
||||
"ckpt_path": "text_translation"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create Dataset."""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from src.dataset.bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate dataset file.')
|
||||
parser.add_argument("--src_folder", type=str, required=False,
|
||||
help="Raw corpus folder.")
|
||||
|
||||
parser.add_argument("--output_folder", type=str, required=False,
|
||||
help="Dataset output path.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
if not os.path.exists(args.output_folder):
|
||||
os.makedirs(args.output_folder)
|
||||
dicts = []
|
||||
train_src_file = "train.tok.clean.bpe.32000.en"
|
||||
train_tgt_file = "train.tok.clean.bpe.32000.fr"
|
||||
test_src_file = "newstest2014.en"
|
||||
test_tgt_file = "newstest2014.fr"
|
||||
|
||||
vocab = args.src_folder + "/vocab.bpe.32000"
|
||||
bpe_codes = args.src_folder + "/bpe.32000"
|
||||
pad_vocab = 8
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, src_en='en', tgt_fr='fr', vocab_pad=pad_vocab)
|
||||
|
||||
train = BiLingualDataLoader(
|
||||
src_filepath=os.path.join(args.src_folder, train_src_file),
|
||||
tgt_filepath=os.path.join(args.src_folder, train_tgt_file),
|
||||
tokenizer=tokenizer,
|
||||
source_max_sen_len=51,
|
||||
target_max_sen_len=50,
|
||||
schema_address=args.output_folder + "/" + train_src_file + ".json"
|
||||
)
|
||||
print(f" | It's writing, please wait a moment.")
|
||||
train.write_to_mindrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder,
|
||||
os.path.basename(train_src_file) + ".mindrecord"
|
||||
),
|
||||
train_mode=True
|
||||
)
|
||||
test = TextDataLoader(
|
||||
src_filepath=os.path.join(args.src_folder, test_src_file),
|
||||
tokenizer=tokenizer,
|
||||
source_max_sen_len=None,
|
||||
schema_address=args.output_folder + "/" + test_src_file + ".json"
|
||||
)
|
||||
print(f" | It's writing, please wait a moment.")
|
||||
test.write_to_mindrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder,
|
||||
os.path.basename(test_src_file) + ".mindrecord"
|
||||
),
|
||||
train_mode=False
|
||||
)
|
||||
print(f" | Vocabulary size: {tokenizer.vocab_size}.")
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation api."""
|
||||
import os
|
||||
# os.system("pip3 install subword-nmt")
|
||||
# os.system("pip3 install sacremoses")
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import moxing as mox
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import context
|
||||
|
||||
from config import Seq2seqConfig
|
||||
from src.seq2seq_model import infer
|
||||
from src.seq2seq_model.bleu_calculate import bleu_calculate
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
is_modelarts = False
|
||||
|
||||
if is_modelarts:
|
||||
parser = argparse.ArgumentParser(description='seq2seq')
|
||||
parser.add_argument("--config", type=str, required=True,
|
||||
help="model config json file path.")
|
||||
parser.add_argument("--data_url", type=str, required=True,
|
||||
help="data address.")
|
||||
parser.add_argument("--train_url", type=str, required=True,
|
||||
help="output address.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='seq2seq')
|
||||
parser.add_argument("--config", type=str, required=True,
|
||||
help="model config json file path.")
|
||||
parser.add_argument("--test_dataset", type=str, required=True,
|
||||
help="test dataset address.")
|
||||
parser.add_argument("--existed_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
parser.add_argument("--vocab", type=str, required=True,
|
||||
help="Vocabulary to use.")
|
||||
parser.add_argument("--bpe_codes", type=str, required=True,
|
||||
help="bpe codes to use.")
|
||||
parser.add_argument("--test_tgt", type=str, required=True,
|
||||
default=None,
|
||||
help="data file of the test target")
|
||||
parser.add_argument("--output", type=str, required=False,
|
||||
default="./output.npz",
|
||||
help="result file path.")
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=True,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=True)
|
||||
|
||||
def get_config(config):
|
||||
config = Seq2seqConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
def _check_args(config):
|
||||
if not os.path.exists(config):
|
||||
raise FileNotFoundError("`config` is not existed.")
|
||||
if not isinstance(config, str):
|
||||
raise ValueError("`config` must be type of str.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
_check_args(args.config)
|
||||
_config = get_config(args.config)
|
||||
|
||||
if is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
||||
_config.test_dataset = '/cache/dataset_menu/newstest2014.en.mindrecord'
|
||||
_config.existed_ckpt = '/cache/dataset_menu/seq2seq-7_1642.ckpt'
|
||||
|
||||
_config.test_dataset = args.test_dataset
|
||||
_config.existed_ckpt = args.existed_ckpt
|
||||
|
||||
result = infer(_config)
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(result, f, 1)
|
||||
|
||||
result_npy_addr = args.output
|
||||
vocab = args.vocab
|
||||
bpe_codes = args.bpe_codes
|
||||
test_tgt = args.test_tgt
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'fr')
|
||||
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||
print(f"BLEU scores is :{scores}")
|
||||
|
||||
if is_modelarts:
|
||||
result_npy_addr = output
|
||||
vocab = '/cache/dataset_menu/vocab.bpe.32000'
|
||||
bpe_codes = '/cache/dataset_menu/bpe.32000'
|
||||
test_tgt = '/cache/dataset_menu/newstest2014.fr'
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'fr')
|
||||
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||
print(f"BLEU scores is :{scores}")
|
||||
mox.file.copy_parallel(src_url='/cache/infer_output/', dst_url=args.train_url)
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import export
|
||||
|
||||
from config import Seq2seqConfig
|
||||
from src.seq2seq_model.seq2seq import Seq2seqModel
|
||||
from src.seq2seq_model.seq2seq_for_infer import Seq2seqInferCell
|
||||
from src.utils import zero_weight
|
||||
from src.utils.load_weights import load_infer_weights
|
||||
|
||||
parser = argparse.ArgumentParser(description="seq2seq export")
|
||||
parser.add_argument("--file_name", type=str, default="seq2seq", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument('--infer_config', type=str, required=True, help='seq2seq config file')
|
||||
parser.add_argument("--existed_ckpt", type=str, required=True, help="existed checkpoint address.")
|
||||
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
|
||||
parser.add_argument("--bpe_codes", type=str, required=True, help="bpe codes to use.")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=False)
|
||||
|
||||
|
||||
def get_config(config_file):
|
||||
tfm_config = Seq2seqConfig.from_json_file(config_file)
|
||||
tfm_config.compute_type = mstype.float16
|
||||
tfm_config.dtype = mstype.float32
|
||||
return tfm_config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = get_config(args.infer_config)
|
||||
config.existed_ckpt = args.existed_ckpt
|
||||
vocab = args.vocab_file
|
||||
bpe_codes = args.bpe_codes
|
||||
|
||||
tfm_model = Seq2seqModel(
|
||||
config=config,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
params = tfm_model.trainable_params()
|
||||
weights = load_infer_weights(config)
|
||||
|
||||
for param in params:
|
||||
value = param.data
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"{weights_name} is not found in weights.")
|
||||
if isinstance(value, Tensor):
|
||||
if weights_name in weights:
|
||||
assert weights_name in weights
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
|
||||
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
print("weight not found in checkpoint: " + weights_name)
|
||||
param.set_data(zero_weight(value.asnumpy().shape))
|
||||
|
||||
print(" | Load weights successfully.")
|
||||
tfm_infer = Seq2seqInferCell(tfm_model)
|
||||
tfm_infer.set_train(False)
|
||||
|
||||
source_ids = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
|
||||
source_mask = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
|
||||
|
||||
export(tfm_infer, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,4 @@
|
|||
numpy
|
||||
subword-nmt==0.3.7
|
||||
sacrebleu==1.4.14
|
||||
sacremoses==0.0.35
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset_filter"""
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Clean dataset')
|
||||
parser.add_argument('-f1', '--file1', help='file1')
|
||||
parser.add_argument('-f2', '--file2', help='file2')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def save_output(fname, data):
|
||||
with open(fname, 'w') as f:
|
||||
f.writelines(data)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Discards all pairs of sentences which can't be decoded by latin-1 encoder.
|
||||
|
||||
It aims to filter out sentences with rare unicode glyphs and pairs which
|
||||
are most likely not valid English-German sentences.
|
||||
"""
|
||||
args = parse_args()
|
||||
|
||||
c = Counter()
|
||||
skipped = 0
|
||||
valid = 0
|
||||
data1 = []
|
||||
data2 = []
|
||||
|
||||
with open(args.file1) as f1, open(args.file2) as f2:
|
||||
for idx, lines in enumerate(zip(f1, f2)):
|
||||
line1, line2 = lines
|
||||
if idx % 100000 == 1:
|
||||
print('Processed {} lines'.format(idx))
|
||||
try:
|
||||
line1.encode('latin1')
|
||||
line2.encode('latin1')
|
||||
except UnicodeEncodeError:
|
||||
skipped += 1
|
||||
else:
|
||||
data1.append(line1)
|
||||
data2.append(line2)
|
||||
valid += 1
|
||||
c.update(line1)
|
||||
|
||||
ratio = valid / (skipped + valid)
|
||||
print('Skipped: {}, Valid: {}, Valid ratio {}'.format(skipped, valid, ratio))
|
||||
print('Character frequency:', c)
|
||||
|
||||
save_output(args.file1, data1)
|
||||
save_output(args.file2, data2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET"
|
||||
echo "for example:"
|
||||
echo "sh run_distributed_train_ascend.sh \
|
||||
/home/workspace/rank_table_8p.json \
|
||||
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_TABLE_ADDR=$1
|
||||
PRE_TRAIN_DATASET=$2
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export GLOG_v=2
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
rm -rf ${current_exec_path}/device$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
cp ../../*.py .
|
||||
cp -r ../../src .
|
||||
cp -r ../../config .
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ../../train.py \
|
||||
--config=${current_exec_path}/device${i}/config/config.json \
|
||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network${i}.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
cd ${current_exec_path} || exit
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
|
||||
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
|
||||
echo "for example:"
|
||||
echo "sh run_standalone_eval_ascend.sh \
|
||||
/home/workspace/dataset_menu/newstest2014.en.mindrecord \
|
||||
/home/workspace/seq2seq/seq2seq-8_3452.ckpt \
|
||||
/home/workspace/wmt14_fr_en/vocab.bpe.32000 \
|
||||
/home/workspace/wmt14_fr_en/bpe.32000 \
|
||||
/home/workspace/wmt14_fr_en/newstest2014.fr"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
TEST_DATASET=$1
|
||||
EXISTED_CKPT_PATH=$2
|
||||
VOCAB_ADDR=$3
|
||||
BPE_CODE_ADDR=$4
|
||||
TEST_TARGET=$5
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
|
||||
export GLOG_v=2
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../config ./eval
|
||||
cd ./eval || exit
|
||||
echo "start for evaluation"
|
||||
env > env.log
|
||||
python3 eval.py \
|
||||
--config=${current_exec_path}/eval/config/config_test.json \
|
||||
--test_dataset=$TEST_DATASET \
|
||||
--existed_ckpt=$EXISTED_CKPT_PATH \
|
||||
--vocab=$VOCAB_ADDR \
|
||||
--bpe_codes=$BPE_CODE_ADDR \
|
||||
--test_tgt=$TEST_TARGET >log_infer.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,46 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET"
|
||||
echo "for example:"
|
||||
echo "sh run_standalone_train_ascend.sh \
|
||||
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
PRE_TRAIN_DATASET=$1
|
||||
|
||||
export GLOG_v=2
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../config ./train
|
||||
cd ./train || exit
|
||||
echo "start for training"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--config=${current_exec_path}/train/config/config.json \
|
||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,128 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
export LANG=C.UTF-8
|
||||
export LC_ALL=C.UTF-8
|
||||
|
||||
OUTPUT_DIR=${1:-"data_new/wmt14_en_fr"}
|
||||
echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable."
|
||||
|
||||
OUTPUT_DIR_DATA="${OUTPUT_DIR}/data"
|
||||
|
||||
mkdir -p $OUTPUT_DIR_DATA
|
||||
|
||||
echo "Downloading Europarl v7. This may take a while..."
|
||||
wget -nc -nv -O ${OUTPUT_DIR_DATA}/europarl-v7-fr-en.tgz \
|
||||
http://www.statmt.org/europarl/v7/fr-en.tgz
|
||||
|
||||
echo "Downloading Common Crawl corpus. This may take a while..."
|
||||
wget -nc -nv -O ${OUTPUT_DIR_DATA}/common-crawl.tgz \
|
||||
http://www.statmt.org/wmt14/training-parallel-commoncrawl.tgz
|
||||
|
||||
echo "Downloading News Commentary v11. This may take a while..."
|
||||
wget -nc -nv -O ${OUTPUT_DIR_DATA}/nc-v9.tgz \
|
||||
http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz
|
||||
|
||||
echo "Downloading test sets"
|
||||
wget -nc -nv -O ${OUTPUT_DIR_DATA}/test.tgz \
|
||||
http://www.statmt.org/wmt14/test-full.tgz
|
||||
|
||||
# Extract everything
|
||||
echo "Extracting all files..."
|
||||
mkdir -p "${OUTPUT_DIR_DATA}/europarl-v7-fr-en"
|
||||
tar -xvzf "${OUTPUT_DIR_DATA}/europarl-v7-fr-en.tgz" -C "${OUTPUT_DIR_DATA}/europarl-v7-fr-en"
|
||||
mkdir -p "${OUTPUT_DIR_DATA}/common-crawl"
|
||||
tar -xvzf "${OUTPUT_DIR_DATA}/common-crawl.tgz" -C "${OUTPUT_DIR_DATA}/common-crawl"
|
||||
mkdir -p "${OUTPUT_DIR_DATA}/nc-v9"
|
||||
tar -xvzf "${OUTPUT_DIR_DATA}/nc-v9.tgz" -C "${OUTPUT_DIR_DATA}/nc-v9"
|
||||
mkdir -p "${OUTPUT_DIR_DATA}/test"
|
||||
tar -xvzf "${OUTPUT_DIR_DATA}/test.tgz" -C "${OUTPUT_DIR_DATA}/test"
|
||||
|
||||
# Concatenate Training data
|
||||
cat "${OUTPUT_DIR_DATA}/europarl-v7-fr-en/europarl-v7.fr-en.en" \
|
||||
"${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.fr-en.en" \
|
||||
"${OUTPUT_DIR_DATA}/nc-v9/training-parallel-nc-v9/news-commentary-v9.fr-en.en" \
|
||||
> "${OUTPUT_DIR}/train.en"
|
||||
wc -l "${OUTPUT_DIR}/train.en"
|
||||
|
||||
cat "${OUTPUT_DIR_DATA}/europarl-v7-de-en/europarl-v7.fr-en.fr" \
|
||||
"${OUTPUT_DIR_DATA}/common-crawl/commoncrawl.fr-en.fr" \
|
||||
"${OUTPUT_DIR_DATA}/nc-v9/training-parallel-nc-v9/news-commentary-v9.fr-en.fr" \
|
||||
> "${OUTPUT_DIR}/train.fr"
|
||||
wc -l "${OUTPUT_DIR}/train.fr"
|
||||
|
||||
# Clone Moses
|
||||
if [ ! -d "${OUTPUT_DIR}/mosesdecoder" ]; then
|
||||
echo "Cloning moses for data processing"
|
||||
git clone https://github.com/moses-smt/mosesdecoder.git "${OUTPUT_DIR}/mosesdecoder"
|
||||
cd ${OUTPUT_DIR}/mosesdecoder
|
||||
git reset --hard 8c5eaa1a122236bbf927bde4ec610906fea599e6
|
||||
cd -
|
||||
fi
|
||||
|
||||
# Convert newstest2014 data into raw text format
|
||||
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
|
||||
< ${OUTPUT_DIR_DATA}/test/test/newstest2014-fren-src.fr.sgm \
|
||||
> ${OUTPUT_DIR_DATA}/test/test/newstest2014.fr
|
||||
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
|
||||
< ${OUTPUT_DIR_DATA}/test/test/newstest2014-fren-ref.en.sgm \
|
||||
> ${OUTPUT_DIR_DATA}/test/test/newstest2014.en
|
||||
|
||||
cp ${OUTPUT_DIR_DATA}/test/test/newstest2014.fr ${OUTPUT_DIR}
|
||||
cp ${OUTPUT_DIR_DATA}/test/test/newstest2014.en ${OUTPUT_DIR}
|
||||
|
||||
# Tokenize data
|
||||
for f in ${OUTPUT_DIR}/*.fr; do
|
||||
echo "Tokenizing $f..."
|
||||
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l fr -threads 8 < $f > ${f%.*}.tok.fr
|
||||
done
|
||||
|
||||
for f in ${OUTPUT_DIR}/*.en; do
|
||||
echo "Tokenizing $f..."
|
||||
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l en -threads 8 < $f > ${f%.*}.tok.en
|
||||
done
|
||||
|
||||
# Clean all corpora
|
||||
for f in ${OUTPUT_DIR}/*.en; do
|
||||
fbase=${f%.*}
|
||||
echo "Cleaning ${fbase}..."
|
||||
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $fbase fr en "${fbase}.clean" 1 80
|
||||
done
|
||||
|
||||
# Filter datasets
|
||||
python filter_dataset.py \
|
||||
-f1 ${OUTPUT_DIR}/train.tok.clean.en \
|
||||
-f2 ${OUTPUT_DIR}/train.tok.clean.fr
|
||||
|
||||
# Learn Shared BPE
|
||||
merge_ops=32000
|
||||
echo "Learning BPE with merge_ops=${merge_ops}. This may take a while..."
|
||||
cat "${OUTPUT_DIR}/train.tok.clean.fr" "${OUTPUT_DIR}/train.tok.clean.en" | \
|
||||
subword-nmt learn-bpe -s $merge_ops > "${OUTPUT_DIR}/bpe.${merge_ops}"
|
||||
|
||||
echo "Apply BPE with merge_ops=${merge_ops} to tokenized files..."
|
||||
for lang in en fr; do
|
||||
for f in ${OUTPUT_DIR}/*.tok.${lang} ${OUTPUT_DIR}/*.tok.clean.${lang}; do
|
||||
outfile="${f%.*}.bpe.${merge_ops}.${lang}"
|
||||
subword-nmt apply-bpe -c "${OUTPUT_DIR}/bpe.${merge_ops}" < $f > "${outfile}"
|
||||
echo ${outfile}
|
||||
done
|
||||
done
|
||||
# Create vocabulary file for BPE
|
||||
cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.fr" | \
|
||||
subword-nmt get-vocab | cut -f1 -d ' ' > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}"
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset Init."""
|
||||
from .bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||
from .load_dataset import load_dataset
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"BiLingualDataLoader",
|
||||
"TextDataLoader",
|
||||
"Tokenizer"
|
||||
]
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class of data loader."""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .schema import SCHEMA, TEST_SCHEMA
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Data loader for dataset."""
|
||||
_SCHEMA = SCHEMA
|
||||
_TEST_SCHEMA = TEST_SCHEMA
|
||||
|
||||
def __init__(self):
|
||||
self._examples = []
|
||||
|
||||
def _load(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def padding(self, sen, padding_idx, need_sentence_len=None, dtype=np.int64):
|
||||
"""Padding <pad> to sentence."""
|
||||
if need_sentence_len is None:
|
||||
return None
|
||||
if sen.shape[0] > need_sentence_len:
|
||||
return None
|
||||
new_sen = np.array([padding_idx] * need_sentence_len, dtype=dtype)
|
||||
new_sen[:sen.shape[0]] = sen[:]
|
||||
return new_sen
|
||||
|
||||
def write_to_mindrecord(self, path, train_mode, shard_num=1, desc="seq2seq"):
|
||||
"""
|
||||
Write mindrecord file.
|
||||
|
||||
Args:
|
||||
path (str): File path.
|
||||
shard_num (int): Shard num.
|
||||
desc (str): Description.
|
||||
"""
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
|
||||
writer = FileWriter(file_name=path, shard_num=shard_num)
|
||||
if train_mode:
|
||||
writer.add_schema(self._SCHEMA, desc)
|
||||
else:
|
||||
writer.add_schema(self._TEST_SCHEMA, desc)
|
||||
if not self._examples:
|
||||
self._load()
|
||||
|
||||
writer.write_raw_data(self._examples)
|
||||
writer.commit()
|
||||
print(f"| Wrote to {path}.")
|
||||
|
||||
def _add_example(self, example):
|
||||
self._examples.append(example)
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bilingual data loader."""
|
||||
import numpy as np
|
||||
|
||||
from .base import DataLoader
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BiLingualDataLoader(DataLoader):
|
||||
"""Loader for bilingual data."""
|
||||
|
||||
def __init__(self,
|
||||
src_filepath: str,
|
||||
tgt_filepath: str,
|
||||
tokenizer: Tokenizer,
|
||||
min_sen_len=0,
|
||||
source_max_sen_len=None,
|
||||
target_max_sen_len=80,
|
||||
schema_address=None):
|
||||
super(BiLingualDataLoader, self).__init__()
|
||||
self._src_filepath = src_filepath
|
||||
self._tgt_filepath = tgt_filepath
|
||||
self.tokenizer = tokenizer
|
||||
self.min_sen_len = min_sen_len
|
||||
self.source_max_sen_len = source_max_sen_len
|
||||
self.target_max_sen_len = target_max_sen_len
|
||||
self.schema_address = schema_address
|
||||
|
||||
def src_max_sen_isnll(self):
|
||||
"""source max sentence is null"""
|
||||
if self.source_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_src = 0
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = [int(self.tokenizer.tok2idx[t]) for t in _pair.strip().split(" ") if t]
|
||||
src_len = len(src_tokens)
|
||||
if src_len > max_src:
|
||||
max_src = src_len
|
||||
self.source_max_sen_len = max_src + 2
|
||||
|
||||
def tgt_max_sen_isnll(self):
|
||||
"""target max sentence is null"""
|
||||
if self.target_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _tgt_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_tgt = 0
|
||||
for _, _pair in enumerate(_tgt_file):
|
||||
src_tokens = [int(self.tokenizer.tok2idx[t]) for t in _pair.strip().split(" ") if t]
|
||||
tgt_len = len(src_tokens)
|
||||
if tgt_len > max_tgt:
|
||||
max_tgt = tgt_len
|
||||
self.target_max_sen_len = max_tgt + 1
|
||||
|
||||
def write_schema(self, count):
|
||||
"""write schema"""
|
||||
if self.schema_address is not None:
|
||||
provlist = [count, self.source_max_sen_len, self.source_max_sen_len,
|
||||
self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]
|
||||
columns = ["src", "src_padding", "prev_opt", "target", "tgt_padding"]
|
||||
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||
f.write("{\n")
|
||||
f.write(' "datasetType":"MS",\n')
|
||||
f.write(' "numRows":%s,\n' % provlist[0])
|
||||
f.write(' "columns":{\n')
|
||||
t = 1
|
||||
for name in columns:
|
||||
f.write(' "%s":{\n' % name)
|
||||
f.write(' "type":"int64",\n')
|
||||
f.write(' "rank":1,\n')
|
||||
f.write(' "shape":[%s]\n' % provlist[t])
|
||||
f.write(' }')
|
||||
if t < len(columns):
|
||||
f.write(',')
|
||||
f.write('\n')
|
||||
t += 1
|
||||
f.write(' }\n}\n')
|
||||
print(" | Write to " + self.schema_address)
|
||||
|
||||
def _load(self):
|
||||
count = 0
|
||||
self.src_max_sen_isnll()
|
||||
self.tgt_max_sen_isnll()
|
||||
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print("--Processing corpus--")
|
||||
with open(self._tgt_filepath, "r") as _tgt_file:
|
||||
for _, _pair in enumerate(zip(_src_file, _tgt_file)):
|
||||
src_tokens = [int(self.tokenizer.tok2idx[t]) for t in _pair[0].strip().split(" ") if t]
|
||||
tgt_tokens = [int(self.tokenizer.tok2idx[t]) for t in _pair[1].strip().split(" ") if t]
|
||||
src_tokens.insert(0, self.tokenizer.bos_index)
|
||||
src_tokens.append(self.tokenizer.eos_index)
|
||||
tgt_tokens.insert(0, self.tokenizer.bos_index)
|
||||
tgt_tokens.append(self.tokenizer.eos_index)
|
||||
src_tokens = np.array(src_tokens)
|
||||
tgt_tokens = np.array(tgt_tokens)
|
||||
src_len = src_tokens.shape[0]
|
||||
tgt_len = tgt_tokens.shape[0]
|
||||
|
||||
if (src_len > self.source_max_sen_len) or (src_len < self.min_sen_len) or (
|
||||
tgt_len > (self.target_max_sen_len + 1)) or (tgt_len < self.min_sen_len):
|
||||
print(f"+++++ delete! src_len={src_len}, tgt_len={tgt_len - 1}")
|
||||
continue
|
||||
# encoder inputs
|
||||
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||
for i in range(src_len):
|
||||
src_padding[i] = 1
|
||||
# decoder inputs
|
||||
decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||
# decoder outputs
|
||||
decoder_output = self.padding(tgt_tokens[1:], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||
tgt_padding = np.zeros(shape=self.target_max_sen_len + 1, dtype=np.int64)
|
||||
for j in range(tgt_len):
|
||||
tgt_padding[j] = 1
|
||||
tgt_padding = tgt_padding[1:]
|
||||
decoder_input = np.array(decoder_input, dtype=np.int64)
|
||||
decoder_output = np.array(decoder_output, dtype=np.int64)
|
||||
tgt_padding = np.array(tgt_padding, dtype=np.int64)
|
||||
example = {"src": encoder_input, "src_padding": src_padding, "prev_opt": decoder_input,
|
||||
"target": decoder_output, "tgt_padding": tgt_padding}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
self.write_schema(count)
|
||||
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
"""Loader for text data."""
|
||||
|
||||
def __init__(self,
|
||||
src_filepath: str,
|
||||
tokenizer: Tokenizer,
|
||||
min_sen_len=0,
|
||||
source_max_sen_len=None,
|
||||
schema_address=None):
|
||||
super(TextDataLoader, self).__init__()
|
||||
self._src_filepath = src_filepath
|
||||
self.tokenizer = tokenizer
|
||||
self.min_sen_len = min_sen_len
|
||||
self.source_max_sen_len = source_max_sen_len
|
||||
self.schema_address = schema_address
|
||||
|
||||
def _load(self):
|
||||
count = 0
|
||||
if self.source_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_src = 0
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = self.tokenizer.tokenize(_pair)
|
||||
src_len = len(src_tokens)
|
||||
if src_len > max_src:
|
||||
max_src = src_len
|
||||
self.source_max_sen_len = max_src
|
||||
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | Processing corpus {self._src_filepath}.")
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = self.tokenizer.tokenize(_pair)
|
||||
src_len = len(src_tokens)
|
||||
src_tokens = np.array(src_tokens)
|
||||
# encoder inputs
|
||||
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||
for i in range(src_len):
|
||||
src_padding[i] = 1
|
||||
|
||||
example = {
|
||||
"src": encoder_input,
|
||||
"src_padding": src_padding
|
||||
}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
|
||||
print(f" | source padding_len = {self.source_max_sen_len}.")
|
||||
print(f" | Total activate sen = {count}.")
|
||||
print(f" | Total sen = {count}.")
|
||||
|
||||
if self.schema_address is not None:
|
||||
provlist = [count, self.source_max_sen_len, self.source_max_sen_len]
|
||||
columns = ["src", "src_padding"]
|
||||
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||
f.write("{\n")
|
||||
f.write(' "datasetType":"MS",\n')
|
||||
f.write(' "numRows":%s,\n' % provlist[0])
|
||||
f.write(' "columns":{\n')
|
||||
t = 1
|
||||
for name in columns:
|
||||
f.write(' "%s":{\n' % name)
|
||||
f.write(' "type":"int64",\n')
|
||||
f.write(' "rank":1,\n')
|
||||
f.write(' "shape":[%s]\n' % provlist[t])
|
||||
f.write(' }')
|
||||
if t < len(columns):
|
||||
f.write(',')
|
||||
f.write('\n')
|
||||
t += 1
|
||||
f.write(' }\n}\n')
|
||||
print(" | Write to " + self.schema_address)
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset loader to feed into model."""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
|
||||
def _load_dataset(input_files, batch_size, sink_mode=False,
|
||||
rank_size=1, rank_id=0, shuffle=True, drop_remainder=True,
|
||||
is_translate=False):
|
||||
"""
|
||||
Load dataset according to passed in params.
|
||||
|
||||
Args:
|
||||
input_files (list): Data files.
|
||||
batch_size (int): Batch size.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
drop_remainder (bool): Whether drop the last possibly incomplete batch.
|
||||
is_translate (bool): Whether translate the text.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
if not input_files:
|
||||
raise FileNotFoundError("Require at least one dataset.")
|
||||
|
||||
if not isinstance(sink_mode, bool):
|
||||
raise ValueError("`sink` must be type of bool.")
|
||||
|
||||
for datafile in input_files:
|
||||
print(f" | Loading {datafile}.")
|
||||
|
||||
if not is_translate:
|
||||
data_set = ds.MindDataset(
|
||||
input_files, columns_list=[
|
||||
"src", "src_padding",
|
||||
"prev_opt",
|
||||
"target", "tgt_padding"
|
||||
], shuffle=False, num_shards=rank_size, shard_id=rank_id,
|
||||
num_parallel_workers=8
|
||||
)
|
||||
|
||||
ori_dataset_size = data_set.get_dataset_size()
|
||||
print(f" | Dataset size: {ori_dataset_size}.")
|
||||
if shuffle:
|
||||
data_set = data_set.shuffle(buffer_size=ori_dataset_size // 20)
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns="prev_opt", operations=type_cast_op, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns="target", operations=type_cast_op, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns="tgt_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
data_set = data_set.rename(
|
||||
input_columns=["src",
|
||||
"src_padding",
|
||||
"prev_opt",
|
||||
"target",
|
||||
"tgt_padding"],
|
||||
output_columns=["source_eos_ids",
|
||||
"source_eos_mask",
|
||||
"target_sos_ids",
|
||||
"target_eos_ids",
|
||||
"target_eos_mask"]
|
||||
)
|
||||
data_set = data_set.batch(batch_size, drop_remainder=drop_remainder)
|
||||
else:
|
||||
data_set = ds.MindDataset(
|
||||
input_files, columns_list=[
|
||||
"src", "src_padding"
|
||||
],
|
||||
shuffle=False, num_shards=rank_size, shard_id=rank_id,
|
||||
num_parallel_workers=8
|
||||
)
|
||||
|
||||
ori_dataset_size = data_set.get_dataset_size()
|
||||
print(f" | Dataset size: {ori_dataset_size}.")
|
||||
if shuffle:
|
||||
data_set = data_set.shuffle(buffer_size=ori_dataset_size // 20)
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
data_set = data_set.rename(
|
||||
input_columns=["src",
|
||||
"src_padding"],
|
||||
output_columns=["source_eos_ids",
|
||||
"source_eos_mask"]
|
||||
)
|
||||
data_set = data_set.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def load_dataset(data_files: list, batch_size: int, sink_mode: bool,
|
||||
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
|
||||
"""
|
||||
Load dataset.
|
||||
|
||||
Args:
|
||||
data_files (list): Data files.
|
||||
batch_size (int): Batch size.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
return _load_dataset(data_files, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle,
|
||||
drop_remainder=drop_remainder, is_translate=is_translate)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define schema of mindrecord."""
|
||||
|
||||
SCHEMA = {
|
||||
"src": {"type": "int64", "shape": [-1]},
|
||||
"src_padding": {"type": "int64", "shape": [-1]},
|
||||
"prev_opt": {"type": "int64", "shape": [-1]},
|
||||
"target": {"type": "int64", "shape": [-1]},
|
||||
"tgt_padding": {"type": "int64", "shape": [-1]},
|
||||
}
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"src": {"type": "int64", "shape": [-1]},
|
||||
"src_padding": {"type": "int64", "shape": [-1]},
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tokenizer."""
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import subword_nmt.apply_bpe
|
||||
import sacremoses
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Constructor for the Tokenizer class.
|
||||
|
||||
Args:
|
||||
vocab_address: vocabulary address.
|
||||
bpe_code_address: path to the file with bpe codes.
|
||||
vocab_pad: pads vocabulary to a multiple of 'vocab_pad' tokens.
|
||||
isolator: tokenization isolator.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_address=None, bpe_code_address=None,
|
||||
src_en='en', tgt_fr='fr', vocab_pad=8, isolator='@@'):
|
||||
self.padding_index = 0
|
||||
self.unk_index = 1
|
||||
self.bos_index = 2
|
||||
self.eos_index = 3
|
||||
self.pad_word = '<pad>'
|
||||
self.unk_word = '<unk>'
|
||||
self.bos_word = '<s>'
|
||||
self.eos_word = r'<\s>'
|
||||
self.isolator = isolator
|
||||
self.init_bpe(bpe_code_address)
|
||||
self.vocab_establist(vocab_address, vocab_pad)
|
||||
# TODO src = english, tgt = french
|
||||
self.sacremoses_tokenizer = sacremoses.MosesTokenizer(src_en)
|
||||
self.sacremoses_detokenizer = sacremoses.MosesDetokenizer(tgt_fr)
|
||||
|
||||
def init_bpe(self, bpe_code_address):
|
||||
"""Init bpe."""
|
||||
if (bpe_code_address is not None) and os.path.exists(bpe_code_address):
|
||||
with open(bpe_code_address, 'r') as f1:
|
||||
self.bpe = subword_nmt.apply_bpe.BPE(f1)
|
||||
|
||||
def vocab_establist(self, vocab_address, vocab_pad):
|
||||
"""Establish vocabulary."""
|
||||
if (vocab_address is None) or (not os.path.exists(vocab_address)):
|
||||
return
|
||||
vocab_words = [self.pad_word, self.unk_word, self.bos_word, self.eos_word]
|
||||
with open(vocab_address) as f1:
|
||||
for sentence in f1:
|
||||
vocab_words.append(sentence.strip())
|
||||
vocab_size = len(vocab_words)
|
||||
padded_vocab_size = (vocab_size + vocab_pad - 1) // vocab_pad * vocab_pad
|
||||
for idx in range(0, padded_vocab_size - vocab_size):
|
||||
fil_token = f'filled{idx:04d}'
|
||||
vocab_words.append(fil_token)
|
||||
self.vocab_size = len(vocab_words)
|
||||
self.tok2idx = defaultdict(partial(int, self.unk_index))
|
||||
for idx, token in enumerate(vocab_words):
|
||||
self.tok2idx[token] = idx
|
||||
self.idx2tok = {}
|
||||
self.idx2tok = defaultdict(partial(str, ","))
|
||||
for token, idx in self.tok2idx.items():
|
||||
self.idx2tok[idx] = token
|
||||
|
||||
def tokenize(self, sentence):
|
||||
"""Tokenize sentence."""
|
||||
tokenized = self.sacremoses_tokenizer.tokenize(sentence, return_str=True)
|
||||
bpe = self.bpe.process_line(tokenized)
|
||||
sentence = bpe.strip().split()
|
||||
inputs = [self.tok2idx[i] for i in sentence]
|
||||
inputs = [self.bos_index] + inputs + [self.eos_index]
|
||||
return inputs
|
||||
|
||||
def detokenize(self, indexes, gap=' '):
|
||||
"""Detokenizes single sentence and removes token isolator characters."""
|
||||
reconstruction_bpe = gap.join([self.idx2tok[idx] for idx in indexes])
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.isolator + ' ', '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.isolator, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.bos_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.eos_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.unk_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.pad_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.strip()
|
||||
reconstruction_words = self.sacremoses_detokenizer.detokenize(reconstruction_bpe.split())
|
||||
return reconstruction_words
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Seq2seq Init."""
|
||||
from config.config import Seq2seqConfig
|
||||
from .seq2seq import Seq2seqModel
|
||||
from .seq2seq_for_train import Seq2seqTraining, LabelSmoothedCrossEntropyCriterion, \
|
||||
Seq2seqNetworkWithLoss, Seq2seqTrainOneStepWithLossScaleCell
|
||||
from .seq2seq_for_infer import infer
|
||||
from .bleu_calculate import bleu_calculate
|
||||
|
||||
__all__ = [
|
||||
"infer",
|
||||
"Seq2seqTraining",
|
||||
"LabelSmoothedCrossEntropyCriterion",
|
||||
"Seq2seqTrainOneStepWithLossScaleCell",
|
||||
"Seq2seqNetworkWithLoss",
|
||||
"Seq2seqModel",
|
||||
"Seq2seqConfig",
|
||||
"bleu_calculate"
|
||||
]
|
|
@ -0,0 +1,375 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Beam search decoder."""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
INF = 65536.0
|
||||
|
||||
|
||||
class LengthPenalty(nn.Cell):
|
||||
"""
|
||||
Length penalty.
|
||||
|
||||
Args:
|
||||
weight (float): The length penalty weight.
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, weight=1.0, compute_type=mstype.float32):
|
||||
super(LengthPenalty, self).__init__()
|
||||
self.weight = weight
|
||||
self.add = P.TensorAdd()
|
||||
self.pow = P.Pow()
|
||||
self.div = P.RealDiv()
|
||||
self.five = Tensor(5.0, mstype.float32)
|
||||
self.six = Tensor(6.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, length_tensor):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
length_tensor (Tensor): the input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, after punishment of length.
|
||||
"""
|
||||
length_tensor = self.cast(length_tensor, mstype.float32)
|
||||
output = self.add(length_tensor, self.five)
|
||||
output = self.div(output, self.six)
|
||||
output = self.pow(output, self.weight)
|
||||
return output
|
||||
|
||||
class TileBeam(nn.Cell):
|
||||
"""
|
||||
Beam Tile operation.
|
||||
|
||||
Args:
|
||||
beam_width (int): The Number of beam.
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, beam_width, compute_type=mstype.float32):
|
||||
super(TileBeam, self).__init__()
|
||||
self.beam_width = beam_width
|
||||
|
||||
self.expand = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
input_tensor (Tensor): with shape (N, T, D).
|
||||
|
||||
Returns:
|
||||
Tensor, tiled tensor.
|
||||
"""
|
||||
shape = self.shape(input_tensor)
|
||||
# add an dim
|
||||
input_tensor = self.expand(input_tensor, 1)
|
||||
# get tile shape: [1, beam, ...]
|
||||
# shape = self.shape(input_tensor)
|
||||
tile_shape = (1,) + (self.beam_width,)
|
||||
for _ in range(len(shape) - 1):
|
||||
tile_shape = tile_shape + (1,)
|
||||
# tile
|
||||
output = self.tile(input_tensor, tile_shape)
|
||||
# reshape to [batch*beam, ...]
|
||||
out_shape = (shape[0] * self.beam_width,) + shape[1:]
|
||||
output = self.reshape(output, out_shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Mod(nn.Cell):
|
||||
"""
|
||||
Mod operation.
|
||||
|
||||
Args:
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
compute_type=mstype.float32):
|
||||
super(Mod, self).__init__()
|
||||
self.compute_type = compute_type
|
||||
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.sub = P.Sub()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
"""
|
||||
Get the remainder of input_x and input_y.
|
||||
|
||||
Inputs:
|
||||
input_x (Tensor): Divisor.
|
||||
input_y (Tensor): Dividend.
|
||||
|
||||
Returns:
|
||||
Tensor, remainder.
|
||||
"""
|
||||
x = self.floor_div(input_x, input_y)
|
||||
x = self.multiply(x, input_y)
|
||||
x = self.sub(input_x, x)
|
||||
return x
|
||||
|
||||
|
||||
class BeamSearchDecoder(nn.Cell):
|
||||
"""
|
||||
Beam search decoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence.
|
||||
vocab_size (int): The shape of each embedding vector.
|
||||
decoder (Cell): The GNMT decoder.
|
||||
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||
decoder_layers_nums (int): The nums of decoder layers.
|
||||
length_penalty_weight (float): Penalty for sentence length. Default: 0.6.
|
||||
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||
sos_id (int): The index of start label <SOS>. Default: 1.
|
||||
eos_id (int): The index of end label <EOS>. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, predictions output.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
vocab_size,
|
||||
decoder,
|
||||
beam_width=4,
|
||||
decoder_layers_nums=4,
|
||||
length_penalty_weight=0.6,
|
||||
hidden_size=1024,
|
||||
max_decode_length=100,
|
||||
sos_id=2,
|
||||
eos_id=3,
|
||||
is_using_while=True,
|
||||
compute_type=mstype.float32):
|
||||
super(BeamSearchDecoder, self).__init__()
|
||||
|
||||
self.encoder_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.beam_width = beam_width
|
||||
self.decoder_layers_nums = decoder_layers_nums
|
||||
self.max_decode_length = max_decode_length
|
||||
self.decoder = decoder
|
||||
self.is_using_while = is_using_while
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
self.expand = P.ExpandDims()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_flat = (-1,)
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
|
||||
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
|
||||
|
||||
self.select = P.Select()
|
||||
self.flat_shape = (batch_size, beam_width * vocab_size)
|
||||
self.topk = P.TopK(sorted=True)
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
|
||||
self.mod = Mod()
|
||||
self.equal = P.Equal()
|
||||
self.real_div = P.RealDiv()
|
||||
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
|
||||
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
|
||||
|
||||
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
|
||||
self.beam_ids = Tensor(beam_ids, mstype.int32)
|
||||
|
||||
batch_ids = np.arange(batch_size * beam_width).reshape((batch_size, beam_width)) // beam_width
|
||||
self.batch_ids = Tensor(batch_ids, mstype.int32)
|
||||
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.gather_nd = P.GatherNd()
|
||||
|
||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
||||
if self.is_using_while:
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length],
|
||||
sos_id), mstype.int32)
|
||||
else:
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
||||
|
||||
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
|
||||
self.init_scores = Tensor(init_scores, mstype.float32)
|
||||
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
|
||||
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.greater_equal = P.GreaterEqual()
|
||||
self.sub = P.Sub()
|
||||
self.state_concat = P.Concat(axis=0)
|
||||
|
||||
def one_step(self, cur_input_ids, state_log_probs, state_seq, state_length,
|
||||
idx, decoder_hidden_state, state_finished):
|
||||
"""
|
||||
Beam search one_step output.
|
||||
|
||||
Inputs:
|
||||
cur_input_ids (Tensor): with shape (batch_size * beam_width, 1).
|
||||
state_log_probs (Tensor): with shape (batch_size, beam_width).
|
||||
state_seq (Tensor): with shape (batch_size, beam_width, m).
|
||||
state_length (Tensor): with shape (batch_size, beam_width).
|
||||
idx (Tensor): with shape ().
|
||||
decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D).
|
||||
state_finished (Tensor): with shape (batch_size, beam_width).
|
||||
"""
|
||||
|
||||
# log_probs, [batch_size * beam_width, 1, V]
|
||||
log_probs, all_decoder_state = self.decoder(cur_input_ids, decoder_hidden_state)
|
||||
# log_probs: [batch_size, beam_width, V]
|
||||
log_probs = self.reshape(log_probs, (-1, self.beam_width, self.vocab_size))
|
||||
# select topk indices, [batch_size, beam_width, V]
|
||||
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
|
||||
# mask finished beams, [batch_size, beam_width]
|
||||
# t-1 has finished
|
||||
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
|
||||
# save the t-1 probability
|
||||
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
|
||||
# [batch, beam*vocab]
|
||||
flat_scores = self.reshape(total_log_probs, (-1, self.beam_width * self.vocab_size))
|
||||
# select topk, [batch, beam]
|
||||
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
||||
|
||||
# convert to beam and word indices, [batch, beam]
|
||||
temp = topk_indices
|
||||
beam_indices = self.zeroslike(topk_indices)
|
||||
for _ in range(self.beam_width - 1):
|
||||
temp = self.sub(temp, self.vocab_size_tensor)
|
||||
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
|
||||
beam_indices = beam_indices + res
|
||||
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
|
||||
|
||||
# mask finished indices, [batch, beam]
|
||||
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
||||
word_indices = self.select(state_finished, self.eos_ids, word_indices)
|
||||
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
|
||||
|
||||
# sort according to scores with -inf for finished beams, [batch, beam]
|
||||
tmp_log_probs = self.select(
|
||||
self.equal(word_indices, self.eos_ids),
|
||||
self.ninf_tensor,
|
||||
topk_scores)
|
||||
|
||||
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
|
||||
# update, [batch_size, beam_width, 2]
|
||||
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
|
||||
# [batch_size, beam_width]
|
||||
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
|
||||
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
|
||||
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
|
||||
|
||||
# gather indices for selecting alive beams
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
|
||||
|
||||
# length add 1 if not finished in the previous step, [batch_size, beam_width]
|
||||
length_add = self.add(state_length, self.one)
|
||||
state_length = self.select(state_finished, state_length, length_add)
|
||||
state_length = self.gather_nd(state_length, gather_indices)
|
||||
# concat seq
|
||||
seq = self.gather_nd(state_seq, gather_indices)
|
||||
|
||||
# update all_decoder_state
|
||||
all_decoder_state = self.reshape(all_decoder_state,
|
||||
(self.decoder_layers_nums * 2, self.batch_size, self.beam_width,
|
||||
self.hidden_size))
|
||||
for i in range(self.decoder_layers_nums * 2):
|
||||
all_decoder_state[i, :, :, :] = self.gather_nd(all_decoder_state[i, :, :, :], gather_indices)
|
||||
all_decoder_state = self.reshape(all_decoder_state,
|
||||
(self.decoder_layers_nums, 2, self.batch_size * self.beam_width,
|
||||
self.hidden_size))
|
||||
|
||||
# update state_seq
|
||||
if self.is_using_while:
|
||||
state_seq_new = self.cast(seq, mstype.float32)
|
||||
word_indices_fp32 = self.cast(word_indices, mstype.float32)
|
||||
state_seq_new[:, :, idx] = word_indices_fp32
|
||||
state_seq = self.cast(state_seq_new, mstype.int32)
|
||||
else:
|
||||
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
||||
|
||||
cur_input_ids = self.reshape(word_indices, (-1, 1))
|
||||
state_log_probs = topk_scores
|
||||
state_finished = self.equal(word_indices, self.eos_ids)
|
||||
|
||||
return cur_input_ids, state_log_probs, state_seq, state_length, all_decoder_state, state_finished
|
||||
|
||||
def construct(self, state):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
states (Tensor): Output of transformer encoder with shape (2, batch_size * beam_width, D).
|
||||
|
||||
Returns:
|
||||
Tensor, predictions output.
|
||||
"""
|
||||
# beam search start
|
||||
cur_input_ids = self.start_ids
|
||||
state_log_probs = self.init_scores
|
||||
state_seq = self.init_seq
|
||||
state_finished = self.init_finished
|
||||
state_length = self.init_length
|
||||
|
||||
decoder_hidden_state = self.state_concat((self.expand(state, 0), self.expand(state, 0)))
|
||||
decoder_hidden_state = self.state_concat((decoder_hidden_state, decoder_hidden_state))
|
||||
|
||||
if not self.is_using_while:
|
||||
for _ in range(self.max_decode_length + 1):
|
||||
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, \
|
||||
state_finished = self.one_step(cur_input_ids, state_log_probs, state_seq,
|
||||
state_length, None, decoder_hidden_state, state_finished)
|
||||
else:
|
||||
idx = self.start + 1
|
||||
ends = self.start + self.max_decode_length + 1
|
||||
while idx < ends:
|
||||
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, \
|
||||
state_finished = self.one_step(cur_input_ids, state_log_probs, state_seq,
|
||||
state_length, idx, decoder_hidden_state, state_finished)
|
||||
idx = idx + 1
|
||||
# add length penalty scores
|
||||
penalty_len = self.length_penalty(state_length)
|
||||
log_probs = self.real_div(state_log_probs, penalty_len)
|
||||
# sort according to scores
|
||||
_, top_beam_indices = self.topk(log_probs, self.beam_width)
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
||||
# sort sequence
|
||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
||||
if not self.is_using_while:
|
||||
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]
|
||||
else:
|
||||
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]
|
||||
|
||||
return predicted_ids
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Calculate the blue scores"""
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def load_result_data(result_npy_addr):
|
||||
# load the numpy to list.
|
||||
result = np.load(result_npy_addr, allow_pickle=True)
|
||||
return result
|
||||
|
||||
|
||||
def get_bleu_data(tokenizer: Tokenizer, result_npy_addr):
|
||||
"""
|
||||
Detokenizer the prediction.
|
||||
|
||||
Args:
|
||||
tokenizer (Tokenizer): tokenizer operations.
|
||||
result_npy_addr (string): Path to the predict file.
|
||||
|
||||
Returns:
|
||||
List, the predict text context.
|
||||
"""
|
||||
|
||||
result = load_result_data(result_npy_addr)
|
||||
prediction_list = []
|
||||
for _, info in enumerate(result):
|
||||
# prediction detokenize
|
||||
prediction = info["prediction"]
|
||||
prediction_str = tokenizer.detokenize(prediction)
|
||||
prediction_list.append(prediction_str)
|
||||
|
||||
return prediction_list
|
||||
|
||||
|
||||
def calculate_sacrebleu(predict_path, target_path):
|
||||
"""
|
||||
Calculate the BLEU scores.
|
||||
|
||||
Args:
|
||||
predict_path (string): Path to the predict file.
|
||||
target_path (string): Path to the target file.
|
||||
|
||||
Returns:
|
||||
Float32, bleu scores.
|
||||
"""
|
||||
|
||||
sacrebleu_params = '--score-only -lc --tokenize intl'
|
||||
sacrebleu = subprocess.run([f'sacrebleu --input {predict_path} \
|
||||
{target_path} {sacrebleu_params}'],
|
||||
stdout=subprocess.PIPE, shell=True)
|
||||
bleu_scores = round(float(sacrebleu.stdout.strip()), 2)
|
||||
return bleu_scores
|
||||
|
||||
|
||||
def bleu_calculate(tokenizer, result_npy_addr, target_addr=None):
|
||||
"""
|
||||
Calculate the BLEU scores.
|
||||
|
||||
Args:
|
||||
tokenizer (Tokenizer): tokenizer operations.
|
||||
result_npy_addr (string): Path to the predict file.
|
||||
target_addr (string): Path to the target file.
|
||||
|
||||
Returns:
|
||||
Float32, bleu scores.
|
||||
"""
|
||||
|
||||
prediction = get_bleu_data(tokenizer, result_npy_addr)
|
||||
print("predict:\n", prediction)
|
||||
|
||||
eval_path = './predict.txt'
|
||||
with open(eval_path, 'w') as eval_file:
|
||||
lines = [line + '\n' for line in prediction]
|
||||
eval_file.writelines(lines)
|
||||
reference_path = target_addr
|
||||
bleu_scores = calculate_sacrebleu(eval_path, reference_path)
|
||||
return bleu_scores
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Components of model."""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""Cast wrapper."""
|
||||
|
||||
def __init__(self, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
return self.cast(x, self.dst_type)
|
||||
|
||||
|
||||
class LayerNorm(nn.Cell):
|
||||
"""
|
||||
Do layer norm.
|
||||
|
||||
Args:
|
||||
in_channels (int): In channels number of layer norm.
|
||||
return_2d (bool): Whether return 2d tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=None, return_2d=False):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.return_2d = return_2d
|
||||
self.layer_norm = nn.LayerNorm((in_channels,))
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
self.reshape = P.Reshape()
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""Do layer norm."""
|
||||
shape = self.get_shape(input_tensor)
|
||||
batch_size = shape[0]
|
||||
max_len = shape[1]
|
||||
embed_dim = shape[2]
|
||||
|
||||
output = self.reshape(input_tensor, (-1, embed_dim))
|
||||
output = self.cast(output, mstype.float32)
|
||||
output = self.layer_norm(output)
|
||||
output = self.cast(output, self.get_dtype(input_tensor))
|
||||
if not self.return_2d:
|
||||
output = self.reshape(output, (batch_size, max_len, embed_dim))
|
||||
return output
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Decoder of Seq2seq."""
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.initializer import Uniform
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import Seq2seqConfig
|
||||
from .dynamic_rnn import DynamicRNNNet
|
||||
|
||||
|
||||
class Seq2seqDecoder(nn.Cell):
|
||||
"""
|
||||
Implements of decoder.
|
||||
|
||||
Args:
|
||||
decoder_layers (int): Decoder layers.
|
||||
intermediate_size (int): Hidden size of FFN.
|
||||
initializer_range (float): Initial range. Default: 0.02.
|
||||
dropout_prob (float): Dropout rate between layers. Default: 0.1.
|
||||
hidden_act (str): Non-linear activation function in FFN. Default: "relu".
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (N, T', D).
|
||||
"""
|
||||
def __init__(self,
|
||||
config: Seq2seqConfig,
|
||||
is_training: bool,
|
||||
use_one_hot_embeddings: bool = False,
|
||||
initializer_range=0.1,
|
||||
infer_beam_width=1,
|
||||
compute_type=mstype.float16):
|
||||
|
||||
super(Seq2seqDecoder, self).__init__()
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
|
||||
self.is_training = is_training
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_dropout_prob = config.hidden_dropout_prob
|
||||
self.vocab_size = config.vocab_size
|
||||
self.seq_length = config.max_decode_length
|
||||
# batchsize* beam_width for beam_search.
|
||||
self.batch_size = config.batch_size * infer_beam_width
|
||||
self.word_embed_dim = config.hidden_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.oneslike = P.OnesLike()
|
||||
self.state_concat = P.Concat(axis=0)
|
||||
self.all_decoder_state = Tensor(np.zeros([self.num_layers, 2, self.batch_size, config.hidden_size]),
|
||||
mstype.float32)
|
||||
|
||||
decoder_layers = []
|
||||
for _ in range(0, self.num_layers):
|
||||
layer = DynamicRNNNet(
|
||||
seq_length=self.seq_length,
|
||||
batchsize=self.batch_size,
|
||||
word_embed_dim=self.word_embed_dim,
|
||||
hidden_size=self.word_embed_dim)
|
||||
decoder_layers.append(layer)
|
||||
|
||||
self.decoder_layers = nn.CellList(decoder_layers)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(config.hidden_size,
|
||||
config.vocab_size,
|
||||
has_bias=True,
|
||||
weight_init=Uniform(initializer_range),
|
||||
bias_init=Uniform(initializer_range)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
self.shape_op = P.Shape()
|
||||
self.expand = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(0)
|
||||
|
||||
def construct(self, tgt_embeddings, decoder_init_state=None):
|
||||
"""Decoder."""
|
||||
# tgt_embeddings: [T',N,D], state: [2,N,D]
|
||||
query_shape = self.shape_op(tgt_embeddings)
|
||||
if decoder_init_state is None:
|
||||
hidden_state = self.all_decoder_state
|
||||
else:
|
||||
hidden_state = decoder_init_state
|
||||
|
||||
decoder_outputs = self.dropout(tgt_embeddings)
|
||||
decoder_outputs, state_0 = self.decoder_layers[0](decoder_outputs,
|
||||
self.squeeze(hidden_state[0:1, :, :, :]))
|
||||
all_decoder_state = self.expand(state_0, 0)
|
||||
|
||||
for i in range(1, self.num_layers):
|
||||
decoder_outputs = self.dropout(decoder_outputs)
|
||||
decoder_outputs, state = self.decoder_layers[i](decoder_outputs,
|
||||
self.squeeze(hidden_state[i:i+1, :, :, :]))
|
||||
all_decoder_state = self.state_concat((all_decoder_state, self.expand(state, 0)))
|
||||
|
||||
decoder_outputs = self.reshape(decoder_outputs, (-1, self.word_embed_dim))
|
||||
|
||||
if self.is_training:
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float16)
|
||||
decoder_outputs = self.classifier(decoder_outputs)
|
||||
if self.is_training:
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
|
||||
|
||||
# [m, batch_size * beam_width, V]
|
||||
decoder_outputs = self.reshape(decoder_outputs, (query_shape[0], query_shape[1], self.vocab_size))
|
||||
|
||||
return decoder_outputs, all_decoder_state
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Decoder for beam_search of seq2seq."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from .embedding import EmbeddingLookup
|
||||
from .decoder import Seq2seqDecoder
|
||||
from .components import SaturateCast
|
||||
|
||||
|
||||
class PredLogProbs(nn.Cell):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): The length of sequences.
|
||||
width (int): Number of parameters of a layer
|
||||
compute_type (int): Type of input type.
|
||||
dtype (int): Type of MindSpore output type.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
width,
|
||||
compute_type=mstype.float32,
|
||||
dtype=mstype.float32):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.width = width
|
||||
self.compute_type = compute_type
|
||||
self.dtype = dtype
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
# self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, logits):
|
||||
"""
|
||||
Calculate the log_softmax.
|
||||
|
||||
Inputs:
|
||||
input_tensor (Tensor): A batch of sentences with shape (N, T).
|
||||
output_weights (Tensor): A batch of masks with shape (N, T).
|
||||
|
||||
Returns:
|
||||
Tensor, the prediction probability with shape (N, T').
|
||||
"""
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class BeamDecoderStep(nn.Cell):
|
||||
"""
|
||||
Multi-layer transformer decoder step.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
use_one_hot_embeddings,
|
||||
compute_type=mstype.float32):
|
||||
super(BeamDecoderStep, self).__init__(auto_prefix=True)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.word_embed_dim = config.hidden_size
|
||||
self.embedding_lookup = EmbeddingLookup(
|
||||
is_training=False,
|
||||
vocab_size=config.vocab_size,
|
||||
embed_dim=self.word_embed_dim,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
self.projection = PredLogProbs(
|
||||
batch_size=config.batch_size * config.beam_width,
|
||||
seq_length=1,
|
||||
width=config.vocab_size,
|
||||
compute_type=config.compute_type)
|
||||
|
||||
self.seq_length = config.max_decode_length
|
||||
self.decoder = Seq2seqDecoder(config,
|
||||
is_training=False,
|
||||
infer_beam_width=config.beam_width)
|
||||
|
||||
self.ones_like = P.OnesLike()
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.expand = P.ExpandDims()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
ones = np.ones(shape=(config.max_decode_length, config.max_decode_length))
|
||||
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
|
||||
def construct(self, input_ids, decoder_hidden_state):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size * beam_width, m]
|
||||
|
||||
Returns:
|
||||
Tensor, the log_probs. [batch_size * beam_width, 1, vocabulary_size]
|
||||
"""
|
||||
|
||||
# process embedding. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D]
|
||||
input_embedding, _ = self.embedding_lookup(input_ids)
|
||||
input_embedding = self.cast_compute_type(input_embedding)
|
||||
|
||||
input_shape = self.shape(input_ids)
|
||||
input_len = input_shape[1]
|
||||
# [m, batch_size * beam_width, D]
|
||||
input_embedding = self.transpose(input_embedding, self.transpose_orders)
|
||||
|
||||
# decoder_output: [m, batch_size*beam_width, V], all_decoder_state:[4,2,b*beam_width,D]
|
||||
decoder_output, all_decoder_state = self.decoder(input_embedding, decoder_hidden_state)
|
||||
# [batch_size * beam_width, m, v]
|
||||
decoder_output = self.transpose(decoder_output, self.transpose_orders)
|
||||
|
||||
# take the last step, [batch_size * beam_width, 1, V]
|
||||
decoder_output = decoder_output[:, (input_len - 1):input_len, :]
|
||||
|
||||
# projection and log_prob
|
||||
log_probs = self.projection(decoder_output)
|
||||
|
||||
# [batch_size * beam_width, 1, vocabulary_size]
|
||||
return log_probs, all_decoder_state
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DynamicRNN."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class DynamicRNNCell(nn.Cell):
|
||||
"""
|
||||
DynamicRNN Cell.
|
||||
|
||||
Args:
|
||||
num_setp (int): Lengths of setences.
|
||||
batch_size (int): Batch size.
|
||||
word_embed_dim (int): Input size.
|
||||
hidden_size (int): Hidden size .
|
||||
initializer_range (float): Initial range. Default: 0.02
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_setp=50,
|
||||
batch_size=128,
|
||||
word_embed_dim=1024,
|
||||
hidden_size=1024,
|
||||
initializer_range=0.1):
|
||||
super(DynamicRNNCell, self).__init__()
|
||||
self.rnn = P.DynamicRNN()
|
||||
self.num_step = num_setp
|
||||
self.batch_size = batch_size
|
||||
self.input_size = word_embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
# w
|
||||
dynamicRNN_w = np.random.uniform(-initializer_range, initializer_range,
|
||||
size=[self.input_size + self.hidden_size, 4 * self.hidden_size])
|
||||
self.dynamicRNN_w = Parameter(Tensor(dynamicRNN_w, mstype.float32), name="w")
|
||||
# b
|
||||
dynamicRNN_b = np.random.uniform(-initializer_range, initializer_range, size=[4 * self.hidden_size])
|
||||
self.dynamicRNN_b = Parameter(Tensor(dynamicRNN_b, mstype.float32), name="b")
|
||||
|
||||
self.dynamicRNN_h = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
|
||||
self.dynamicRNN_c = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x, init_h=None, init_c=None):
|
||||
w = self.cast(self.dynamicRNN_w, mstype.float16)
|
||||
b = self.cast(self.dynamicRNN_b, mstype.float16)
|
||||
if init_h is None or init_c is None:
|
||||
init_h = self.cast(self.dynamicRNN_h, mstype.float16)
|
||||
init_c = self.cast(self.dynamicRNN_c, mstype.float16)
|
||||
out = self.rnn(x, w, b, None, init_h, init_c)
|
||||
return out[0], out[1], out[2]
|
||||
|
||||
|
||||
class DynamicRNNNet(nn.Cell):
|
||||
"""
|
||||
DynamicRNN Network.
|
||||
|
||||
Args:
|
||||
seq_length (int): Lengths of setences.
|
||||
batchsize (int): Batch size.
|
||||
word_embed_dim (int): Input size.
|
||||
hidden_size (int): Hidden size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seq_length=80,
|
||||
batchsize=128,
|
||||
word_embed_dim=1024,
|
||||
hidden_size=1024):
|
||||
super(DynamicRNNNet, self).__init__()
|
||||
self.max_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.get_shape = P.Shape()
|
||||
self.print = P.Print()
|
||||
self.net = DynamicRNNCell(num_setp=seq_length,
|
||||
batch_size=batchsize,
|
||||
word_embed_dim=word_embed_dim,
|
||||
hidden_size=hidden_size)
|
||||
|
||||
def construct(self, inputs, init_state=None):
|
||||
"""DynamicRNN Network."""
|
||||
inputs = self.cast(inputs, mstype.float16)
|
||||
if init_state is not None:
|
||||
init_h = self.cast(init_state[0:1, :, :], mstype.float16)
|
||||
init_c = self.cast(init_state[-1:, :, :], mstype.float16)
|
||||
out, state_h, state_c = self.net(inputs, init_h, init_c)
|
||||
else:
|
||||
out, state_h, state_c = self.net(inputs)
|
||||
out = self.cast(out, mstype.float32)
|
||||
state = self.concat((state_h[-1:, :, :], state_c[-1:, :, :]))
|
||||
state = self.cast(state, mstype.float32)
|
||||
# out:[T,b,D], state:[2,b,D]
|
||||
return out, state
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Word embedding for seq2seq."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
Embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
is_training (bool): Whether to train.
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embed_dim (int): The size of word embedding.
|
||||
initializer_range (int): The initialize range of parameters.
|
||||
use_one_hot_embeddings (bool): Whether use one-hot embedding. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
is_training,
|
||||
vocab_size,
|
||||
embed_dim,
|
||||
initializer_range=0.1,
|
||||
use_one_hot_embeddings=False):
|
||||
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.embedding_dim = embed_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
init_weight = np.random.normal(-initializer_range, initializer_range, size=[vocab_size, embed_dim])
|
||||
self.embedding_table = Parameter(Tensor(init_weight, mstype.float32), name="embedding_table")
|
||||
self.expand = P.ExpandDims()
|
||||
self.gather = P.GatherV2()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.get_shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): A batch of sentences with shape (N, T).
|
||||
|
||||
Returns:
|
||||
Tensor, word embeddings with shape (N, T, D)
|
||||
"""
|
||||
_shape = self.get_shape(input_ids) # (N, T).
|
||||
_batch_size = _shape[0]
|
||||
_max_len = _shape[1]
|
||||
if self.is_training:
|
||||
embedding_table = self.cast(self.embedding_table, mstype.float16)
|
||||
else:
|
||||
embedding_table = self.embedding_table
|
||||
|
||||
flat_ids = self.reshape(input_ids, (_batch_size * _max_len,))
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
if self.is_training:
|
||||
one_hot_ids = self.cast(one_hot_ids, mstype.float16)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(embedding_table, flat_ids, 0)
|
||||
|
||||
output = self.reshape(output_for_reshape, (_batch_size, _max_len, self.embedding_dim))
|
||||
if self.is_training:
|
||||
output = self.cast(output, mstype.float32)
|
||||
embedding_table = self.cast(embedding_table, mstype.float32)
|
||||
return output, embedding_table
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Encoder of Seq2seq."""
|
||||
import copy
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import Seq2seqConfig
|
||||
from .dynamic_rnn import DynamicRNNNet
|
||||
|
||||
class Seq2seqEncoder(nn.Cell):
|
||||
"""
|
||||
Implements of Seq2seq encoder.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig): Configuration of Seq2seq network.
|
||||
is_training (bool): Whether to train.
|
||||
compute_type (mstype): Mindspore data type.
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (2, T, D).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Seq2seqConfig,
|
||||
is_training: bool,
|
||||
compute_type=mstype.float32):
|
||||
super(Seq2seqEncoder, self).__init__()
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_dropout_prob = config.hidden_dropout_prob
|
||||
self.seq_length = config.seq_length
|
||||
self.batch_size = config.batch_size
|
||||
self.word_embed_dim = config.hidden_size
|
||||
|
||||
encoder_layers = []
|
||||
for _ in range(0, self.num_layers):
|
||||
layer = DynamicRNNNet(seq_length=self.seq_length,
|
||||
batchsize=self.batch_size,
|
||||
word_embed_dim=self.word_embed_dim,
|
||||
hidden_size=self.word_embed_dim)
|
||||
encoder_layers.append(layer)
|
||||
|
||||
self.encoder_layers = nn.CellList(encoder_layers)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
|
||||
self.reverse_v2 = P.ReverseV2(axis=[0])
|
||||
|
||||
def construct(self, inputs):
|
||||
"""Encoder."""
|
||||
inputs_r = self.reverse_v2(inputs)
|
||||
encoder_outputs = inputs_r
|
||||
state = 0
|
||||
|
||||
for i in range(0, self.num_layers):
|
||||
encoder_outputs = self.dropout(encoder_outputs)
|
||||
# [T,N,D] -> [T,N,D]
|
||||
encoder_outputs, state = self.encoder_layers[i](encoder_outputs)
|
||||
|
||||
return encoder_outputs, state
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""seq2seq model"""
|
||||
import copy
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import Seq2seqConfig
|
||||
from .embedding import EmbeddingLookup
|
||||
from .beam_search import BeamSearchDecoder, TileBeam
|
||||
from .encoder import Seq2seqEncoder
|
||||
from .decoder import Seq2seqDecoder
|
||||
from .components import SaturateCast
|
||||
from .decoder_beam_infer import BeamDecoderStep
|
||||
|
||||
|
||||
class Seq2seqModel(nn.Cell):
|
||||
"""
|
||||
Seq2seq with encoder and decoder.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig): Model config.
|
||||
is_training (bool): Whether is training.
|
||||
use_one_hot_embeddings (bool): Whether use one-hot embedding.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs.
|
||||
"""
|
||||
def __init__(self,
|
||||
config: Seq2seqConfig,
|
||||
is_training: bool = False,
|
||||
use_one_hot_embeddings: bool = False,
|
||||
compute_type=mstype.float32):
|
||||
super(Seq2seqModel, self).__init__()
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
self.is_training = is_training
|
||||
self.vocab_size = config.vocab_size
|
||||
self.seq_length = config.seq_length
|
||||
self.batch_size = config.batch_size
|
||||
self.max_decode_length = config.max_decode_length
|
||||
self.word_embed_dim = config.hidden_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.beam_width = config.beam_width
|
||||
self.expand = P.ExpandDims()
|
||||
self.state_concat = P.Concat(axis=0)
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
|
||||
self.embedding_lookup = EmbeddingLookup(
|
||||
is_training=self.is_training,
|
||||
vocab_size=self.vocab_size,
|
||||
embed_dim=self.word_embed_dim,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
self.seq2seq_encoder = Seq2seqEncoder(config, is_training)
|
||||
|
||||
if self.is_training:
|
||||
# use for train.
|
||||
self.seq2seq_decoder = Seq2seqDecoder(config, is_training)
|
||||
|
||||
else:
|
||||
# use for infer.
|
||||
self.reshape = P.Reshape()
|
||||
self.tile_beam = TileBeam(beam_width=config.beam_width)
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
|
||||
beam_decoder_cell = BeamDecoderStep(config,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
# link beam_search after decoder
|
||||
self.beam_decoder = BeamSearchDecoder(
|
||||
batch_size=config.batch_size,
|
||||
seq_length=config.seq_length,
|
||||
vocab_size=config.vocab_size,
|
||||
decoder=beam_decoder_cell,
|
||||
beam_width=config.beam_width,
|
||||
decoder_layers_nums=config.num_hidden_layers,
|
||||
length_penalty_weight=config.length_penalty_weight,
|
||||
hidden_size=config.hidden_size,
|
||||
max_decode_length=config.max_decode_length)
|
||||
self.beam_decoder.add_flags(loop_can_unroll=True)
|
||||
|
||||
def construct(self, source_ids, source_mask=None, target_ids=None):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
In this method, T = src_max_len, T' = tgt_max_len.
|
||||
|
||||
Args:
|
||||
source_ids (Tensor): Source sentences with shape (N, T).
|
||||
source_mask (Tensor): Source sentences padding mask with shape (N, T),
|
||||
where 0 indicates padding position.
|
||||
target_ids (Tensor): Target sentences with shape (N, T').
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs.
|
||||
"""
|
||||
# Process source sentences. src_embeddings:[N, T, D].
|
||||
src_embeddings, _ = self.embedding_lookup(source_ids)
|
||||
# T, N, D
|
||||
inputs = self.transpose(src_embeddings, self.transpose_orders)
|
||||
# encoder. encoder_outputs: [T, N, D]
|
||||
_, state = self.seq2seq_encoder(inputs)
|
||||
|
||||
decoder_init_state = self.state_concat((self.expand(state, 0), self.expand(state, 0)))
|
||||
decoder_init_state = self.state_concat((decoder_init_state, decoder_init_state))
|
||||
|
||||
# decoder.
|
||||
if self.is_training:
|
||||
# training
|
||||
# process target input sentences. N, T, D
|
||||
tgt_embeddings, _ = self.embedding_lookup(target_ids)
|
||||
# T, N, D
|
||||
tgt_embeddings = self.transpose(tgt_embeddings, self.transpose_orders)
|
||||
# cell: [T,N,D].
|
||||
cell, _ = self.seq2seq_decoder(tgt_embeddings, decoder_init_state)
|
||||
# decoder_output: (N, T', V).
|
||||
decoder_outputs = self.transpose(cell, self.transpose_orders)
|
||||
out = decoder_outputs
|
||||
else:
|
||||
#infer
|
||||
beam_state = self.transpose(state, self.transpose_orders)
|
||||
# bean search for state, [N*beam_width, 2, D]
|
||||
beam_state = self.tile_beam(beam_state)
|
||||
beam_state = self.transpose(beam_state, self.transpose_orders)
|
||||
#[2, N*beam_width, D]
|
||||
predicted_ids = self.beam_decoder(beam_state)
|
||||
predicted_ids = self.reshape(predicted_ids, (-1, self.max_decode_length))
|
||||
out = predicted_ids
|
||||
|
||||
return out
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Infer api."""
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.train.model import Model
|
||||
|
||||
from src.dataset import load_dataset
|
||||
from .seq2seq import Seq2seqModel
|
||||
from ..utils import zero_weight
|
||||
from ..utils.load_weights import load_infer_weights
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=False)
|
||||
|
||||
|
||||
class Seq2seqInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of Seq2seqModel network infer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): Seq2seqModel model.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor], predicted_ids and predicted_probs.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(Seq2seqInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask):
|
||||
"""Defines the computation performed."""
|
||||
|
||||
predicted_ids = self.network(source_ids,
|
||||
source_mask)
|
||||
|
||||
return predicted_ids
|
||||
|
||||
|
||||
def seq2seq_infer(config, dataset):
|
||||
"""
|
||||
Run infer with Seq2seqModel.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig): Config.
|
||||
dataset (Dataset): Dataset.
|
||||
|
||||
Returns:
|
||||
List[Dict], prediction, each example has 4 keys, "source",
|
||||
"target", "prediction" and "prediction_prob".
|
||||
"""
|
||||
tfm_model = Seq2seqModel(
|
||||
config=config,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
params = tfm_model.trainable_params()
|
||||
weights = load_infer_weights(config)
|
||||
for param in params:
|
||||
value = param.data
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"{weights_name} is not found in weights.")
|
||||
if isinstance(value, Tensor):
|
||||
if weights_name in weights:
|
||||
assert weights_name in weights
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
|
||||
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
print("weight not found in checkpoint: " + weights_name)
|
||||
param.set_data(zero_weight(value.asnumpy().shape))
|
||||
|
||||
print(" | Load weights successfully.")
|
||||
tfm_infer = Seq2seqInferCell(tfm_model)
|
||||
model = Model(tfm_infer)
|
||||
|
||||
predictions = []
|
||||
source_sentences = []
|
||||
|
||||
shape = P.Shape()
|
||||
concat = P.Concat(axis=0)
|
||||
batch_index = 1
|
||||
pad_idx = 0
|
||||
sos_idx = 2
|
||||
eos_idx = 3
|
||||
source_ids_pad = Tensor(np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)]),
|
||||
[config.batch_size, 1]), mstype.int32)
|
||||
source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),
|
||||
[config.batch_size, 1]), mstype.int32)
|
||||
for batch in dataset.create_dict_iterator():
|
||||
source_sentences.append(batch["source_eos_ids"].asnumpy())
|
||||
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
|
||||
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
|
||||
|
||||
active_num = shape(source_ids)[0]
|
||||
if active_num < config.batch_size:
|
||||
source_ids = concat((source_ids, source_ids_pad[active_num:, :]))
|
||||
source_mask = concat((source_mask, source_mask_pad[active_num:, :]))
|
||||
|
||||
start_time = time.time()
|
||||
predicted_ids = model.predict(source_ids, source_mask)
|
||||
|
||||
print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
|
||||
f"Time cost: {time.time() - start_time}.")
|
||||
if active_num < config.batch_size:
|
||||
predicted_ids = predicted_ids[:active_num, :]
|
||||
batch_index = batch_index + 1
|
||||
predictions.append(predicted_ids.asnumpy())
|
||||
|
||||
output = []
|
||||
for inputs, batch_out in zip(source_sentences, predictions):
|
||||
for i, _ in enumerate(batch_out):
|
||||
if batch_out.ndim == 3:
|
||||
batch_out = batch_out[:, 0]
|
||||
|
||||
example = {
|
||||
"source": inputs[i].tolist(),
|
||||
"prediction": batch_out[i].tolist()
|
||||
}
|
||||
output.append(example)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def infer(config):
|
||||
"""
|
||||
Seq2seqModel infer api.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config.
|
||||
|
||||
Returns:
|
||||
list, result with
|
||||
"""
|
||||
eval_dataset = load_dataset(data_files=config.test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
drop_remainder=False,
|
||||
is_translate=True,
|
||||
shuffle=False) if config.test_dataset else None
|
||||
prediction = seq2seq_infer(config, eval_dataset)
|
||||
return prediction
|
|
@ -0,0 +1,377 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""seq2seq for training."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
|
||||
from .seq2seq import Seq2seqModel
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""Defines the gradients clip."""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
|
||||
return new_grads
|
||||
|
||||
class PredLogProbs(nn.Cell):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): The config of GNMT.
|
||||
|
||||
Returns:
|
||||
Tensor, log softmax output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, log softmax output.
|
||||
"""
|
||||
shape = self.get_shape(input_tensor)
|
||||
logits = self.reshape(input_tensor, (shape[0] * shape[1], shape[2]))
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyCriterion(nn.Cell):
|
||||
"""
|
||||
Label Smoothed Cross-Entropy Criterion.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig): The config of Seq2seq.
|
||||
|
||||
Returns:
|
||||
Tensor, final loss.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(LabelSmoothedCrossEntropyCriterion, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.batch_size = config.batch_size
|
||||
self.smoothing = 0.1
|
||||
self.confidence = 0.9
|
||||
self.last_idx = (-1,)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.index_ids = Tensor(np.arange(config.batch_size * config.max_decode_length).reshape((-1, 1)), mstype.int32)
|
||||
self.gather_nd = P.GatherNd()
|
||||
self.expand = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
|
||||
def construct(self, prediction_scores, label_ids, label_weights):
|
||||
"""
|
||||
Construct network to calculate loss.
|
||||
|
||||
Args:
|
||||
prediction_scores (Tensor): Prediction scores. [batchsize, seq_len, vocab_size]
|
||||
label_ids (Tensor): Labels. [batchsize, seq_len]
|
||||
label_weights (Tensor): Mask tensor. [batchsize, seq_len]
|
||||
|
||||
Returns:
|
||||
Tensor, final loss.
|
||||
"""
|
||||
prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
|
||||
label_ids = self.reshape(label_ids, (-1, 1))
|
||||
label_weights = self.reshape(label_weights, (-1,))
|
||||
tmp_gather_indices = self.concat((self.index_ids, label_ids))
|
||||
nll_loss = self.neg(self.gather_nd(prediction_scores, tmp_gather_indices))
|
||||
nll_loss = label_weights * nll_loss
|
||||
smooth_loss = self.neg(self.reduce_mean(prediction_scores, self.last_idx))
|
||||
smooth_loss = label_weights * smooth_loss
|
||||
loss = self.reduce_sum(self.confidence * nll_loss + self.smoothing * smooth_loss, ())
|
||||
loss = loss / self.batch_size
|
||||
return loss
|
||||
|
||||
|
||||
class Seq2seqTraining(nn.Cell):
|
||||
"""
|
||||
seq2seq training network.
|
||||
|
||||
Args:
|
||||
config (seq2seqConfig): The config of seq2seq.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(Seq2seqTraining, self).__init__()
|
||||
self.seq2seq = Seq2seqModel(config, is_training, use_one_hot_embeddings)
|
||||
self.projection = PredLogProbs(config)
|
||||
|
||||
def construct(self, source_ids, source_mask, target_ids):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
source_ids (Tensor): Source sentence.
|
||||
source_mask (Tensor): Source padding mask.
|
||||
target_ids (Tensor): Target sentence.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores.
|
||||
"""
|
||||
decoder_outputs = self.seq2seq(source_ids, source_mask, target_ids)
|
||||
prediction_scores = self.projection(decoder_outputs)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class Seq2seqNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide seq2seq training loss through network.
|
||||
|
||||
Args:
|
||||
config (seq2seqconfig): The config of seq2seq.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(Seq2seqNetworkWithLoss, self).__init__()
|
||||
self.seq2seq = Seq2seqTraining(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = LabelSmoothedCrossEntropyCriterion(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights):
|
||||
prediction_scores = self.seq2seq(source_ids, source_mask, target_ids)
|
||||
total_loss = self.loss(prediction_scores, label_ids, label_weights)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class Seq2seqTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of seq2seq network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network: Cell. The training network. Note that loss function should have
|
||||
been added.
|
||||
optimizer: Optimizer. Optimizer for updating the weights.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(Seq2seqTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.all_reduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(),
|
||||
dtype=mstype.float32), name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
self.loss_scalar = P.ScalarSummary()
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
target_sos_ids,
|
||||
target_eos_ids,
|
||||
target_eos_mask,
|
||||
sens=None):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
source_eos_ids (Tensor): Source sentence.
|
||||
source_eos_mask (Tensor): Source padding mask.
|
||||
target_sos_ids (Tensor): Target sentence.
|
||||
target_eos_ids (Tensor): Prediction sentence.
|
||||
target_eos_mask (Tensor): Prediction padding mask.
|
||||
sens (Tensor): Loss sen.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
|
||||
"""
|
||||
source_ids = source_eos_ids
|
||||
source_mask = source_eos_mask
|
||||
target_ids = target_sos_ids
|
||||
label_ids = target_eos_ids
|
||||
label_weights = target_eos_mask
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights)
|
||||
# Alloc status.
|
||||
init = self.alloc_status()
|
||||
# Clear overflow buffer.
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# Apply grad reducer on grads.
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
||||
if self.is_distributed:
|
||||
# Sum overflow flag over devices.
|
||||
flag_reduce = self.all_reduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
self.loss_scalar("loss", loss)
|
||||
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for gnmt model."""
|
||||
|
||||
from .lr_scheduler import square_root_schedule
|
||||
from .loss_monitor import LossCallBack
|
||||
from .initializer import zero_weight, one_weight, normal_weight, weight_variable
|
||||
|
||||
__all__ = [
|
||||
"square_root_schedule",
|
||||
"LossCallBack",
|
||||
"one_weight",
|
||||
"zero_weight",
|
||||
"normal_weight",
|
||||
"weight_variable"
|
||||
]
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initializer."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def _compute_fans(shape):
|
||||
"""
|
||||
Computes the number of input and output units for a weight shape.
|
||||
|
||||
Args:
|
||||
shape (tuple): Integer shape tuple or MS tensor shape.
|
||||
|
||||
Returns:
|
||||
tuple, integer scalars (fan_in, fan_out).
|
||||
"""
|
||||
if not shape:
|
||||
fan_in = fan_out = 1
|
||||
elif len(shape) == 1:
|
||||
fan_in = fan_out = shape[0]
|
||||
elif len(shape) == 2:
|
||||
fan_in = shape[0]
|
||||
fan_out = shape[1]
|
||||
else:
|
||||
# Assuming convolution kernels (2D, 3D, or more).
|
||||
# kernel shape: (..., input_depth, depth)
|
||||
receptive_field_size = 1
|
||||
for dim in shape[:-2]:
|
||||
receptive_field_size *= dim
|
||||
fan_in = shape[-2] * receptive_field_size
|
||||
fan_out = shape[-1] * receptive_field_size
|
||||
return int(fan_in), int(fan_out)
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
"""
|
||||
Generate weight var.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
# scale_shape = shape
|
||||
# fan_in, fan_out = _compute_fans(scale_shape)
|
||||
# scale = 1.0 / max(1., (fan_in + fan_out) / 2.)
|
||||
# limit = math.sqrt(3.0 * scale)
|
||||
limit = 0.1
|
||||
values = np.random.uniform(-limit, limit, shape)
|
||||
return values
|
||||
|
||||
|
||||
def one_weight(shape):
|
||||
"""
|
||||
Generate weight with ones.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
ones = np.ones(shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
def zero_weight(shape):
|
||||
"""
|
||||
Generate weight with zeros.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
zeros = np.zeros(shape).astype(np.float32)
|
||||
return Tensor(zeros)
|
||||
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
"""
|
||||
Generate weight with normal dist.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
num_units (int): Dimension.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Weight loader."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
|
||||
def load_infer_weights(config):
|
||||
"""
|
||||
Load weights from ckpt or npz.
|
||||
|
||||
Args:
|
||||
config (Seq2seqConfig): Config.
|
||||
|
||||
Returns:
|
||||
dict, weights.
|
||||
"""
|
||||
model_path = config.existed_ckpt
|
||||
if model_path.endswith(".npz"):
|
||||
ms_ckpt = np.load(model_path)
|
||||
is_npz = True
|
||||
else:
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
is_npz = False
|
||||
weights = {}
|
||||
for param_name in ms_ckpt:
|
||||
infer_name = param_name.replace("seq2seq.seq2seq.", "")
|
||||
if infer_name.startswith("embedding_lookup."):
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
infer_name = "beam_decoder.decoder." + infer_name
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
continue
|
||||
elif not infer_name.startswith("seq2seq_encoder"):
|
||||
if infer_name.startswith("seq2seq_decoder."):
|
||||
infer_name = infer_name.replace("seq2seq_decoder.", "decoder.")
|
||||
infer_name = "beam_decoder.decoder." + infer_name
|
||||
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
return weights
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss monitor."""
|
||||
import time
|
||||
|
||||
from mindspore.train.callback import Callback
|
||||
from config import Seq2seqConfig
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
def __init__(self, config: Seq2seqConfig, per_print_times: int = 1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.config = config
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
if not self.time_stamp_init:
|
||||
self.time_stamp_first = self._get_ms_timestamp()
|
||||
self.time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end."""
|
||||
cb_params = run_context.original_args()
|
||||
file_name = "./loss.log"
|
||||
with open(file_name, "a+") as f:
|
||||
time_stamp_current = self._get_ms_timestamp()
|
||||
f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format(
|
||||
time_stamp_current - self.time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs[0].asnumpy()),
|
||||
str(cb_params.net_outputs[1].asnumpy()),
|
||||
str(cb_params.net_outputs[2].asnumpy())
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def _get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning scheduler."""
|
||||
from math import ceil
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_float2int(values, total_steps):
|
||||
if isinstance(values, float):
|
||||
values = int(values * total_steps)
|
||||
return values
|
||||
|
||||
|
||||
def square_root_schedule(lr, update_num, decay_start_step,
|
||||
warmup_steps=2000,
|
||||
min_lr=1e-5):
|
||||
"""
|
||||
Decay the LR based on the ISR(inverse square root).
|
||||
|
||||
During warm-up::
|
||||
lrs = np.linspace(0, lr, warmup_steps)
|
||||
|
||||
After warm-up:
|
||||
decay_factor = lr * sqrt(warmup_steps)
|
||||
lr = decay_factor / sqrt(step) if step >= decay_start_step else lr
|
||||
|
||||
Args:
|
||||
lr (float): Init learning rate.
|
||||
update_num (int): Total steps.
|
||||
decay_start_step (int): Decay begins after `decay_start_step` steps.
|
||||
warmup_steps (int): Warm up steps.
|
||||
min_lr (float): Min learning rate.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate array.
|
||||
"""
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
|
||||
# If warmup_init_lr > lr, then lr_step is negative.
|
||||
# Otherwise, it's positive.
|
||||
lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps
|
||||
decay_factor = lr * warmup_steps ** 0.5
|
||||
|
||||
lrs = np.empty(shape=update_num, dtype=np.float32)
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < update_num:
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
for step in range(_start_step, update_num):
|
||||
if step < warmup_steps:
|
||||
_lr = warmup_init_lr + step * lr_step
|
||||
elif step < decay_start_step:
|
||||
_lr = lr
|
||||
else:
|
||||
_lr = decay_factor * step ** -0.5
|
||||
if _lr < min_lr:
|
||||
_lr = min_lr
|
||||
lrs[step] = _lr
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
warmup_steps (int): Warmup steps.
|
||||
decay_steps (int): Decay steps.
|
||||
total_update_num (int): Total update steps.
|
||||
min_lr (float): Min learning.
|
||||
power (float): Power factor.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < total_update_num:
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
decay_steps = decay_steps
|
||||
for step in range(_start_step, total_update_num):
|
||||
_step = step - _start_step # 2999
|
||||
ratio = ceil(_step / decay_steps) # 3
|
||||
ratio = 1 if ratio < 1 else ratio
|
||||
_decay_steps = decay_steps * ratio # 3000
|
||||
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def Warmup_MultiStepLR_scheduler(base_lr=0.002, total_update_num=200, warmup_steps=200, remain_steps=1.0,
|
||||
decay_interval=-1, decay_steps=4, decay_factor=0.5):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
base_lr (float): Initial learning rate.
|
||||
total_update_num (int): Total update steps.
|
||||
warmup_steps (int or float): Warmup steps.
|
||||
remain_steps (int or float): start decay at 'remain_steps' iteration
|
||||
decay_interval (int): interval between LR decay steps
|
||||
decay_steps (int): Decay steps.
|
||||
decay_factor (float): decay factor
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
remain_steps = convert_float2int(remain_steps, total_update_num)
|
||||
warmup_steps = convert_float2int(warmup_steps, total_update_num)
|
||||
if warmup_steps > remain_steps:
|
||||
warmup_steps = remain_steps
|
||||
|
||||
if decay_interval < 0:
|
||||
decay_iterations = total_update_num - remain_steps
|
||||
decay_interval = decay_iterations // decay_steps
|
||||
decay_interval = max(decay_interval, 1)
|
||||
else:
|
||||
decay_interval = convert_float2int(decay_interval, total_update_num)
|
||||
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
_start_step = 0
|
||||
for last_epoch in range(_start_step, total_update_num):
|
||||
if last_epoch < warmup_steps:
|
||||
if warmup_steps != 0:
|
||||
warmup_factor = math.exp(math.log(0.01) / warmup_steps)
|
||||
else:
|
||||
warmup_factor = 1.0
|
||||
inv_decay = warmup_factor ** (warmup_steps - last_epoch)
|
||||
lrs[last_epoch] = base_lr * inv_decay
|
||||
elif last_epoch >= remain_steps:
|
||||
decay_iter = last_epoch - remain_steps
|
||||
num_decay_step = decay_iter // decay_interval + 1
|
||||
num_decay_step = min(num_decay_step, decay_steps)
|
||||
lrs[last_epoch] = base_lr * (decay_factor ** num_decay_step)
|
||||
else:
|
||||
lrs[last_epoch] = base_lr
|
||||
return lrs
|
|
@ -0,0 +1,420 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""adam"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn import Optimizer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
_learning_rate_update_func = ['linear', 'cos', 'sin']
|
||||
|
||||
adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
|
||||
|
||||
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||
return next_v
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
|
||||
|
||||
|
||||
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_float_positive('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_positive('power', power, prim_name)
|
||||
validator.check_float_legal_value('power', power, prim_name)
|
||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor")
|
||||
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
|
||||
moment2):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""
|
||||
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
||||
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
|
||||
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
|
||||
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
|
||||
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
|
||||
:math:`\epsilon` represents `eps`.
|
||||
|
||||
Note:
|
||||
The Adam optimizer supports separating parameter groups. Different parameter groups can set different
|
||||
`learning_rate` and `weight_decay`.
|
||||
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
|
||||
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr" and "weight_decay" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value should be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.9.
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.999.
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
||||
1e-8.
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
|
||||
1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.Adam(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
>>> {'params': no_conv_params}]
|
||||
>>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
|
||||
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
|
||||
>>> # learning rate of 0.1 and a weight decay of 0.0.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
|
||||
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
# validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta2")
|
||||
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta1")
|
||||
self.eps = eps
|
||||
|
||||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
|
||||
self.pow = P.Pow()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.realdiv = P.RealDiv()
|
||||
|
||||
self.lr_scalar = P.ScalarSummary()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam optimizer."""
|
||||
params = self.parameters
|
||||
moment1 = self.moment1
|
||||
moment2 = self.moment2
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
self.lr_scalar("learning_rate", lr)
|
||||
|
||||
beta1_power = self.beta1_power * self.beta1
|
||||
self.beta1_power = beta1_power
|
||||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
return success
|
||||
|
||||
|
||||
class AdamWeightDecay(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm weight decay fix.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecay, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam Weight Decay"""
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
return updated_velocity
|
||||
|
||||
|
||||
class AdamWeightDecayDynamicLR(Optimizer):
|
||||
"""
|
||||
Adam Weight Decay Dynamic Learning Rate (LR).
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
decay_steps (int): The steps of the decay.
|
||||
learning_rate (float): A floating point value for the learning rate. Default: 0.001.
|
||||
end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
|
||||
power (float): Power. Default: 10.0.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
decay_steps,
|
||||
learning_rate=0.001,
|
||||
end_learning_rate=0.0001,
|
||||
power=10.0,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
eps=1e-6,
|
||||
weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
|
||||
warmup_steps=0):
|
||||
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
|
||||
# turn them to scalar when me support scalar/tensor mix operations
|
||||
self.global_step = Parameter(initializer(0, [1]), name="global_step")
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
|
||||
self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32))
|
||||
self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32))
|
||||
self.power = power
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.min = P.Minimum()
|
||||
self.pow = P.Pow()
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32))
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam Weight Decay Dynamic LR."""
|
||||
step = self.min(self.global_step, self.decay_steps)
|
||||
p = step / self.decay_steps
|
||||
lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
|
||||
if self.warmup_flag:
|
||||
warmup_percent = self.global_step / self.warmup_steps
|
||||
warmup_lr = self.start_learning_rate * warmup_percent
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
added_global_step = self.global_step + self.one
|
||||
F.control_depend(lr, added_global_step)
|
||||
self.global_step = added_global_step
|
||||
|
||||
return updated_velocity
|
|
@ -0,0 +1,357 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train api."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import moxing as mox
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train.callback import LossMonitor, SummaryCollector
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from config.config import Seq2seqConfig
|
||||
from src.dataset import load_dataset
|
||||
from src.seq2seq_model.seq2seq_for_train import Seq2seqNetworkWithLoss, Seq2seqTrainOneStepWithLossScaleCell
|
||||
from src.utils import LossCallBack
|
||||
from src.utils import one_weight, weight_variable
|
||||
from src.utils.lr_scheduler import square_root_schedule, polynomial_decay_scheduler, Warmup_MultiStepLR_scheduler
|
||||
from src.utils.optimizer import Adam
|
||||
|
||||
parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
||||
|
||||
is_modelarts = False
|
||||
|
||||
if is_modelarts:
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
parser.add_argument("--data_url", type=str, required=True, help="pre-train dataset address.")
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=True,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=True)
|
||||
|
||||
def get_config(config):
|
||||
config = Seq2seqConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
def _train(model, config: Seq2seqConfig,
|
||||
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
|
||||
callbacks: list = None):
|
||||
"""
|
||||
Train model.
|
||||
|
||||
Args:
|
||||
model (Model): MindSpore model instance.
|
||||
config (seq2seqConfig): Config of mass model.
|
||||
pre_training_dataset (Dataset): Pre-training dataset.
|
||||
fine_tune_dataset (Dataset): Fine-tune dataset.
|
||||
test_dataset (Dataset): Test dataset.
|
||||
callbacks (list): A list of callbacks.
|
||||
"""
|
||||
callbacks = callbacks if callbacks else []
|
||||
|
||||
if pre_training_dataset is not None:
|
||||
print(" | Start pre-training job.")
|
||||
epoch_size = pre_training_dataset.get_repeat_count()
|
||||
print("epoch size ", epoch_size)
|
||||
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||
model.train(config.epochs, pre_training_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
||||
|
||||
if fine_tune_dataset is not None:
|
||||
print(" | Start fine-tuning job.")
|
||||
epoch_size = fine_tune_dataset.get_repeat_count()
|
||||
|
||||
model.train(config.epochs, fine_tune_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
||||
|
||||
|
||||
def _load_checkpoint_to_net(config, network):
|
||||
"""load parameters to network from checkpoint."""
|
||||
if config.existed_ckpt:
|
||||
if config.existed_ckpt.endswith(".npz"):
|
||||
weights = np.load(config.existed_ckpt)
|
||||
else:
|
||||
weights = load_checkpoint(config.existed_ckpt)
|
||||
for param in network.trainable_params():
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"Param {weights_name} is not found in ckpt file.")
|
||||
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
param.set_data(weights[weights_name].data)
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
for param in network.trainable_params():
|
||||
name = param.name
|
||||
value = param.data
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
param.set_data(one_weight(value.asnumpy().shape))
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
# param.set_data(zero_weight(value.asnumpy().shape))
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float32)))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float16)))
|
||||
else:
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float32)))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float16)))
|
||||
|
||||
|
||||
def _get_lr(config, update_steps):
|
||||
"""generate learning rate."""
|
||||
if config.lr_scheduler == "isr":
|
||||
lr = Tensor(square_root_schedule(lr=config.lr,
|
||||
update_num=update_steps,
|
||||
decay_start_step=config.decay_start_step,
|
||||
warmup_steps=config.warmup_steps,
|
||||
min_lr=config.min_lr), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "poly":
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=config.decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.lr_scheduler_power), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "WarmupMultiStepLR":
|
||||
lr = Tensor(Warmup_MultiStepLR_scheduler(base_lr=config.lr,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
remain_steps=config.warmup_lr_remain_steps,
|
||||
decay_interval=config.warmup_lr_decay_interval,
|
||||
decay_steps=config.decay_steps,
|
||||
decay_factor=config.lr_scheduler_power), dtype=mstype.float32)
|
||||
else:
|
||||
lr = config.lr
|
||||
return lr
|
||||
|
||||
|
||||
def _get_optimizer(config, network, lr):
|
||||
"""get gnmt optimizer, support Adam, Lamb, Momentum."""
|
||||
if config.optimizer.lower() == "adam":
|
||||
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
|
||||
elif config.optimizer.lower() == "lamb":
|
||||
optimizer = Lamb(network.trainable_params(), learning_rate=lr,
|
||||
eps=1e-6)
|
||||
elif config.optimizer.lower() == "momentum":
|
||||
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"optimizer only support `adam` and `momentum` now.")
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def _build_training_pipeline(config: Seq2seqConfig,
|
||||
pre_training_dataset=None,
|
||||
fine_tune_dataset=None,
|
||||
test_dataset=None):
|
||||
"""
|
||||
Build training pipeline.
|
||||
|
||||
Args:
|
||||
config (seq2seqConfig): Config of seq2seq model.
|
||||
pre_training_dataset (Dataset): Pre-training dataset.
|
||||
fine_tune_dataset (Dataset): Fine-tune dataset.
|
||||
test_dataset (Dataset): Test dataset.
|
||||
"""
|
||||
net_with_loss = Seq2seqNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True)
|
||||
net_with_loss.init_parameters_data()
|
||||
_load_checkpoint_to_net(config, net_with_loss)
|
||||
|
||||
dataset = pre_training_dataset if pre_training_dataset is not None \
|
||||
else fine_tune_dataset
|
||||
|
||||
if dataset is None:
|
||||
raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
|
||||
|
||||
update_steps = config.epochs * dataset.get_dataset_size()
|
||||
|
||||
lr = _get_lr(config, update_steps)
|
||||
optimizer = _get_optimizer(config, net_with_loss, lr)
|
||||
|
||||
# Dynamic loss scale.
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
|
||||
scale_factor=config.loss_scale_factor,
|
||||
scale_window=config.scale_window)
|
||||
net_with_grads = Seq2seqTrainOneStepWithLossScaleCell(
|
||||
network=net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=scale_manager.get_update_cell()
|
||||
)
|
||||
net_with_grads.set_train(True)
|
||||
model = Model(net_with_grads, amp_level="O2")
|
||||
loss_monitor = LossCallBack(config)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
|
||||
rank_size = os.getenv('RANK_SIZE')
|
||||
callbacks = [time_cb, loss_monitor]
|
||||
callbacks.append(LossMonitor(1642))
|
||||
|
||||
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
prefix=config.ckpt_prefix,
|
||||
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
|
||||
callbacks.append(summary_callback)
|
||||
|
||||
if rank_size is None or int(rank_size) == 1:
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
prefix=config.ckpt_prefix,
|
||||
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
|
||||
callbacks.append(summary_callback)
|
||||
|
||||
print(f" | ALL SET, PREPARE TO TRAIN.")
|
||||
_train(model=model, config=config,
|
||||
pre_training_dataset=pre_training_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset,
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
def _setup_parallel_env():
|
||||
context.reset_auto_parallel_context()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
gradients_mean=True
|
||||
)
|
||||
|
||||
|
||||
def train_parallel(config: Seq2seqConfig):
|
||||
"""
|
||||
Train model with multi ascend chips.
|
||||
|
||||
Args:
|
||||
config (seq2seqConfig): Config for Seq2seq model.
|
||||
"""
|
||||
_setup_parallel_env()
|
||||
print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
|
||||
|
||||
pre_train_dataset = load_dataset(
|
||||
data_files=config.pre_train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(
|
||||
data_files=config.fine_tune_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(
|
||||
data_files=config.test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.test_dataset else None
|
||||
|
||||
_build_training_pipeline(config=config,
|
||||
pre_training_dataset=pre_train_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset)
|
||||
|
||||
|
||||
def train_single(config: Seq2seqConfig):
|
||||
"""
|
||||
Train model on single device.
|
||||
|
||||
Args:
|
||||
config (seq2seqConfig): Config for seq2seq model.
|
||||
"""
|
||||
print(" | Starting training on single device.")
|
||||
|
||||
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(data_files=config.test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
sink_mode=config.dataset_sink_mode) if config.test_dataset else None
|
||||
|
||||
_build_training_pipeline(config=config,
|
||||
pre_training_dataset=pre_train_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset)
|
||||
|
||||
|
||||
def _check_args(config):
|
||||
if not os.path.exists(config):
|
||||
raise FileNotFoundError("`config` is not existed.")
|
||||
if not isinstance(config, str):
|
||||
raise ValueError("`config` must be type of str.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
_rank_size = os.getenv('RANK_SIZE')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
||||
_config.pre_train_dataset = '/cache/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord'
|
||||
_config.ckpt_path = '/cache/train_output/'
|
||||
|
||||
_check_args(args.config)
|
||||
_config = get_config(args.config)
|
||||
_config.pre_train_dataset = args.pre_train_dataset
|
||||
|
||||
set_seed(_config.random_seed)
|
||||
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
train_parallel(_config)
|
||||
else:
|
||||
train_single(_config)
|
||||
|
||||
if is_modelarts:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output/', dst_url=args.train_url)
|
Loading…
Reference in New Issue