forked from mindspore-Ecosystem/mindspore
upload gnmt_v2
This commit is contained in:
parent
64d078da79
commit
6ac5be72d9
|
@ -43,6 +43,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp)
|
||||
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
|
||||
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
|
||||
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
|
||||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
|
||||
|
||||
<!-- TOC -->
|
||||
- [GNMT v2 For MindSpore](#gnmt-v2-for-mindspore)
|
||||
- [Model Structure](#model-structure)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Platform](#platform)
|
||||
- [Software](#software)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Configuration File](#configuration-file)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Result](#result)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Practice](#practice)
|
||||
- [Dataset Preprocessing](#dataset-preprocessing)
|
||||
- [Training](#training-1)
|
||||
- [Inference](#inference-1)
|
||||
- [Random Situation Description](#random-situation-description)
|
||||
- [Others](#others)
|
||||
- [ModelZoo](#modelzoo)
|
||||
<!-- /TOC -->
|
||||
|
||||
|
||||
# GNMT v2 For MindSpore
|
||||
The GNMT v2 model is similar to the model described in [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144), which is mainly used for corpus translation.
|
||||
|
||||
# Model Structure
|
||||
The GNMTv2 model mainly consists of an encoder, a decoder, and an attention mechanism, where the encoder and the decoder use a shared word embedding vector.
|
||||
Encoder: consists of four long short-term memory (LSTM) layers. The first LSTM layer is bidirectional, while the other three layers are unidirectional.
|
||||
Decoder: consists of four unidirectional LSTM layers and a fully connected classifier. The output embedding dimension of LSTM is 1024.
|
||||
Attention mechanism: uses the standardized Bahdanau attention mechanism. First, the first layer output of the decoder is used as the input of the attention mechanism. Then, the computing result of the attention mechanism is connected to the input of the decoder LSTM, which is used as the input of the subsequent LSTM layer.
|
||||
|
||||
# Dataset
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
- *WMT Englis-German* for training.
|
||||
- *WMT newstest2014* for evaluation.
|
||||
|
||||
# Environment Requirements
|
||||
## Platform
|
||||
- Hardware (Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial.
|
||||
- Framework
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/en/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/en/master/index.html)
|
||||
|
||||
## Software
|
||||
```txt
|
||||
numpy
|
||||
sacrebleu==1.2.10
|
||||
sacremoses==0.0.19
|
||||
subword_nmt==0.3.7
|
||||
```
|
||||
|
||||
# [Quick Start](#contents)
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
```bash
|
||||
# run training example
|
||||
python train.py --config /home/workspace/gnmt_v2/config/config.json
|
||||
|
||||
# run distributed training example
|
||||
cd ./scripts
|
||||
sh run_distributed_train_ascend.sh
|
||||
|
||||
# run evaluation example
|
||||
cd ./scripts
|
||||
sh run_standalone_eval_ascend.sh
|
||||
```
|
||||
|
||||
# Script Description
|
||||
The GNMT network script and code result are as follows:
|
||||
```text
|
||||
├── gnmt
|
||||
├── README.md // Introduction of GNMTv2 model.
|
||||
├── config
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──config.json // Configuration file for pre-train or finetune.
|
||||
│ ├──config_test.json // Configuration file for test.
|
||||
├── src
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──dataset
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──base.py // Base class of data loader.
|
||||
│ ├──bi_data_loader.py // Bilingual data loader.
|
||||
│ ├──load_dataset.py // Dataset loader to feed into model.
|
||||
│ ├──schema.py // Define schema of mindrecord.
|
||||
│ ├──tokenizer.py // Tokenizer class.
|
||||
│ ├──gnmt_model
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──attention.py // Bahdanau attention mechanism.
|
||||
│ ├──beam_search.py // Beam search decoder for inferring.
|
||||
│ ├──bleu_calculate.py // Calculat the blue accuracy.
|
||||
│ ├──components.py // Components.
|
||||
│ ├──create_attention.py // Recurrent attention.
|
||||
│ ├──create_attn_padding.py // Create attention paddings from input paddings.
|
||||
│ ├──decoder.py // GNMT decoder component.
|
||||
│ ├──decoder_beam_infer.py // GNMT decoder component for beam search.
|
||||
│ ├──dynamic_rnn.py // DynamicRNN.
|
||||
│ ├──embedding.py // Embedding component.
|
||||
│ ├──encoder.py // GNMT encoder component.
|
||||
│ ├──gnmt.py // GNMT model architecture.
|
||||
│ ├──gnmt_for_infer.py // Use GNMT to infer.
|
||||
│ ├──gnmt_for_train.py // Use GNMT to train.
|
||||
│ ├──grad_clip.py // Gradient clip
|
||||
│ ├──utils
|
||||
│ ├──__init__.py // User interface.
|
||||
│ ├──initializer.py // Parameters initializer.
|
||||
│ ├──load_weights.py // Load weights from a checkpoint or NPZ file.
|
||||
│ ├──loss_moniter.py // Callback of monitering loss during training step.
|
||||
│ ├──lr_scheduler.py // Learning rate scheduler.
|
||||
│ ├──optimizer.py // Optimizer.
|
||||
├── scripts
|
||||
│ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend.
|
||||
│ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend.
|
||||
│ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend.
|
||||
├── create_dataset.py // dataset preparation.
|
||||
├── eval.py // Infer API entry.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
```
|
||||
|
||||
## Dataset Preparation
|
||||
You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
|
||||
- train.tok.clean.bpe.32000.en
|
||||
- train.tok.clean.bpe.32000.de
|
||||
- vocab.bpe.32000
|
||||
- bpe.32000
|
||||
- newstest2014.en
|
||||
- newstest2014.de
|
||||
|
||||
- Convert the original data to tfrecord for training and evaluation:
|
||||
|
||||
``` bash
|
||||
python create_dataset.py --src_folder /home/workspace/wmt16_de_en --output_folder /home/workspace/dataset_menu
|
||||
```
|
||||
|
||||
## Configuration File
|
||||
The JSON file in the `config/` directory is the template configuration file.
|
||||
Almost all required options and parameters can be easily assigned, including the training platform, dataset and model configuration, and optimizer parameters. By setting the corresponding options, you can also obtain optional functions such as loss scale and checkpoint.
|
||||
For more information about attributes, see the `config/config.py` file.
|
||||
|
||||
## Training Process
|
||||
The model training requires the shell script `scripts/run_standalone_train_ascend.sh`. In this script, set environment variables and the training script `train.py` to be executed in `gnmt_v2/`.
|
||||
Start task training on a single device and run the following command in bash:
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train_ascend.sh
|
||||
```
|
||||
or multiple devices
|
||||
Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distributed_train_ascend.sh
|
||||
```
|
||||
Note: Ensure that the hccl_json file is assigned when distributed training is running.
|
||||
Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `distribute_script/rank_table_8p.json` file.
|
||||
|
||||
## Evaluation Process
|
||||
|
||||
Set options in `config/config_test.json`. Make sure the 'existed_ckpt', 'dataset_schema' and 'test_dataset' are set to your own path.
|
||||
|
||||
Run `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_eval_ascend.sh
|
||||
```
|
||||
|
||||
# Model Description
|
||||
## Performance
|
||||
### Result
|
||||
#### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | WMT Englis-German |
|
||||
| Training Parameters | epoch=6, batch_size=128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| BLEU Score | 24.05 |
|
||||
| Speed | 344ms/step (8pcs) |
|
||||
| Loss | 63.35 |
|
||||
| Params (M) | 613 |
|
||||
| Checkpoint for inference | 1.8G (.ckpt file) |
|
||||
| Scripts | [gnmt_v2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 11/06/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | WMT newstest2014 |
|
||||
| batch_size | 128 |
|
||||
| outputs | BLEU score |
|
||||
| Accuracy | BLEU= 24.05 |
|
||||
|
||||
## Practice
|
||||
The process of GNMTv2 performing the text translation task is as follows:
|
||||
1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
|
||||
2. Dataset preprocessing.
|
||||
3. Perform training.
|
||||
4. Perform inference.
|
||||
|
||||
### Dataset Preprocessing
|
||||
For a pre-trained model, configure the following options in the `config.json` file:
|
||||
```
|
||||
python create_dataset.py --src_folder /home/work_space/wmt16_de_en --output_folder /home/work_space/dataset_menu
|
||||
```
|
||||
|
||||
### Training
|
||||
For a pre-trained model, configure the following options in the `config/config.json` file:
|
||||
- Assign `pre_train_dataset` and `dataset_schema` to the training dataset path.
|
||||
- Select an optimizer ('momentum/adam/lamb' is available).
|
||||
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
|
||||
- Set other parameters, including dataset configuration and network configuration.
|
||||
- If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.
|
||||
|
||||
Run the shell script `run.sh`:
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train_ascend.sh
|
||||
```
|
||||
|
||||
### Inference
|
||||
For inference using a trained model on multiple hardware platforms, such as GPU, Ascend 910, and Ascend 310, see [Network Migration](https://www.mindspore.cn/tutorial/en/master/advanced_use/network_migration.html).
|
||||
For inference interruption, configure the following options in the `config/config.json` file:
|
||||
- Assign `test_dataset` and the `dataset_schema` to the inference dataset path.
|
||||
- Assign `existed_ckpt` and the `checkpoint_path` to the path of the model file generated during training.
|
||||
- Set other parameters, including dataset configuration and network configuration.
|
||||
|
||||
Run the shell script `run.sh`:
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_eval_ascend.sh
|
||||
```
|
||||
|
||||
# Random Situation Description
|
||||
There are three random situations:
|
||||
- Shuffle of the dataset.
|
||||
- Initialization of some model weights.
|
||||
- Dropout operations.
|
||||
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in config/config.json.
|
||||
|
||||
# Others
|
||||
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
|
||||
|
||||
# ModelZoo 主页
|
||||
[链接](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GNMTv2 model configuration."""
|
||||
from .config import GNMTConfig
|
||||
|
||||
__all__ = [
|
||||
"GNMTConfig"
|
||||
]
|
|
@ -0,0 +1,55 @@
|
|||
{
|
||||
"training_platform": {
|
||||
"modelarts": false
|
||||
},
|
||||
"dataset_config": {
|
||||
"random_seed": 50,
|
||||
"epochs": 6,
|
||||
"batch_size": 128,
|
||||
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
|
||||
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001",
|
||||
"fine_tune_dataset": null,
|
||||
"test_dataset": null,
|
||||
"valid_dataset": null,
|
||||
"dataset_sink_mode": true,
|
||||
"dataset_sink_step": 2
|
||||
},
|
||||
"model_config": {
|
||||
"seq_length": 51,
|
||||
"vocab_size": 32320,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 4,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"attention_dropout_prob": 0.2,
|
||||
"initializer_range": 0.1,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 2,
|
||||
"length_penalty_weight": 0.6,
|
||||
"max_decode_length": 50
|
||||
},
|
||||
"loss_scale_config": {
|
||||
"init_loss_scale": 65536,
|
||||
"loss_scale_factor": 2,
|
||||
"scale_window": 1000
|
||||
},
|
||||
"learn_rate_config": {
|
||||
"optimizer": "adam",
|
||||
"lr": 2e-3,
|
||||
"lr_scheduler": "WarmupMultiStepLR",
|
||||
"lr_scheduler_power": 0.5,
|
||||
"warmup_lr_remain_steps": 0.666,
|
||||
"warmup_lr_decay_interval": -1,
|
||||
"decay_steps": 4,
|
||||
"decay_start_step": -1,
|
||||
"warmup_steps": 200,
|
||||
"min_lr": 1e-6
|
||||
},
|
||||
"checkpoint_options": {
|
||||
"existed_ckpt": "",
|
||||
"save_ckpt_steps": 3452,
|
||||
"keep_ckpt_max": 6,
|
||||
"ckpt_prefix": "gnmt",
|
||||
"ckpt_path": "text_translation"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Configuration class for GNMT."""
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def _is_dataset_file(file: str):
|
||||
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
|
||||
|
||||
|
||||
def _get_files_from_dir(folder: str):
|
||||
_files = []
|
||||
for file in os.listdir(folder):
|
||||
if _is_dataset_file(file):
|
||||
_files.append(os.path.join(folder, file))
|
||||
return _files
|
||||
|
||||
|
||||
def get_source_list(folder: str) -> List:
|
||||
"""
|
||||
Get file list from a folder.
|
||||
|
||||
Returns:
|
||||
list, file list.
|
||||
"""
|
||||
_list = []
|
||||
if not folder:
|
||||
return _list
|
||||
|
||||
if os.path.isdir(folder):
|
||||
_list = _get_files_from_dir(folder)
|
||||
else:
|
||||
if _is_dataset_file(folder):
|
||||
_list.append(folder)
|
||||
return _list
|
||||
|
||||
|
||||
PARAM_NODES = {"dataset_config",
|
||||
"training_platform",
|
||||
"model_config",
|
||||
"loss_scale_config",
|
||||
"learn_rate_config",
|
||||
"checkpoint_options"}
|
||||
|
||||
|
||||
class GNMTConfig:
|
||||
"""
|
||||
Configuration for `GNMT`.
|
||||
|
||||
Args:
|
||||
random_seed (int): Random seed.
|
||||
batch_size (int): Batch size of input dataset.
|
||||
epochs (int): Epoch number.
|
||||
dataset_sink_mode (bool): Whether enable dataset sink mode.
|
||||
dataset_sink_step (int): Dataset sink step.
|
||||
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
|
||||
lr (float): Initial learning rate.
|
||||
min_lr (float): Minimum learning rate.
|
||||
decay_start_step (int): Step to decay.
|
||||
warmup_steps (int): Warm up steps.
|
||||
dataset_schema (str): Path of dataset schema file.
|
||||
pre_train_dataset (str): Path of pre-training dataset file or folder.
|
||||
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
|
||||
test_dataset (str): Path of test dataset file or folder.
|
||||
valid_dataset (str): Path of validation dataset file or folder.
|
||||
ckpt_path (str): Checkpoints save path.
|
||||
save_ckpt_steps (int): Interval of saving ckpt.
|
||||
ckpt_prefix (str): Prefix of ckpt file.
|
||||
keep_ckpt_max (int): Max ckpt files number.
|
||||
seq_length (int): Length of input sequence. Default: 64.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 46192.
|
||||
hidden_size (int): Size of embedding, attention, dim. Default: 512.
|
||||
num_hidden_layers (int): Encoder, Decoder layers.
|
||||
|
||||
intermediate_size (int): Size of intermediate layer in the Transformer
|
||||
encoder/decoder cell. Default: 4096.
|
||||
hidden_act (str): Activation function used in the Transformer encoder/decoder
|
||||
cell. Default: "relu".
|
||||
init_loss_scale (int): Initialized loss scale.
|
||||
loss_scale_factor (int): Loss scale factor.
|
||||
scale_window (int): Window size of loss scale.
|
||||
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
|
||||
label_smoothing (float): Label smoothing setting. Default: 0.1.
|
||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||
dataset. Default: True.
|
||||
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
|
||||
is wanted.
|
||||
dtype (mstype): Data type of the input. Default: mstype.float32.
|
||||
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||
attention_dropout_prob (float): The dropout probability for
|
||||
Multi-head Self-Attention. Default: 0.1.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
modelarts=False, random_seed=74,
|
||||
epochs=6, batch_size=64,
|
||||
dataset_schema: str = None,
|
||||
pre_train_dataset: str = None,
|
||||
fine_tune_dataset: str = None,
|
||||
test_dataset: str = None,
|
||||
valid_dataset: str = None,
|
||||
dataset_sink_mode=True, dataset_sink_step=1,
|
||||
seq_length=51, vocab_size=32320, hidden_size=1024,
|
||||
num_hidden_layers=4, intermediate_size=4096,
|
||||
hidden_act="tanh",
|
||||
hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
|
||||
initializer_range=0.1,
|
||||
label_smoothing=0.1,
|
||||
beam_width=5,
|
||||
length_penalty_weight=1.0,
|
||||
max_decode_length=50,
|
||||
input_mask_from_dataset=False,
|
||||
init_loss_scale=2 ** 10,
|
||||
loss_scale_factor=2, scale_window=128,
|
||||
lr_scheduler="", optimizer="adam",
|
||||
lr=1e-4, min_lr=1e-6,
|
||||
decay_steps=4, lr_scheduler_power=1,
|
||||
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
|
||||
decay_start_step=-1, warmup_steps=200,
|
||||
existed_ckpt="", save_ckpt_steps=2000, keep_ckpt_max=20,
|
||||
ckpt_prefix="gnmt", ckpt_path: str = None,
|
||||
save_step=10000,
|
||||
save_graphs=False,
|
||||
dtype=mstype.float32):
|
||||
|
||||
self.save_graphs = save_graphs
|
||||
self.random_seed = random_seed
|
||||
self.modelarts = modelarts
|
||||
self.save_step = save_step
|
||||
self.dataset_schema = dataset_schema
|
||||
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
|
||||
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
|
||||
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
|
||||
self.test_dataset = get_source_list(test_dataset) # type: List[str]
|
||||
|
||||
if not isinstance(epochs, int) and epochs < 0:
|
||||
raise ValueError("`epoch` must be type of int.")
|
||||
|
||||
self.epochs = epochs
|
||||
self.dataset_sink_mode = dataset_sink_mode
|
||||
self.dataset_sink_step = dataset_sink_step
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
self.keep_ckpt_max = keep_ckpt_max
|
||||
self.save_ckpt_steps = save_ckpt_steps
|
||||
self.ckpt_prefix = ckpt_prefix
|
||||
self.existed_ckpt = existed_ckpt
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_dropout_prob = attention_dropout_prob
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
self.beam_width = beam_width
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.max_decode_length = max_decode_length
|
||||
self.input_mask_from_dataset = input_mask_from_dataset
|
||||
self.compute_type = mstype.float16
|
||||
self.dtype = dtype
|
||||
|
||||
self.scale_window = scale_window
|
||||
self.loss_scale_factor = loss_scale_factor
|
||||
self.init_loss_scale = init_loss_scale
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr = lr
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.min_lr = min_lr
|
||||
self.lr_scheduler_power = lr_scheduler_power
|
||||
self.warmup_lr_remain_steps = warmup_lr_remain_steps
|
||||
self.warmup_lr_decay_interval = warmup_lr_decay_interval
|
||||
self.decay_steps = decay_steps
|
||||
self.decay_start_step = decay_start_step
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
self.train_url = ""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object: dict):
|
||||
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
|
||||
_params = {}
|
||||
for node in PARAM_NODES:
|
||||
for key in json_object[node]:
|
||||
_params[key] = json_object[node][key]
|
||||
return cls(**_params)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `TransformerConfig` from a json file of parameters."""
|
||||
with open(json_file, "r") as reader:
|
||||
return cls.from_dict(json.load(reader))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
@ -0,0 +1,55 @@
|
|||
{
|
||||
"training_platform": {
|
||||
"modelarts": false
|
||||
},
|
||||
"dataset_config": {
|
||||
"random_seed": 50,
|
||||
"epochs": 6,
|
||||
"batch_size": 128,
|
||||
"dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json",
|
||||
"pre_train_dataset": null,
|
||||
"fine_tune_dataset": null,
|
||||
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001",
|
||||
"valid_dataset": null,
|
||||
"dataset_sink_mode": true,
|
||||
"dataset_sink_step": 2
|
||||
},
|
||||
"model_config": {
|
||||
"seq_length": 107,
|
||||
"vocab_size": 32320,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 4,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"attention_dropout_prob": 0.2,
|
||||
"initializer_range": 0.1,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 2,
|
||||
"length_penalty_weight": 0.6,
|
||||
"max_decode_length": 80
|
||||
},
|
||||
"loss_scale_config": {
|
||||
"init_loss_scale": 8192,
|
||||
"loss_scale_factor": 2,
|
||||
"scale_window": 128
|
||||
},
|
||||
"learn_rate_config": {
|
||||
"optimizer": "adam",
|
||||
"lr": 2e-3,
|
||||
"lr_scheduler": "WarmupMultiStepLR",
|
||||
"lr_scheduler_power": 0.5,
|
||||
"warmup_lr_remain_steps": 0.666,
|
||||
"warmup_lr_decay_interval": -1,
|
||||
"decay_steps": 4,
|
||||
"decay_start_step": -1,
|
||||
"warmup_steps": 200,
|
||||
"min_lr": 1e-6
|
||||
},
|
||||
"checkpoint_options": {
|
||||
"existed_ckpt": "/home/workspace/gnmt_v2/gnmt-6_3452.ckpt",
|
||||
"save_ckpt_steps": 3452,
|
||||
"keep_ckpt_max": 6,
|
||||
"ckpt_prefix": "gnmt",
|
||||
"ckpt_path": "text_translation"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create Dataset."""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from src.dataset.bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate dataset file.')
|
||||
parser.add_argument("--src_folder", type=str, default="/home/workspace/wmt16_de_en", required=False,
|
||||
help="Raw corpus folder.")
|
||||
|
||||
parser.add_argument("--output_folder", type=str, default="/home/workspace/dataset_menu",
|
||||
required=False,
|
||||
help="Dataset output path.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
if not os.path.exists(args.output_folder):
|
||||
os.makedirs(args.output_folder)
|
||||
dicts = []
|
||||
train_src_file = "train.tok.clean.bpe.32000.en"
|
||||
train_tgt_file = "train.tok.clean.bpe.32000.de"
|
||||
test_src_file = "newstest2014.en"
|
||||
test_tgt_file = "newstest2014.de"
|
||||
|
||||
vocab = args.src_folder + "/vocab.bpe.32000"
|
||||
bpe_codes = args.src_folder + "/bpe.32000"
|
||||
pad_vocab = 8
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, src_en='en', tgt_de='de', vocab_pad=pad_vocab)
|
||||
|
||||
test = TextDataLoader(
|
||||
src_filepath=os.path.join(args.src_folder, test_src_file),
|
||||
tokenizer=tokenizer,
|
||||
source_max_sen_len=None,
|
||||
schema_address=args.output_folder + "/" + test_src_file + ".json"
|
||||
)
|
||||
print(f" | It's writing, please wait a moment.")
|
||||
test.write_to_tfrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder,
|
||||
os.path.basename(test_src_file) + ".tfrecord"
|
||||
)
|
||||
)
|
||||
|
||||
train = BiLingualDataLoader(
|
||||
src_filepath=os.path.join(args.src_folder, train_src_file),
|
||||
tgt_filepath=os.path.join(args.src_folder, train_tgt_file),
|
||||
tokenizer=tokenizer,
|
||||
source_max_sen_len=51,
|
||||
target_max_sen_len=50,
|
||||
schema_address=args.output_folder + "/" + train_src_file + ".json"
|
||||
)
|
||||
print(f" | It's writing, please wait a moment.")
|
||||
train.write_to_tfrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder,
|
||||
os.path.basename(train_src_file) + ".tfrecord"
|
||||
)
|
||||
)
|
||||
|
||||
print(f" | Vocabulary size: {tokenizer.vocab_size}.")
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation api."""
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config import GNMTConfig
|
||||
from src.gnmt_model import infer
|
||||
from src.gnmt_model.bleu_calculate import bleu_calculate
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(description='gnmt')
|
||||
parser.add_argument("--config", type=str, required=True,
|
||||
help="model config json file path.")
|
||||
parser.add_argument("--vocab", type=str, required=True,
|
||||
help="Vocabulary to use.")
|
||||
parser.add_argument("--bpe_codes", type=str, required=True,
|
||||
help="bpe codes to use.")
|
||||
parser.add_argument("--test_tgt", type=str, required=False,
|
||||
default=None,
|
||||
help="data file of the test target")
|
||||
parser.add_argument("--output", type=str, required=False,
|
||||
default="./output.npz",
|
||||
help="result file path.")
|
||||
|
||||
|
||||
def get_config(config):
|
||||
config = GNMTConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
_config = get_config(args.config)
|
||||
result = infer(_config)
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(result, f, 1)
|
||||
|
||||
result_npy_addr = args.output
|
||||
vocab = args.vocab
|
||||
bpe_codes = args.bpe_codes
|
||||
test_tgt = args.test_tgt
|
||||
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
|
||||
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||
print(f"BLEU scores is :{scores}")
|
|
@ -0,0 +1,6 @@
|
|||
nltk
|
||||
jieba
|
||||
numpy
|
||||
subword-nmt==0.3.7
|
||||
sacrebleu==1.2.10
|
||||
sacremoses==0.0.19
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_TABLE_FILE=/home/workspace/rank_table_8p.json
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=/home/workspace/rank_table_8p.json
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
rm -rf ${current_exec_path}/device$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
cp ../../*.py .
|
||||
cp ../../*.sh .
|
||||
cp -r ../../src .
|
||||
cp -r ../../config .
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ../../train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network${i}.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
cd ${current_exec_path} || exit
|
|
@ -0,0 +1,33 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../config ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --config /home/workspace/gnmt_v2/config/config_test.json --vocab /home/workspace/wmt16_de_en/vocab.bpe.32000 --bpe_codes /home/workspace/wmt16_de_en/bpe.32000 --test_tgt /home/workspace/wmt16_de_en/newstest2014.de >log_infer.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,33 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=4
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../config ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GNMTv2 Init."""
|
||||
from .dataset import load_dataset
|
||||
from .dataset import bi_data_loader
|
||||
from .gnmt_model import GNMT, infer, GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
|
||||
from .gnmt_model import LabelSmoothedCrossEntropyCriterion
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"bi_data_loader",
|
||||
"GNMT",
|
||||
"infer",
|
||||
"GNMTNetworkWithLoss",
|
||||
"GNMTTrainOneStepWithLossScaleCell",
|
||||
"LabelSmoothedCrossEntropyCriterion"
|
||||
]
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset Init."""
|
||||
from .bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||
from .load_dataset import load_dataset
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"BiLingualDataLoader",
|
||||
"TextDataLoader",
|
||||
"Tokenizer"
|
||||
]
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class of data loader."""
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .schema import SCHEMA
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Data loader for dataset."""
|
||||
_SCHEMA = SCHEMA
|
||||
|
||||
def __init__(self):
|
||||
self._examples = []
|
||||
|
||||
def _load(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def padding(self, sen, padding_idx, need_sentence_len=None, dtype=np.int64):
|
||||
"""Padding <pad> to sentence."""
|
||||
if need_sentence_len is None:
|
||||
return None
|
||||
if sen.shape[0] > need_sentence_len:
|
||||
return None
|
||||
new_sen = np.array([padding_idx] * need_sentence_len, dtype=dtype)
|
||||
new_sen[:sen.shape[0]] = sen[:]
|
||||
return new_sen
|
||||
|
||||
def write_to_mindrecord(self, path, shard_num=1, desc=""):
|
||||
"""
|
||||
Write mindrecord file.
|
||||
|
||||
Args:
|
||||
path (str): File path.
|
||||
shard_num (int): Shard num.
|
||||
desc (str): Description.
|
||||
"""
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
|
||||
writer = FileWriter(file_name=path, shard_num=shard_num)
|
||||
writer.add_schema(self._SCHEMA, desc)
|
||||
if not self._examples:
|
||||
self._load()
|
||||
|
||||
writer.write_raw_data(self._examples)
|
||||
writer.commit()
|
||||
print(f"| Wrote to {path}.")
|
||||
|
||||
def write_to_tfrecord(self, path, shard_num=1):
|
||||
"""
|
||||
Write to tfrecord.
|
||||
|
||||
Args:
|
||||
path (str): Output file path.
|
||||
shard_num (int): Shard num.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
output_files = []
|
||||
for i in range(shard_num):
|
||||
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
|
||||
output_files.append(output_file)
|
||||
# create writers
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.io.TFRecordWriter(output_file))
|
||||
|
||||
if not self._examples:
|
||||
self._load()
|
||||
|
||||
# create feature
|
||||
features = collections.OrderedDict()
|
||||
for example in self._examples:
|
||||
for key in example:
|
||||
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
for writer in writers:
|
||||
writer.write(tf_example.SerializeToString())
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
for p in output_files:
|
||||
print(f" | Write to {p}.")
|
||||
|
||||
def _add_example(self, example):
|
||||
self._examples.append(example)
|
|
@ -0,0 +1,233 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bilingual data loader."""
|
||||
import numpy as np
|
||||
|
||||
from .base import DataLoader
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BiLingualDataLoader(DataLoader):
|
||||
"""Loader for bilingual data."""
|
||||
|
||||
def __init__(self,
|
||||
src_filepath: str,
|
||||
tgt_filepath: str,
|
||||
tokenizer: Tokenizer,
|
||||
min_sen_len=0,
|
||||
source_max_sen_len=None,
|
||||
target_max_sen_len=80,
|
||||
schema_address=None):
|
||||
super(BiLingualDataLoader, self).__init__()
|
||||
self._src_filepath = src_filepath
|
||||
self._tgt_filepath = tgt_filepath
|
||||
self.tokenizer = tokenizer
|
||||
self.min_sen_len = min_sen_len
|
||||
self.source_max_sen_len = source_max_sen_len
|
||||
self.target_max_sen_len = target_max_sen_len
|
||||
self.schema_address = schema_address
|
||||
|
||||
def _load(self):
|
||||
count = 0
|
||||
if self.source_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_src = 0
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = [
|
||||
int(self.tokenizer.tok2idx[t])
|
||||
for t in _pair.strip().split(" ") if t
|
||||
]
|
||||
src_len = len(src_tokens)
|
||||
if src_len > max_src:
|
||||
max_src = src_len
|
||||
self.source_max_sen_len = max_src + 2
|
||||
|
||||
if self.target_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _tgt_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_tgt = 0
|
||||
for _, _pair in enumerate(_tgt_file):
|
||||
src_tokens = [
|
||||
int(self.tokenizer.tok2idx[t])
|
||||
for t in _pair.strip().split(" ") if t
|
||||
]
|
||||
tgt_len = len(src_tokens)
|
||||
if tgt_len > max_tgt:
|
||||
max_tgt = tgt_len
|
||||
self.target_max_sen_len = max_tgt + 1
|
||||
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | Processing corpus {self._src_filepath}.")
|
||||
print(f" | Processing corpus {self._tgt_filepath}.")
|
||||
with open(self._tgt_filepath, "r") as _tgt_file:
|
||||
for _, _pair in enumerate(zip(_src_file, _tgt_file)):
|
||||
|
||||
src_tokens = [
|
||||
int(self.tokenizer.tok2idx[t])
|
||||
for t in _pair[0].strip().split(" ") if t
|
||||
]
|
||||
tgt_tokens = [
|
||||
int(self.tokenizer.tok2idx[t])
|
||||
for t in _pair[1].strip().split(" ") if t
|
||||
]
|
||||
src_tokens.insert(0, self.tokenizer.bos_index)
|
||||
src_tokens.append(self.tokenizer.eos_index)
|
||||
tgt_tokens.insert(0, self.tokenizer.bos_index)
|
||||
tgt_tokens.append(self.tokenizer.eos_index)
|
||||
src_tokens = np.array(src_tokens)
|
||||
tgt_tokens = np.array(tgt_tokens)
|
||||
src_len = src_tokens.shape[0]
|
||||
tgt_len = tgt_tokens.shape[0]
|
||||
|
||||
if (src_len > self.source_max_sen_len) or (src_len < self.min_sen_len) or (
|
||||
tgt_len > (self.target_max_sen_len + 1)) or (tgt_len < self.min_sen_len):
|
||||
print(f"+++++ delete! src_len={src_len}, tgt_len={tgt_len - 1}, "
|
||||
f"source_max_sen_len={self.source_max_sen_len},"
|
||||
f"target_max_sen_len={self.target_max_sen_len}")
|
||||
continue
|
||||
# encoder inputs
|
||||
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||
for i in range(src_len):
|
||||
src_padding[i] = 1
|
||||
src_length = np.array([src_len], dtype=np.int64)
|
||||
# decoder inputs
|
||||
decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||
# decoder outputs
|
||||
decoder_output = self.padding(tgt_tokens[1:], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||
tgt_padding = np.zeros(shape=self.target_max_sen_len + 1, dtype=np.int64)
|
||||
for j in range(tgt_len):
|
||||
tgt_padding[j] = 1
|
||||
tgt_padding = tgt_padding[1:]
|
||||
decoder_input = np.array(decoder_input, dtype=np.int64)
|
||||
decoder_output = np.array(decoder_output, dtype=np.int64)
|
||||
tgt_padding = np.array(tgt_padding, dtype=np.int64)
|
||||
|
||||
example = {
|
||||
"src": encoder_input,
|
||||
"src_padding": src_padding,
|
||||
"src_length": src_length,
|
||||
"prev_opt": decoder_input,
|
||||
"target": decoder_output,
|
||||
"tgt_padding": tgt_padding
|
||||
}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
|
||||
print(f" | source padding_len = {self.source_max_sen_len}.")
|
||||
print(f" | target padding_len = {self.target_max_sen_len}.")
|
||||
print(f" | Total activate sen = {count}.")
|
||||
print(f" | Total sen = {count}.")
|
||||
|
||||
if self.schema_address is not None:
|
||||
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1,
|
||||
self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]
|
||||
columns = ["src", "src_padding", "src_length", "prev_opt", "target", "tgt_padding"]
|
||||
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||
f.write("{\n")
|
||||
f.write(' "datasetType":"TF",\n')
|
||||
f.write(' "numRows":%s,\n' % provlist[0])
|
||||
f.write(' "columns":{\n')
|
||||
t = 1
|
||||
for name in columns:
|
||||
f.write(' "%s":{\n' % name)
|
||||
f.write(' "type":"int64",\n')
|
||||
f.write(' "rank":1,\n')
|
||||
f.write(' "shape":[%s]\n' % provlist[t])
|
||||
f.write(' }')
|
||||
if t < len(columns):
|
||||
f.write(',')
|
||||
f.write('\n')
|
||||
t += 1
|
||||
f.write(' }\n}\n')
|
||||
print(" | Write to " + self.schema_address)
|
||||
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
"""Loader for text data."""
|
||||
|
||||
def __init__(self,
|
||||
src_filepath: str,
|
||||
tokenizer: Tokenizer,
|
||||
min_sen_len=0,
|
||||
source_max_sen_len=None,
|
||||
schema_address=None):
|
||||
super(TextDataLoader, self).__init__()
|
||||
self._src_filepath = src_filepath
|
||||
self.tokenizer = tokenizer
|
||||
self.min_sen_len = min_sen_len
|
||||
self.source_max_sen_len = source_max_sen_len
|
||||
self.schema_address = schema_address
|
||||
|
||||
def _load(self):
|
||||
count = 0
|
||||
if self.source_max_sen_len is None:
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||
max_src = 0
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = self.tokenizer.tokenize(_pair)
|
||||
src_len = len(src_tokens)
|
||||
if src_len > max_src:
|
||||
max_src = src_len
|
||||
self.source_max_sen_len = max_src
|
||||
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | Processing corpus {self._src_filepath}.")
|
||||
for _, _pair in enumerate(_src_file):
|
||||
src_tokens = self.tokenizer.tokenize(_pair)
|
||||
src_len = len(src_tokens)
|
||||
src_tokens = np.array(src_tokens)
|
||||
# encoder inputs
|
||||
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||
for i in range(src_len):
|
||||
src_padding[i] = 1
|
||||
src_length = np.array([src_len], dtype=np.int64)
|
||||
|
||||
example = {
|
||||
"src": encoder_input,
|
||||
"src_padding": src_padding,
|
||||
"src_length": src_length
|
||||
}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
|
||||
print(f" | source padding_len = {self.source_max_sen_len}.")
|
||||
print(f" | Total activate sen = {count}.")
|
||||
print(f" | Total sen = {count}.")
|
||||
|
||||
if self.schema_address is not None:
|
||||
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1]
|
||||
columns = ["src", "src_padding", "src_length"]
|
||||
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||
f.write("{\n")
|
||||
f.write(' "datasetType":"TF",\n')
|
||||
f.write(' "numRows":%s,\n' % provlist[0])
|
||||
f.write(' "columns":{\n')
|
||||
t = 1
|
||||
for name in columns:
|
||||
f.write(' "%s":{\n' % name)
|
||||
f.write(' "type":"int64",\n')
|
||||
f.write(' "rank":1,\n')
|
||||
f.write(' "shape":[%s]\n' % provlist[t])
|
||||
f.write(' }')
|
||||
if t < len(columns):
|
||||
f.write(',')
|
||||
f.write('\n')
|
||||
t += 1
|
||||
f.write(' }\n}\n')
|
||||
print(" | Write to " + self.schema_address)
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset loader to feed into model."""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
|
||||
def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
|
||||
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True,
|
||||
drop_remainder=True, is_translate=False):
|
||||
"""
|
||||
Load dataset according to passed in params.
|
||||
|
||||
Args:
|
||||
input_files (list): Data files.
|
||||
schema_file (str): Schema file path.
|
||||
batch_size (int): Batch size.
|
||||
epoch_count (int): Epoch count.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
sink_step (int): Step to sink.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
drop_remainder (bool): Whether drop the last possibly incomplete batch.
|
||||
is_translate (bool): Whether translate the text.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
if not input_files:
|
||||
raise FileNotFoundError("Require at least one dataset.")
|
||||
|
||||
if not (schema_file and
|
||||
os.path.exists(schema_file)
|
||||
and os.path.isfile(schema_file)
|
||||
and os.path.basename(schema_file).endswith(".json")):
|
||||
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
|
||||
|
||||
if not isinstance(sink_mode, bool):
|
||||
raise ValueError("`sink` must be type of bool.")
|
||||
|
||||
for datafile in input_files:
|
||||
print(f" | Loading {datafile}.")
|
||||
|
||||
if not is_translate:
|
||||
ds = de.TFRecordDataset(
|
||||
input_files, schema_file,
|
||||
columns_list=[
|
||||
"src", "src_padding", "src_length",
|
||||
"prev_opt",
|
||||
"target", "tgt_padding"
|
||||
],
|
||||
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||
shard_equal_rows=True, num_parallel_workers=8)
|
||||
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print(f" | Dataset size: {ori_dataset_size}.")
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="prev_opt", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="target", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
ds = ds.rename(
|
||||
input_columns=["src",
|
||||
"src_padding",
|
||||
"src_length",
|
||||
"prev_opt",
|
||||
"target",
|
||||
"tgt_padding"],
|
||||
output_columns=["source_eos_ids",
|
||||
"source_eos_mask",
|
||||
"source_eos_length",
|
||||
"target_sos_ids",
|
||||
"target_eos_ids",
|
||||
"target_eos_mask"]
|
||||
)
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
else:
|
||||
ds = de.TFRecordDataset(
|
||||
input_files, schema_file,
|
||||
columns_list=[
|
||||
"src", "src_padding", "src_length"
|
||||
],
|
||||
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||
shard_equal_rows=True, num_parallel_workers=8)
|
||||
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print(f" | Dataset size: {ori_dataset_size}.")
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
ds = ds.rename(
|
||||
input_columns=["src",
|
||||
"src_padding",
|
||||
"src_length"],
|
||||
output_columns=["source_eos_ids",
|
||||
"source_eos_mask",
|
||||
"source_eos_length"]
|
||||
)
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: int, sink_mode: bool, sink_step: int = 1,
|
||||
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
|
||||
"""
|
||||
Load dataset.
|
||||
|
||||
Args:
|
||||
data_files (list): Data files.
|
||||
schema (str): Schema file path.
|
||||
batch_size (int): Batch size.
|
||||
epoch_count (int): Epoch count.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
sink_step (int): Step to sink.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
return _load_dataset(data_files, schema, batch_size, epoch_count, sink_mode,
|
||||
sink_step, rank_size, rank_id, shuffle=shuffle,
|
||||
drop_remainder=drop_remainder, is_translate=is_translate)
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define schema of mindrecord."""
|
||||
|
||||
SCHEMA = {
|
||||
"src": {"type": "int64", "shape": [-1]},
|
||||
"src_padding": {"type": "int64", "shape": [-1]},
|
||||
"src_length": {"type": "int64", "shape": [-1]},
|
||||
"prev_opt": {"type": "int64", "shape": [-1]},
|
||||
"target": {"type": "int64", "shape": [-1]},
|
||||
"tgt_padding": {"type": "int64", "shape": [-1]},
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tokenizer."""
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import subword_nmt.apply_bpe
|
||||
import sacremoses
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer class.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_address=None, bpe_code_address=None,
|
||||
src_en='en', tgt_de='de', vocab_pad=8, isolator='@@'):
|
||||
"""
|
||||
Constructor for the Tokenizer class.
|
||||
|
||||
Args:
|
||||
vocab_address: vocabulary address.
|
||||
bpe_code_address: path to the file with bpe codes.
|
||||
vocab_pad: pads vocabulary to a multiple of 'vocab_pad' tokens.
|
||||
isolator: tokenization isolator.
|
||||
"""
|
||||
self.padding_index = 0
|
||||
self.unk_index = 1
|
||||
self.bos_index = 2
|
||||
self.eos_index = 3
|
||||
self.pad_word = '<pad>'
|
||||
self.unk_word = '<unk>'
|
||||
self.bos_word = '<s>'
|
||||
self.eos_word = r'<\s>'
|
||||
self.isolator = isolator
|
||||
self.init_bpe(bpe_code_address)
|
||||
self.vocab_establist(vocab_address, vocab_pad)
|
||||
self.sacremoses_tokenizer = sacremoses.MosesTokenizer(src_en)
|
||||
self.sacremoses_detokenizer = sacremoses.MosesDetokenizer(tgt_de)
|
||||
|
||||
def init_bpe(self, bpe_code_address):
|
||||
"""Init bpe."""
|
||||
if (bpe_code_address is not None) and os.path.exists(bpe_code_address):
|
||||
with open(bpe_code_address, 'r') as f1:
|
||||
self.bpe = subword_nmt.apply_bpe.BPE(f1)
|
||||
|
||||
def vocab_establist(self, vocab_address, vocab_pad):
|
||||
"""Establish vocabulary."""
|
||||
if (vocab_address is None) or (not os.path.exists(vocab_address)):
|
||||
return
|
||||
vocab_words = [self.pad_word, self.unk_word, self.bos_word, self.eos_word]
|
||||
with open(vocab_address) as f1:
|
||||
for sentence in f1:
|
||||
vocab_words.append(sentence.strip())
|
||||
vocab_size = len(vocab_words)
|
||||
padded_vocab_size = (vocab_size + vocab_pad - 1) // vocab_pad * vocab_pad
|
||||
for idx in range(0, padded_vocab_size - vocab_size):
|
||||
fil_token = f'filled{idx:04d}'
|
||||
vocab_words.append(fil_token)
|
||||
self.vocab_size = len(vocab_words)
|
||||
self.tok2idx = defaultdict(partial(int, self.unk_index))
|
||||
for idx, token in enumerate(vocab_words):
|
||||
self.tok2idx[token] = idx
|
||||
self.idx2tok = {}
|
||||
self.idx2tok = defaultdict(partial(str, ","))
|
||||
for token, idx in self.tok2idx.items():
|
||||
self.idx2tok[idx] = token
|
||||
|
||||
def tokenize(self, sentence):
|
||||
"""Tokenize sentence."""
|
||||
tokenized = self.sacremoses_tokenizer.tokenize(sentence, return_str=True)
|
||||
bpe = self.bpe.process_line(tokenized)
|
||||
sentence = bpe.strip().split()
|
||||
inputs = [self.tok2idx[i] for i in sentence]
|
||||
inputs = [self.bos_index] + inputs + [self.eos_index]
|
||||
return inputs
|
||||
|
||||
def detokenize(self, indexes, gap=' '):
|
||||
"""Detokenizes single sentence and removes token isolator characters."""
|
||||
reconstruction_bpe = gap.join([self.idx2tok[idx] for idx in indexes])
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.isolator + ' ', '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.isolator, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.bos_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.eos_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.unk_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.replace(self.pad_word, '')
|
||||
reconstruction_bpe = reconstruction_bpe.strip()
|
||||
reconstruction_words = self.sacremoses_detokenizer.detokenize(reconstruction_bpe.split())
|
||||
return reconstruction_words
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GNMTv2 Init."""
|
||||
from config.config import GNMTConfig
|
||||
from .gnmt import GNMT
|
||||
from .attention import BahdanauAttention
|
||||
from .gnmt_for_train import GNMTTraining, LabelSmoothedCrossEntropyCriterion, \
|
||||
GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
|
||||
from .gnmt_for_infer import infer
|
||||
from .bleu_calculate import bleu_calculate
|
||||
|
||||
__all__ = [
|
||||
"infer",
|
||||
"GNMTTraining",
|
||||
"LabelSmoothedCrossEntropyCriterion",
|
||||
"GNMTTrainOneStepWithLossScaleCell",
|
||||
"GNMTNetworkWithLoss",
|
||||
"GNMT",
|
||||
"BahdanauAttention",
|
||||
"GNMTConfig",
|
||||
"bleu_calculate"
|
||||
]
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bahdanau attention block."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import Uniform
|
||||
|
||||
INF = 65504.0
|
||||
|
||||
|
||||
class BahdanauAttention(nn.Cell):
|
||||
"""
|
||||
Constructor for the BahdanauAttention.
|
||||
|
||||
Args:
|
||||
is_training (bool): Whether to train.
|
||||
query_size (int): feature dimension for query.
|
||||
key_size (int): feature dimension for keys.
|
||||
num_units (int): internal feature dimension.
|
||||
normalize (bool): Whether to normalize.
|
||||
initializer_range: range for uniform initializer parameters.
|
||||
|
||||
Returns:
|
||||
Tensor, shape (N, T, D).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
is_training,
|
||||
query_size,
|
||||
key_size,
|
||||
num_units,
|
||||
normalize=False,
|
||||
initializer_range=0.1,
|
||||
compute_type=mstype.float16):
|
||||
super(BahdanauAttention, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.mask = None
|
||||
self.query_size = query_size
|
||||
self.key_size = key_size
|
||||
self.normalize = normalize
|
||||
self.num_units = num_units
|
||||
self.linear_att = Parameter(Tensor(np.random.uniform(-initializer_range, initializer_range, size=[num_units]),
|
||||
dtype=mstype.float32), name='linear_att')
|
||||
if self.normalize:
|
||||
self.normalize_scalar = Parameter(Tensor(np.array([1.0 / num_units]), dtype=mstype.float32),
|
||||
name='normalize_scalar')
|
||||
self.normalize_bias = Parameter(Tensor(np.zeros(num_units), dtype=mstype.float32), name='normalize_bias')
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
self.shape_op = P.Shape()
|
||||
|
||||
self.linear_q = nn.Dense(query_size,
|
||||
num_units,
|
||||
has_bias=False,
|
||||
weight_init=Uniform(initializer_range)).to_float(compute_type)
|
||||
|
||||
self.linear_k = nn.Dense(key_size,
|
||||
num_units,
|
||||
has_bias=False,
|
||||
weight_init=Uniform(initializer_range)).to_float(compute_type)
|
||||
self.expand = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
|
||||
self.norm = nn.Norm(axis=-1)
|
||||
self.mul = P.Mul()
|
||||
self.matmul = P.MatMul()
|
||||
self.batchMatmul = P.BatchMatMul()
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, query, keys, attention_mask=None):
|
||||
"""
|
||||
Construct attention block.
|
||||
|
||||
Args:
|
||||
query (Tensor): Shape (t_q, N, D).
|
||||
keys (Tensor): Shape (t_k, N, D).
|
||||
attention_mask: Shape(N, t_k).
|
||||
Returns:
|
||||
Tensor, shape (N, t_q, D).
|
||||
"""
|
||||
|
||||
# (t_k, N, D) -> (N, t_k, D).
|
||||
keys = self.transpose(keys, self.transpose_orders)
|
||||
# (t_q, N, D) -> (N, t_q, D).
|
||||
query = self.transpose(query, self.transpose_orders)
|
||||
|
||||
query_shape = self.shape_op(query)
|
||||
b = query_shape[0]
|
||||
t_q = query_shape[1]
|
||||
t_k = self.shape_op(keys)[1]
|
||||
|
||||
# (N, t_q, D)
|
||||
query = self.reshape(query, (b * t_q, self.query_size))
|
||||
if self.is_training:
|
||||
query = self.cast(query, mstype.float16)
|
||||
processed_query = self.linear_q(query)
|
||||
if self.is_trining:
|
||||
processed_query = self.cast(processed_query, mstype.float32)
|
||||
processed_query = self.reshape(processed_query, (b, t_q, self.num_units))
|
||||
# (N, t_k, D)
|
||||
keys = self.reshape(keys, (b * t_k, self.key_size))
|
||||
if self.is_training:
|
||||
keys = self.cast(keys, mstype.float16)
|
||||
processed_key = self.linear_k(keys)
|
||||
if self.is_trining:
|
||||
processed_key = self.cast(processed_key, mstype.float32)
|
||||
processed_key = self.reshape(processed_key, (b, t_k, self.num_units))
|
||||
|
||||
# scores: (N , T_q, T_k)
|
||||
scores = self.calc_score(processed_query, processed_key)
|
||||
# attention_mask: (N, T_k)
|
||||
mask = attention_mask
|
||||
# [N, 1]
|
||||
if mask is not None:
|
||||
mask = 1.0 - mask
|
||||
mask = self.tile(self.expand(mask, 1), (1, t_q, 1))
|
||||
scores += mask * (-INF)
|
||||
# [b, t_q, t_k]
|
||||
scores_normalized = self.softmax(scores)
|
||||
|
||||
keys = self.reshape(keys, (b, t_k, self.key_size))
|
||||
if self.is_training:
|
||||
keys = self.cast(keys, mstype.float16)
|
||||
scores_normalized_fp16 = self.cast(scores_normalized, mstype.float16)
|
||||
else:
|
||||
scores_normalized_fp16 = scores_normalized
|
||||
|
||||
# (b, t_q, n)
|
||||
context_attention = self.batchMatmul(scores_normalized_fp16, keys)
|
||||
# [t_q,b,D]
|
||||
context_attention = self.transpose(context_attention, self.transpose_orders)
|
||||
if self.is_training:
|
||||
context_attention = self.cast(context_attention, mstype.float32)
|
||||
|
||||
return context_attention, scores_normalized
|
||||
|
||||
def calc_score(self, att_query, att_keys):
|
||||
"""
|
||||
Calculate Bahdanau score
|
||||
|
||||
Args:
|
||||
att_query: (N, T_q, D).
|
||||
att_keys: (N, T_k, D).
|
||||
|
||||
returns:
|
||||
scores: (N, T_q, T_k).
|
||||
"""
|
||||
b, t_k, n = self.shape_op(att_keys)
|
||||
t_q = self.shape_op(att_query)[1]
|
||||
# (b, t_q, t_k, n)
|
||||
att_query = self.tile(self.expand(att_query, 2), (1, 1, t_k, 1))
|
||||
att_keys = self.tile(self.expand(att_keys, 1), (1, t_q, 1, 1))
|
||||
# (b, t_q, t_k, n)
|
||||
sum_qk = att_query + att_keys
|
||||
|
||||
if self.normalize:
|
||||
# (b, t_q, t_k, n)
|
||||
sum_qk = sum_qk + self.normalize_bias
|
||||
linear_att = self.linear_att / self.norm(self.linear_att)
|
||||
linear_att = self.cast(linear_att, mstype.float32)
|
||||
linear_att = self.mul(linear_att, self.normalize_scalar)
|
||||
else:
|
||||
linear_att = self.linear_att
|
||||
|
||||
linear_att = self.expand(linear_att, -1)
|
||||
sum_qk = self.reshape(sum_qk, (-1, n))
|
||||
|
||||
tanh_sum_qk = self.tanh(sum_qk)
|
||||
if self.is_training:
|
||||
linear_att = self.cast(linear_att, mstype.float16)
|
||||
tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16)
|
||||
|
||||
out = self.matmul(tanh_sum_qk, linear_att)
|
||||
|
||||
# (b, t_q, t_k)
|
||||
out = self.reshape(out, (b, t_q, t_k))
|
||||
if self.is_training:
|
||||
out = self.cast(out, mstype.float32)
|
||||
return out
|
|
@ -0,0 +1,421 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Beam search decoder."""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
INF = 65536.0
|
||||
|
||||
|
||||
class LengthPenalty(nn.Cell):
|
||||
"""
|
||||
Length penalty.
|
||||
|
||||
Args:
|
||||
weight (float): The length penalty weight.
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, weight=1.0, compute_type=mstype.float32):
|
||||
super(LengthPenalty, self).__init__()
|
||||
self.weight = weight
|
||||
self.add = P.TensorAdd()
|
||||
self.pow = P.Pow()
|
||||
self.div = P.RealDiv()
|
||||
self.five = Tensor(5.0, mstype.float32)
|
||||
self.six = Tensor(6.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, length_tensor):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
length_tensor (Tensor): the input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, after punishment of length.
|
||||
"""
|
||||
length_tensor = self.cast(length_tensor, mstype.float32)
|
||||
output = self.add(length_tensor, self.five)
|
||||
output = self.div(output, self.six)
|
||||
output = self.pow(output, self.weight)
|
||||
return output
|
||||
|
||||
|
||||
class TileBeam(nn.Cell):
|
||||
"""
|
||||
Beam Tile operation.
|
||||
|
||||
Args:
|
||||
beam_width (int): The Number of beam.
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, beam_width, compute_type=mstype.float32):
|
||||
super(TileBeam, self).__init__()
|
||||
self.beam_width = beam_width
|
||||
|
||||
self.expand = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
input_tensor (Tensor): with shape (N, T, D).
|
||||
|
||||
Returns:
|
||||
Tensor, tiled tensor.
|
||||
"""
|
||||
shape = self.shape(input_tensor)
|
||||
# add an dim
|
||||
input_tensor = self.expand(input_tensor, 1)
|
||||
# get tile shape: [1, beam, ...]
|
||||
# shape = self.shape(input_tensor)
|
||||
tile_shape = (1,) + (self.beam_width,)
|
||||
for _ in range(len(shape) - 1):
|
||||
tile_shape = tile_shape + (1,)
|
||||
# tile
|
||||
output = self.tile(input_tensor, tile_shape)
|
||||
# reshape to [batch*beam, ...]
|
||||
out_shape = (shape[0] * self.beam_width,) + shape[1:]
|
||||
output = self.reshape(output, out_shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Mod(nn.Cell):
|
||||
"""
|
||||
Mod operation.
|
||||
|
||||
Args:
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
compute_type=mstype.float32):
|
||||
super(Mod, self).__init__()
|
||||
self.compute_type = compute_type
|
||||
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.sub = P.Sub()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
"""
|
||||
Get the remainder of input_x and input_y.
|
||||
|
||||
Inputs:
|
||||
input_x (Tensor): Divisor.
|
||||
input_y (Tensor): Dividend.
|
||||
|
||||
Returns:
|
||||
Tensor, remainder.
|
||||
"""
|
||||
x = self.floor_div(input_x, input_y)
|
||||
x = self.multiply(x, input_y)
|
||||
x = self.sub(input_x, x)
|
||||
return x
|
||||
|
||||
|
||||
class BeamSearchDecoder(nn.Cell):
|
||||
"""
|
||||
Beam search decoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence.
|
||||
vocab_size (int): The shape of each embedding vector.
|
||||
decoder (Cell): The GNMT decoder.
|
||||
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||
decoder_layers_nums (int): The nums of decoder layers.
|
||||
length_penalty_weight (float): Penalty for sentence length. Default: 0.6.
|
||||
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||
sos_id (int): The index of start label <SOS>. Default: 1.
|
||||
eos_id (int): The index of end label <EOS>. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, predictions output.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
vocab_size,
|
||||
decoder,
|
||||
beam_width=4,
|
||||
decoder_layers_nums=4,
|
||||
length_penalty_weight=0.6,
|
||||
cov_penalty_factor=0.1,
|
||||
hidden_size=1024,
|
||||
max_decode_length=64,
|
||||
sos_id=2,
|
||||
eos_id=3,
|
||||
compute_type=mstype.float32):
|
||||
super(BeamSearchDecoder, self).__init__()
|
||||
|
||||
self.encoder_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.beam_width = beam_width
|
||||
self.decoder_layers_nums = decoder_layers_nums
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.cov_penalty_factor = cov_penalty_factor
|
||||
self.max_decode_length = max_decode_length
|
||||
self.decoder = decoder
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
self.expand = P.ExpandDims()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_flat = (-1,)
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
|
||||
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
|
||||
|
||||
self.select = P.Select()
|
||||
self.flat_shape = (batch_size, beam_width * vocab_size)
|
||||
self.topk = P.TopK(sorted=True)
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
|
||||
self.real_div = P.RealDiv()
|
||||
self.mod = Mod()
|
||||
self.equal = P.Equal()
|
||||
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
|
||||
|
||||
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
|
||||
self.beam_ids = Tensor(beam_ids, mstype.int32)
|
||||
|
||||
batch_ids = np.arange(batch_size * beam_width).reshape((batch_size, beam_width)) // beam_width
|
||||
self.batch_ids = Tensor(batch_ids, mstype.int32)
|
||||
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.gather_nd = P.GatherNd()
|
||||
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), mstype.int32)
|
||||
|
||||
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
|
||||
self.init_scores = Tensor(init_scores, mstype.float32)
|
||||
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
|
||||
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
||||
|
||||
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
|
||||
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.prob_concat = P.Concat(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.decoder_hidden_state = Tensor(np.zeros([self.decoder_layers_nums, 2,
|
||||
self.batch_size * self.beam_width,
|
||||
hidden_size]), mstype.float32)
|
||||
|
||||
self.zeros_scores = Tensor(np.zeros([batch_size, beam_width], dtype=np.float))
|
||||
self.active_index = Tensor(np.ones([batch_size, beam_width], dtype=np.int32))
|
||||
self.init_zeros = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
||||
self.init_ones = Tensor(np.ones([batch_size, beam_width], dtype=np.float32))
|
||||
|
||||
self.accu_attn_scores = Tensor(np.zeros([batch_size, beam_width, self.encoder_length], dtype=np.float32))
|
||||
|
||||
self.zeros = Tensor([0], mstype.int32)
|
||||
self.eos_tensor = Tensor(np.full([batch_size, beam_width, beam_width], eos_id), mstype.int32)
|
||||
|
||||
self.ones_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 1), mstype.float32)
|
||||
self.neg_inf_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], -INF), mstype.float32)
|
||||
self.zeros_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 0), mstype.float32)
|
||||
self.zeros_2d = Tensor(np.full([batch_size * beam_width, self.encoder_length], 0), mstype.int32)
|
||||
self.argmin = P.ArgMinWithValue(axis=1)
|
||||
self.reducesum = P.ReduceSum()
|
||||
self.div = P.Div()
|
||||
self.shape_op = P.Shape()
|
||||
self.mul = P.Mul()
|
||||
self.log = P.Log()
|
||||
self.less = P.Less()
|
||||
self.tile = P.Tile()
|
||||
self.noteq = P.Neg()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.greater_equal = P.GreaterEqual()
|
||||
self.sub = P.Sub()
|
||||
|
||||
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||
state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None,
|
||||
state_finished=None):
|
||||
"""
|
||||
Beam search one_step output.
|
||||
|
||||
Inputs:
|
||||
cur_input_ids (Tensor): with shape (batch_size * beam_width, 1).
|
||||
enc_states (Tensor): with shape (batch_size * beam_width, T, D).
|
||||
enc_attention_mask (Tensor): with shape (batch_size * beam_width, T).
|
||||
state_log_probs (Tensor): with shape (batch_size, beam_width).
|
||||
state_seq (Tensor): with shape (batch_size, beam_width, max_decoder_length).
|
||||
state_length (Tensor): with shape (batch_size, beam_width).
|
||||
idx (Tensor): with shape ().
|
||||
decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D).
|
||||
accu_attn_scores (Tensor): with shape (batchsize, beam_width, seq_length).
|
||||
state_finished (Tensor): with shape (batch_size, beam_width).
|
||||
"""
|
||||
|
||||
# log_probs, [batch_size * beam_width, 1, V]
|
||||
log_probs, all_decoder_state, attn = self.decoder(cur_input_ids, enc_states, enc_attention_mask,
|
||||
decoder_hidden_state)
|
||||
# consider attention_scores
|
||||
attn = self.reshape(attn, (-1, self.beam_width, self.encoder_length))
|
||||
state_finished_attn = self.cast(state_finished, mstype.int32)
|
||||
attn_mask_0 = self.tile(self.expand(state_finished_attn, 2), (1, 1, self.encoder_length))
|
||||
attn_mask_0 = self.cast(attn_mask_0, mstype.bool_)
|
||||
attn_new = self.select(attn_mask_0, self.zeros_3d, attn)
|
||||
accu_attn_scores = self.add(accu_attn_scores, attn_new)
|
||||
|
||||
# log_probs: [batch_size, beam_width, V]
|
||||
log_probs = self.reshape(log_probs, (-1, self.beam_width, self.vocab_size))
|
||||
# select topk indices, [batch_size, beam_width, V]
|
||||
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
|
||||
# mask finished beams, [batch_size, beam_width]
|
||||
# t-1 has finished
|
||||
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
|
||||
# save the t-1 probability
|
||||
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
|
||||
# [batch, beam*vocab]
|
||||
flat_scores = self.reshape(total_log_probs, (-1, self.beam_width * self.vocab_size))
|
||||
# select topk, [batch, beam]
|
||||
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
||||
|
||||
# convert to beam and word indices, [batch, beam]
|
||||
# beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor)
|
||||
# word_indices = self.mod(topk_indices, self.vocab_size_tensor)
|
||||
# ======================================================================
|
||||
# replace floor_div and mod op, since these two ops only support fp16 on
|
||||
# Ascend310, which will cause overflow.
|
||||
temp = topk_indices
|
||||
beam_indices = self.zeroslike(topk_indices)
|
||||
for _ in range(self.beam_width - 1):
|
||||
temp = self.sub(temp, self.vocab_size_tensor)
|
||||
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
|
||||
beam_indices = beam_indices + res
|
||||
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
|
||||
# ======================================================================
|
||||
|
||||
# mask finished indices, [batch, beam]
|
||||
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
||||
word_indices = self.select(state_finished, self.eos_ids, word_indices)
|
||||
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
|
||||
|
||||
# sort according to scores with -inf for finished beams, [batch, beam]
|
||||
# t ends
|
||||
tmp_log_probs = self.select(
|
||||
self.equal(word_indices, self.eos_ids),
|
||||
self.ninf_tensor,
|
||||
topk_scores)
|
||||
|
||||
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
|
||||
# update, [batch_size, beam_width, 2]
|
||||
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
|
||||
# [batch_size, beam_width]
|
||||
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
|
||||
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
|
||||
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
|
||||
|
||||
# gather indices for selecting alive beams
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
|
||||
|
||||
# length add 1 if not finished in the previous step, [batch_size, beam_width]
|
||||
length_add = self.add(state_length, self.one)
|
||||
state_length = self.select(state_finished, state_length, length_add)
|
||||
state_length = self.gather_nd(state_length, gather_indices)
|
||||
# concat seq
|
||||
seq = self.gather_nd(state_seq, gather_indices)
|
||||
# update accu_attn_scores
|
||||
accu_attn_scores = self.gather_nd(accu_attn_scores, gather_indices)
|
||||
# update all_decoder_state
|
||||
all_decoder_state = self.reshape(all_decoder_state,
|
||||
(self.decoder_layers_nums * 2, self.batch_size, self.beam_width,
|
||||
self.hidden_size))
|
||||
for i in range(self.decoder_layers_nums * 2):
|
||||
all_decoder_state[i, :, :, :] = self.gather_nd(all_decoder_state[i, :, :, :], gather_indices)
|
||||
all_decoder_state = self.reshape(all_decoder_state,
|
||||
(self.decoder_layers_nums, 2, self.batch_size * self.beam_width,
|
||||
self.hidden_size))
|
||||
|
||||
# update state_seq
|
||||
state_seq_new = self.cast(seq, mstype.float32)
|
||||
word_indices_fp32 = self.cast(word_indices, mstype.float32)
|
||||
state_seq_new[:, :, idx] = word_indices_fp32
|
||||
state_seq = self.cast(state_seq_new, mstype.int32)
|
||||
|
||||
cur_input_ids = self.reshape(word_indices, (-1, 1))
|
||||
state_log_probs = topk_scores
|
||||
state_finished = self.equal(word_indices, self.eos_ids)
|
||||
|
||||
return cur_input_ids, state_log_probs, state_seq, state_length, \
|
||||
all_decoder_state, accu_attn_scores, state_finished
|
||||
|
||||
def construct(self, enc_states, enc_attention_mask):
|
||||
"""
|
||||
Process source sentence
|
||||
|
||||
Inputs:
|
||||
enc_states (Tensor): Output of transformer encoder with shape (batch_size * beam_width, T, D).
|
||||
enc_attention_mask (Tensor): encoder attention mask with shape (batch_size * beam_width, T).
|
||||
|
||||
Returns:
|
||||
Tensor, predictions output.
|
||||
"""
|
||||
# beam search start
|
||||
cur_input_ids = self.start_ids
|
||||
state_log_probs = self.init_scores
|
||||
state_seq = self.init_seq
|
||||
state_finished = self.init_finished
|
||||
state_length = self.init_length
|
||||
decoder_hidden_state = self.decoder_hidden_state
|
||||
accu_attn_scores = self.accu_attn_scores
|
||||
|
||||
idx = self.start + 1
|
||||
ends = self.start + self.max_decode_length + 1
|
||||
while idx < ends:
|
||||
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
||||
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores,
|
||||
state_finished)
|
||||
idx = idx + 1
|
||||
|
||||
# add length penalty scores
|
||||
penalty_len = self.length_penalty(state_length)
|
||||
# return penalty_len
|
||||
log_probs = self.real_div(state_log_probs, penalty_len)
|
||||
penalty_cov = C.clip_by_value(accu_attn_scores, 0.0, 1.0)
|
||||
penalty_cov = self.log(penalty_cov)
|
||||
penalty_less = self.less(penalty_cov, self.neg_inf_3d)
|
||||
penalty = self.select(penalty_less, self.zeros_3d, penalty_cov)
|
||||
penalty = self.reducesum(penalty, 2)
|
||||
log_probs = log_probs + penalty * self.cov_penalty_factor
|
||||
# sort according to scores
|
||||
_, top_beam_indices = self.topk(log_probs, self.beam_width)
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
||||
# sort sequence and attention scores
|
||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
||||
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]
|
||||
|
||||
return predicted_ids
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Calculate the blue scores"""
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
from src.dataset.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def load_result_data(result_npy_addr):
|
||||
# load the numpy to list.
|
||||
result = np.load(result_npy_addr, allow_pickle=True)
|
||||
return result
|
||||
|
||||
|
||||
def get_bleu_data(tokenizer: Tokenizer, result_npy_addr):
|
||||
"""
|
||||
Detokenizer the prediction.
|
||||
|
||||
Args:
|
||||
tokenizer (Tokenizer): tokenizer operations.
|
||||
result_npy_addr (string): Path to the predict file.
|
||||
|
||||
Returns:
|
||||
List, the predict text context.
|
||||
"""
|
||||
|
||||
result = load_result_data(result_npy_addr)
|
||||
prediction_list = []
|
||||
for _, info in enumerate(result):
|
||||
# prediction detokenize
|
||||
prediction = info["prediction"]
|
||||
prediction_str = tokenizer.detokenize(prediction)
|
||||
prediction_list.append(prediction_str)
|
||||
|
||||
return prediction_list
|
||||
|
||||
|
||||
def calculate_sacrebleu(predict_path, target_path):
|
||||
"""
|
||||
Calculate the BLEU scores.
|
||||
|
||||
Args:
|
||||
predict_path (string): Path to the predict file.
|
||||
target_path (string): Path to the target file.
|
||||
|
||||
Returns:
|
||||
Float32, bleu scores.
|
||||
"""
|
||||
|
||||
sacrebleu_params = '--score-only -lc --tokenize intl'
|
||||
sacrebleu = subprocess.run([f'sacrebleu --input {predict_path} \
|
||||
{target_path} {sacrebleu_params}'],
|
||||
stdout=subprocess.PIPE, shell=True)
|
||||
bleu_scores = round(float(sacrebleu.stdout.strip()), 2)
|
||||
return bleu_scores
|
||||
|
||||
|
||||
def bleu_calculate(tokenizer, result_npy_addr, target_addr=None):
|
||||
"""
|
||||
Calculate the BLEU scores.
|
||||
|
||||
Args:
|
||||
tokenizer (Tokenizer): tokenizer operations.
|
||||
result_npy_addr (string): Path to the predict file.
|
||||
target_addr (string): Path to the target file.
|
||||
|
||||
Returns:
|
||||
Float32, bleu scores.
|
||||
"""
|
||||
|
||||
prediction = get_bleu_data(tokenizer, result_npy_addr)
|
||||
print("predict:\n", prediction)
|
||||
|
||||
eval_path = './predict.txt'
|
||||
with open(eval_path, 'w') as eval_file:
|
||||
lines = [line + '\n' for line in prediction]
|
||||
eval_file.writelines(lines)
|
||||
reference_path = target_addr
|
||||
bleu_scores = calculate_sacrebleu(eval_path, reference_path)
|
||||
return bleu_scores
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Components of model."""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""Cast wrapper."""
|
||||
|
||||
def __init__(self, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
return self.cast(x, self.dst_type)
|
||||
|
||||
|
||||
class LayerNorm(nn.Cell):
|
||||
"""
|
||||
Do layer norm.
|
||||
|
||||
Args:
|
||||
in_channels (int): In channels number of layer norm.
|
||||
return_2d (bool): Whether return 2d tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=None, return_2d=False):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.return_2d = return_2d
|
||||
self.layer_norm = nn.LayerNorm((in_channels,))
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
self.reshape = P.Reshape()
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""Do layer norm."""
|
||||
shape = self.get_shape(input_tensor)
|
||||
batch_size = shape[0]
|
||||
max_len = shape[1]
|
||||
embed_dim = shape[2]
|
||||
|
||||
output = self.reshape(input_tensor, (-1, embed_dim))
|
||||
output = self.cast(output, mstype.float32)
|
||||
output = self.layer_norm(output)
|
||||
output = self.cast(output, self.get_dtype(input_tensor))
|
||||
if not self.return_2d:
|
||||
output = self.reshape(output, (batch_size, max_len, embed_dim))
|
||||
return output
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create attention block."""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
|
||||
from .attention import BahdanauAttention
|
||||
|
||||
|
||||
class RecurrentAttention(nn.Cell):
|
||||
"""
|
||||
Constructor for the RecurrentAttention.
|
||||
|
||||
Args:
|
||||
input_size: number of features in input tensor.
|
||||
context_size: number of features in output from encoder.
|
||||
hidden_size: internal hidden size.
|
||||
num_layers: number of layers in LSTM.
|
||||
dropout: probability of dropout (on input to LSTM layer).
|
||||
initializer_range: range for the uniform initializer.
|
||||
|
||||
Returns:
|
||||
Tensor, shape (N, T, D).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rnn,
|
||||
is_training=True,
|
||||
input_size=1024,
|
||||
context_size=1024,
|
||||
hidden_size=1024,
|
||||
num_layers=1,
|
||||
dropout=0.2,
|
||||
initializer_range=0.1):
|
||||
super(RecurrentAttention, self).__init__()
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - dropout)
|
||||
self.rnn = rnn
|
||||
self.attn = BahdanauAttention(is_training=is_training,
|
||||
query_size=hidden_size,
|
||||
key_size=hidden_size,
|
||||
num_units=hidden_size,
|
||||
normalize=True,
|
||||
initializer_range=initializer_range,
|
||||
compute_type=mstype.float16)
|
||||
|
||||
def construct(self, decoder_embedding, context_key, attention_mask=None, rnn_init_state=None):
|
||||
# decoder_embedding: [t_q,N,D]
|
||||
# context: [t_k,N,D]
|
||||
# attention_mask: [N,t_k]
|
||||
# [t_q,N,D]
|
||||
decoder_embedding = self.dropout(decoder_embedding)
|
||||
rnn_outputs, rnn_state = self.rnn(decoder_embedding, rnn_init_state)
|
||||
# rnn_outputs:[t_q,b,D], attn_outputs:[t_q,b,D], scores:[b, t_q, t_k], rnn_state:tuple([2,b,D]).
|
||||
attn_outputs, scores = self.attn(query=rnn_outputs, keys=context_key, attention_mask=attention_mask)
|
||||
return rnn_outputs, attn_outputs, rnn_state, scores
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create attention paddings from input paddings."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class CreateAttentionPaddingsFromInputPaddings(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config class.
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (N, T, T).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training=True):
|
||||
super(CreateAttentionPaddingsFromInputPaddings, self).__init__()
|
||||
|
||||
self.is_training = is_training
|
||||
self.input_mask = None
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.batch_matmul = P.BatchMatMul()
|
||||
self.multiply = P.Mul()
|
||||
self.shape = P.Shape()
|
||||
# mask future positions
|
||||
ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))
|
||||
self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
||||
|
||||
def construct(self, input_mask, mask_future=False):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_mask (Tensor): Tensor mask vectors with shape (N, T).
|
||||
mask_future (bool): Whether mask future (for decoder training).
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (N, T, T).
|
||||
"""
|
||||
input_shape = self.shape(input_mask)
|
||||
# Add this for infer as the seq_length will increase.
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
shape_left = input_shape + (1,)
|
||||
if self.is_training:
|
||||
input_mask = self.cast(input_mask, mstype.float16)
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
|
||||
attention_mask = self.batch_matmul(mask_left, mask_right)
|
||||
if self.is_training:
|
||||
attention_mask = self.cast(attention_mask, mstype.float32)
|
||||
|
||||
if mask_future:
|
||||
attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)
|
||||
|
||||
return attention_mask
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Decoder of GNMT."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.initializer import Uniform
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import GNMTConfig
|
||||
from .dynamic_rnn import DynamicRNNNet
|
||||
from .create_attention import RecurrentAttention
|
||||
|
||||
|
||||
class GNMTDecoder(nn.Cell):
|
||||
"""
|
||||
Implements of Transformer decoder.
|
||||
|
||||
Args:
|
||||
attn_embed_dim (int): Dimensions of attention layer.
|
||||
decoder_layers (int): Decoder layers.
|
||||
num_attn_heads (int): Attention heads number.
|
||||
intermediate_size (int): Hidden size of FFN.
|
||||
attn_dropout_prob (float): Dropout rate in attention. Default: 0.1.
|
||||
initializer_range (float): Initial range. Default: 0.02.
|
||||
dropout_prob (float): Dropout rate between layers. Default: 0.1.
|
||||
hidden_act (str): Non-linear activation function in FFN. Default: "relu".
|
||||
compute_type (mstype): Mindspore data type. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (N, T', D).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: GNMTConfig,
|
||||
is_training: bool,
|
||||
use_one_hot_embeddings: bool = False,
|
||||
initializer_range=0.1,
|
||||
infer_beam_width=1,
|
||||
compute_type=mstype.float16):
|
||||
super(GNMTDecoder, self).__init__()
|
||||
|
||||
self.is_training = is_training
|
||||
self.attn_embed_dim = config.hidden_size
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_dropout_prob = config.hidden_dropout_prob
|
||||
self.vocab_size = config.vocab_size
|
||||
self.seq_length = config.max_decode_length
|
||||
# batchsize* beam_width for beam_search.
|
||||
self.batch_size = config.batch_size * infer_beam_width
|
||||
self.word_embed_dim = config.hidden_size
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.state_concat = P.Concat(axis=0)
|
||||
self.all_decoder_state = Tensor(np.zeros([self.num_layers, 2, self.batch_size, config.hidden_size]),
|
||||
mstype.float32)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(0, self.num_layers):
|
||||
if i == 0:
|
||||
# the inputs is [T,D,N]
|
||||
scaler = 1
|
||||
else:
|
||||
# the inputs is [T,D,2N]
|
||||
scaler = 2
|
||||
layer = DynamicRNNNet(seq_length=self.seq_length,
|
||||
batchsize=self.batch_size,
|
||||
word_embed_dim=scaler * self.word_embed_dim,
|
||||
hidden_size=self.word_embed_dim)
|
||||
decoder_layers.append(layer)
|
||||
self.decoder_layers = nn.CellList(decoder_layers)
|
||||
|
||||
self.att_rnn = RecurrentAttention(rnn=self.decoder_layers[0],
|
||||
is_training=is_training,
|
||||
input_size=self.word_embed_dim,
|
||||
context_size=self.attn_embed_dim,
|
||||
hidden_size=self.attn_embed_dim,
|
||||
num_layers=1,
|
||||
dropout=config.attention_dropout_prob)
|
||||
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
|
||||
|
||||
self.classifier = nn.Dense(config.hidden_size,
|
||||
config.vocab_size,
|
||||
has_bias=True,
|
||||
weight_init=Uniform(initializer_range),
|
||||
bias_init=Uniform(initializer_range)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
self.shape_op = P.Shape()
|
||||
self.expand = P.ExpandDims()
|
||||
|
||||
def construct(self, tgt_embeddings, encoder_outputs, attention_mask=None,
|
||||
decoder_init_state=None):
|
||||
"""Decoder."""
|
||||
# tgt_embeddings: [T',N,D], encoder_outputs: [T,N,D], attention_mask: [N,T].
|
||||
query_shape = self.shape_op(tgt_embeddings)
|
||||
if decoder_init_state is None:
|
||||
hidden_state = self.all_decoder_state
|
||||
else:
|
||||
hidden_state = decoder_init_state
|
||||
# x:[t_q,b,D], attn:[t_q,b,D], scores:[b, t_q, t_k], state_0:[2,b,D].
|
||||
x, attn, state_0, scores = self.att_rnn(decoder_embedding=tgt_embeddings, context_key=encoder_outputs,
|
||||
attention_mask=attention_mask, rnn_init_state=hidden_state[0, :, :, :])
|
||||
x = self.concat((x, attn))
|
||||
x = self.dropout(x)
|
||||
decoder_outputs, state_1 = self.decoder_layers[1](x, hidden_state[1, :, :, :])
|
||||
|
||||
all_decoder_state = self.state_concat((self.expand(state_0, 0), self.expand(state_1, 0)))
|
||||
|
||||
for i in range(2, self.num_layers):
|
||||
residual = decoder_outputs
|
||||
decoder_outputs = self.concat((decoder_outputs, attn))
|
||||
|
||||
decoder_outputs = self.dropout(decoder_outputs)
|
||||
# 1st unidirectional layer. encoder_outputs: [T,N,D]
|
||||
decoder_outputs, decoder_state = self.decoder_layers[i](decoder_outputs, hidden_state[i, :, :, :])
|
||||
decoder_outputs += residual
|
||||
all_decoder_state = self.state_concat((all_decoder_state, self.expand(decoder_state, 0)))
|
||||
|
||||
# [m, batch_size * beam_width, D]
|
||||
decoder_outputs = self.reshape(decoder_outputs, (-1, self.word_embed_dim))
|
||||
if self.is_training:
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float16)
|
||||
decoder_outputs = self.classifier(decoder_outputs)
|
||||
if self.is_training:
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
|
||||
# [m, batch_size * beam_width, V]
|
||||
decoder_outputs = self.reshape(decoder_outputs, (query_shape[0], query_shape[1], self.vocab_size))
|
||||
# all_decoder_state:[4,2,b,D]
|
||||
return decoder_outputs, all_decoder_state, scores
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Decoder for beam_search of GNMT."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from .embedding import EmbeddingLookup
|
||||
from .decoder import GNMTDecoder
|
||||
from .create_attn_padding import CreateAttentionPaddingsFromInputPaddings
|
||||
from .components import SaturateCast
|
||||
|
||||
|
||||
class PredLogProbs(nn.Cell):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): The length of sequences.
|
||||
width (int): Number of parameters of a layer
|
||||
compute_type (int): Type of input type.
|
||||
dtype (int): Type of MindSpore output type.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
width,
|
||||
compute_type=mstype.float32,
|
||||
dtype=mstype.float32):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.width = width
|
||||
self.compute_type = compute_type
|
||||
self.dtype = dtype
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
# self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, logits):
|
||||
"""
|
||||
Calculate the log_softmax.
|
||||
|
||||
Inputs:
|
||||
input_tensor (Tensor): A batch of sentences with shape (N, T).
|
||||
output_weights (Tensor): A batch of masks with shape (N, T).
|
||||
|
||||
Returns:
|
||||
Tensor, the prediction probability with shape (N, T').
|
||||
"""
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class BeamDecoderStep(nn.Cell):
|
||||
"""
|
||||
Multi-layer transformer decoder step.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): The config of Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
use_one_hot_embeddings,
|
||||
compute_type=mstype.float32):
|
||||
super(BeamDecoderStep, self).__init__(auto_prefix=True)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.word_embed_dim = config.hidden_size
|
||||
self.embedding_lookup = EmbeddingLookup(
|
||||
is_training=False,
|
||||
vocab_size=config.vocab_size,
|
||||
embed_dim=self.word_embed_dim,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
self.projection = PredLogProbs(
|
||||
batch_size=config.batch_size * config.beam_width,
|
||||
seq_length=1,
|
||||
width=config.vocab_size,
|
||||
compute_type=config.compute_type)
|
||||
|
||||
self.seq_length = config.max_decode_length
|
||||
self.decoder = GNMTDecoder(config,
|
||||
is_training=False,
|
||||
infer_beam_width=config.beam_width)
|
||||
|
||||
self.ones_like = P.OnesLike()
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,
|
||||
is_training=False)
|
||||
self.expand = P.ExpandDims()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
ones = np.ones(shape=(config.max_decode_length, config.max_decode_length))
|
||||
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
|
||||
def construct(self, input_ids, enc_states, enc_attention_mask, decoder_hidden_state=None):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size * beam_width, m]
|
||||
enc_states: [batch_size * beam_width, T, D]
|
||||
enc_attention_mask: [batch_size * beam_width, T]
|
||||
decoder_hidden_state: [decoder_layers_nums, 2, batch_size * beam_width, hidden_size].
|
||||
|
||||
Returns:
|
||||
Tensor, the log_probs. [batch_size * beam_width, 1, vocabulary_size]
|
||||
"""
|
||||
|
||||
# process embedding. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D]
|
||||
input_embedding, _ = self.embedding_lookup(input_ids)
|
||||
input_embedding = self.cast_compute_type(input_embedding)
|
||||
|
||||
input_shape = self.shape(input_ids)
|
||||
input_len = input_shape[1]
|
||||
# [m, batch_size * beam_width, D]
|
||||
input_embedding = self.transpose(input_embedding, self.transpose_orders)
|
||||
enc_states = self.transpose(enc_states, self.transpose_orders)
|
||||
|
||||
# decoder_output: [m, batch_size*beam_width, V], scores:[b, t_q, t_k], all_decoder_state:[4,2,b*beam_width,D]
|
||||
decoder_output, all_decoder_state, scores = self.decoder(input_embedding, enc_states, enc_attention_mask,
|
||||
decoder_hidden_state)
|
||||
# [batch_size * beam_width, m, v]
|
||||
decoder_output = self.transpose(decoder_output, self.transpose_orders)
|
||||
|
||||
# take the last step, [batch_size * beam_width, 1, V]
|
||||
decoder_output = decoder_output[:, (input_len - 1):input_len, :]
|
||||
|
||||
# projection and log_prob
|
||||
log_probs = self.projection(decoder_output)
|
||||
|
||||
# [batch_size * beam_width, 1, vocabulary_size]
|
||||
return log_probs, all_decoder_state, scores
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DynamicRNN."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class DynamicRNNCell(nn.Cell):
|
||||
"""
|
||||
DynamicRNN Cell.
|
||||
|
||||
Args:
|
||||
num_setp (int): Lengths of setences.
|
||||
batch_size (int): Batch size.
|
||||
word_embed_dim (int): Input size.
|
||||
hidden_size (int): Hidden size .
|
||||
initializer_range (float): Initial range. Default: 0.02
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_setp=50,
|
||||
batch_size=128,
|
||||
word_embed_dim=1024,
|
||||
hidden_size=1024,
|
||||
initializer_range=0.1):
|
||||
super(DynamicRNNCell, self).__init__()
|
||||
self.rnn = P.DynamicRNN()
|
||||
self.num_step = num_setp
|
||||
self.batch_size = batch_size
|
||||
self.input_size = word_embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
# w
|
||||
dynamicRNN_w = np.random.uniform(-initializer_range, initializer_range,
|
||||
size=[self.input_size + self.hidden_size, 4 * self.hidden_size])
|
||||
self.dynamicRNN_w = Parameter(Tensor(dynamicRNN_w, mstype.float32), name='weight')
|
||||
# b
|
||||
dynamicRNN_b = np.random.uniform(-initializer_range, initializer_range, size=[4 * self.hidden_size])
|
||||
self.dynamicRNN_b = Parameter(Tensor(dynamicRNN_b, mstype.float32), name='bias')
|
||||
|
||||
self.dynamicRNN_h = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
|
||||
self.dynamicRNN_c = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x, init_h=None, init_c=None):
|
||||
w = self.cast(self.dynamicRNN_w, mstype.float16)
|
||||
b = self.cast(self.dynamicRNN_b, mstype.float16)
|
||||
if init_h is None or init_c is None:
|
||||
init_h = self.cast(self.dynamicRNN_h, mstype.float16)
|
||||
init_c = self.cast(self.dynamicRNN_c, mstype.float16)
|
||||
out = self.rnn(x, w, b, None, init_h, init_c)
|
||||
return out[0], out[1], out[2]
|
||||
|
||||
|
||||
class DynamicRNNNet(nn.Cell):
|
||||
"""
|
||||
DynamicRNN Network.
|
||||
|
||||
Args:
|
||||
seq_length (int): Lengths of setences.
|
||||
batchsize (int): Batch size.
|
||||
word_embed_dim (int): Input size.
|
||||
hidden_size (int): Hidden size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seq_length=80,
|
||||
batchsize=128,
|
||||
word_embed_dim=1024,
|
||||
hidden_size=1024):
|
||||
super(DynamicRNNNet, self).__init__()
|
||||
self.max_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.get_shape = P.Shape()
|
||||
self.net = DynamicRNNCell(num_setp=seq_length,
|
||||
batch_size=batchsize,
|
||||
word_embed_dim=word_embed_dim,
|
||||
hidden_size=hidden_size)
|
||||
|
||||
def construct(self, inputs, init_state=None):
|
||||
"""DynamicRNN Network."""
|
||||
inputs = self.cast(inputs, mstype.float16)
|
||||
if init_state is not None:
|
||||
init_h = self.cast(init_state[0:1, :, :], mstype.float16)
|
||||
init_c = self.cast(init_state[-1:, :, :], mstype.float16)
|
||||
out, state_h, state_c = self.net(inputs, init_h, init_c)
|
||||
else:
|
||||
out, state_h, state_c = self.net(inputs)
|
||||
out = self.cast(out, mstype.float32)
|
||||
state = self.concat((state_h[-1:, :, :], state_c[-1:, :, :]))
|
||||
state = self.cast(state, mstype.float32)
|
||||
# out:[T,b,D], state:[2,b,D]
|
||||
return out, state
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Word embedding for gnmt."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
Embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
is_training (bool): Whether to train.
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embed_dim (int): The size of word embedding.
|
||||
initializer_range (int): The initialize range of parameters.
|
||||
use_one_hot_embeddings (bool): Whether use one-hot embedding. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
is_training,
|
||||
vocab_size,
|
||||
embed_dim,
|
||||
initializer_range=0.1,
|
||||
use_one_hot_embeddings=False):
|
||||
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.embedding_dim = embed_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
init_weight = np.random.normal(-initializer_range, initializer_range, size=[vocab_size, embed_dim])
|
||||
self.embedding_table = Parameter(Tensor(init_weight, mstype.float32),
|
||||
name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.gather = P.GatherV2()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.get_shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): A batch of sentences with shape (N, T).
|
||||
|
||||
Returns:
|
||||
Tensor, word embeddings with shape (N, T, D)
|
||||
"""
|
||||
_shape = self.get_shape(input_ids) # (N, T).
|
||||
_batch_size = _shape[0]
|
||||
_max_len = _shape[1]
|
||||
if self.is_training:
|
||||
embedding_table = self.cast(self.embedding_table, mstype.float16)
|
||||
else:
|
||||
embedding_table = self.embedding_table
|
||||
|
||||
flat_ids = self.reshape(input_ids, (_batch_size * _max_len,))
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
if self.is_training:
|
||||
one_hot_ids = self.cast(one_hot_ids, mstype.float16)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(embedding_table, flat_ids, 0)
|
||||
|
||||
output = self.reshape(output_for_reshape, (_batch_size, _max_len, self.embedding_dim))
|
||||
if self.is_training:
|
||||
output = self.cast(output, mstype.float32)
|
||||
embedding_table = self.cast(embedding_table, mstype.float32)
|
||||
return output, embedding_table
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Encoder of GNMT."""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import GNMTConfig
|
||||
from .dynamic_rnn import DynamicRNNNet
|
||||
|
||||
|
||||
class GNMTEncoder(nn.Cell):
|
||||
"""
|
||||
Implements of GNMT encoder.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Configuration of GNMT network.
|
||||
is_training (bool): Whether to train.
|
||||
compute_type (mstype): Mindspore data type.
|
||||
|
||||
Returns:
|
||||
Tensor, shape of (N, T, D).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: GNMTConfig,
|
||||
is_training: bool,
|
||||
compute_type=mstype.float32):
|
||||
super(GNMTEncoder, self).__init__()
|
||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||
self.max_positions = config.seq_length
|
||||
self.attn_embed_dim = config.hidden_size
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_dropout_prob = config.hidden_dropout_prob
|
||||
self.vocab_size = config.vocab_size
|
||||
self.seq_length = config.seq_length
|
||||
self.batch_size = config.batch_size
|
||||
self.word_embed_dim = config.hidden_size
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
encoder_layers = []
|
||||
for i in range(0, self.num_layers + 1):
|
||||
if i == 2:
|
||||
# the bidirectional layer's output is [T,D,2N]
|
||||
scaler = 2
|
||||
else:
|
||||
# the rest layer's output is [T,D,N]
|
||||
scaler = 1
|
||||
layer = DynamicRNNNet(seq_length=self.seq_length,
|
||||
batchsize=self.batch_size,
|
||||
word_embed_dim=scaler * self.word_embed_dim,
|
||||
hidden_size=self.word_embed_dim)
|
||||
encoder_layers.append(layer)
|
||||
self.encoder_layers = nn.CellList(encoder_layers)
|
||||
self.reverse_v2 = P.ReverseV2(axis=[0])
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
|
||||
|
||||
def construct(self, inputs, source_len, attention_mask=None):
|
||||
"""Encoder."""
|
||||
inputs = self.dropout(inputs)
|
||||
# bidirectional layer, fwd_encoder_outputs: [T,N,D]
|
||||
fwd_encoder_outputs, _ = self.encoder_layers[0](inputs)
|
||||
|
||||
# the input need reverse.
|
||||
inputs_r = self.reverse_v2(inputs)
|
||||
bak_encoder_outputs, _ = self.encoder_layers[1](inputs_r)
|
||||
# the result need reverse.
|
||||
bak_encoder_outputs = self.reverse_v2(bak_encoder_outputs)
|
||||
|
||||
# bi_encoder_outputs: [T,N,2D]
|
||||
bi_encoder_outputs = self.concat((fwd_encoder_outputs, bak_encoder_outputs))
|
||||
|
||||
# 1st unidirectional layer. encoder_outputs: [T,N,D]
|
||||
bi_encoder_outputs = self.dropout(bi_encoder_outputs)
|
||||
encoder_outputs, _ = self.encoder_layers[2](bi_encoder_outputs)
|
||||
# Build all the rest unidi layers of encoder
|
||||
for i in range(3, self.num_layers + 1):
|
||||
residual = encoder_outputs
|
||||
encoder_outputs = self.dropout(encoder_outputs)
|
||||
# [T,N,D] -> [T,N,D]
|
||||
encoder_outputs_o, _ = self.encoder_layers[i](encoder_outputs)
|
||||
encoder_outputs = encoder_outputs_o + residual
|
||||
|
||||
return encoder_outputs
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GNMTv2 network."""
|
||||
import copy
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config.config import GNMTConfig
|
||||
from .embedding import EmbeddingLookup
|
||||
from .create_attn_padding import CreateAttentionPaddingsFromInputPaddings
|
||||
from .beam_search import BeamSearchDecoder, TileBeam
|
||||
from .encoder import GNMTEncoder
|
||||
from .decoder import GNMTDecoder
|
||||
from .decoder_beam_infer import BeamDecoderStep
|
||||
from .components import SaturateCast
|
||||
|
||||
|
||||
class GNMT(nn.Cell):
|
||||
"""
|
||||
GNMT with encoder and decoder.
|
||||
|
||||
In GNMT, we define T = src_max_len, T' = tgt_max_len.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Model config.
|
||||
is_training (bool): Whether is training.
|
||||
use_one_hot_embeddings (bool): Whether use one-hot embedding.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: GNMTConfig,
|
||||
is_training: bool = False,
|
||||
use_one_hot_embeddings: bool = False,
|
||||
use_positional_embedding: bool = True,
|
||||
compute_type=mstype.float32):
|
||||
super(GNMT, self).__init__()
|
||||
|
||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||
self.max_positions = config.seq_length
|
||||
self.attn_embed_dim = config.hidden_size
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_dropout_prob = 0.0
|
||||
self.is_training = is_training
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_dropout_prob = config.hidden_dropout_prob
|
||||
self.vocab_size = config.vocab_size
|
||||
self.seq_length = config.seq_length
|
||||
self.batch_size = config.batch_size
|
||||
self.max_decode_length = config.max_decode_length
|
||||
self.word_embed_dim = config.hidden_size
|
||||
|
||||
self.beam_width = config.beam_width
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.transpose_orders = (1, 0, 2)
|
||||
self.embedding_lookup = EmbeddingLookup(
|
||||
is_training=self.is_training,
|
||||
vocab_size=self.vocab_size,
|
||||
embed_dim=self.word_embed_dim,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
self.gnmt_encoder = GNMTEncoder(config, is_training)
|
||||
|
||||
if self.is_training:
|
||||
# use for train.
|
||||
self.gnmt_decoder = GNMTDecoder(config, is_training)
|
||||
else:
|
||||
# use for infer.
|
||||
self.expand = P.ExpandDims()
|
||||
self.multiply = P.Mul()
|
||||
self.reshape = P.Reshape()
|
||||
self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,
|
||||
is_training=False)
|
||||
self.tile_beam = TileBeam(beam_width=config.beam_width)
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
|
||||
beam_decoder_cell = BeamDecoderStep(config, use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
# link beam_search after decoder
|
||||
self.beam_decoder = BeamSearchDecoder(
|
||||
batch_size=config.batch_size,
|
||||
seq_length=config.seq_length,
|
||||
vocab_size=config.vocab_size,
|
||||
decoder=beam_decoder_cell,
|
||||
beam_width=config.beam_width,
|
||||
decoder_layers_nums=config.num_hidden_layers,
|
||||
length_penalty_weight=config.length_penalty_weight,
|
||||
hidden_size=config.hidden_size,
|
||||
max_decode_length=config.max_decode_length)
|
||||
self.beam_decoder.add_flags(loop_can_unroll=True)
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, source_ids, source_mask=None, source_len=None,
|
||||
target_ids=None):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
In this method, T = src_max_len, T' = tgt_max_len.
|
||||
|
||||
Args:
|
||||
source_ids (Tensor): Source sentences with shape (N, T).
|
||||
source_mask (Tensor): Source sentences padding mask with shape (N, T),
|
||||
where 0 indicates padding position.
|
||||
target_ids (Tensor): Target sentences with shape (N, T').
|
||||
target_mask (Tensor): Target sentences padding mask with shape (N, T'),
|
||||
where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs.
|
||||
"""
|
||||
|
||||
# Process source sentences. src_embeddings:[N, T, D].
|
||||
src_embeddings, _ = self.embedding_lookup(source_ids)
|
||||
# T, N, D
|
||||
inputs = self.transpose(src_embeddings, self.transpose_orders)
|
||||
# encoder. encoder_outputs: [T, N, D]
|
||||
encoder_outputs = self.gnmt_encoder(inputs, source_len=source_len)
|
||||
|
||||
# decoder.
|
||||
if self.is_training:
|
||||
# training
|
||||
# process target input sentences. N, T, D
|
||||
tgt_embeddings, _ = self.embedding_lookup(target_ids)
|
||||
# T, N, D
|
||||
tgt_embeddings = self.transpose(tgt_embeddings, self.transpose_orders)
|
||||
# cell: [T,N,D].
|
||||
cell, _, _ = self.gnmt_decoder(tgt_embeddings,
|
||||
encoder_outputs,
|
||||
attention_mask=source_mask)
|
||||
# decoder_output: (N, T', V).
|
||||
decoder_outputs = self.transpose(cell, self.transpose_orders)
|
||||
out = decoder_outputs
|
||||
else:
|
||||
# infer
|
||||
# encoder_output: [T, N, D] -> [N, T, D].
|
||||
beam_encoder_output = self.transpose(encoder_outputs, self.transpose_orders)
|
||||
# bean search for encoder output, [N*beam_width, T, D]
|
||||
beam_encoder_output = self.tile_beam(beam_encoder_output)
|
||||
|
||||
# (N*beam_width, T)
|
||||
beam_enc_attention_pad = self.tile_beam(source_mask)
|
||||
|
||||
predicted_ids = self.beam_decoder(beam_encoder_output, beam_enc_attention_pad)
|
||||
predicted_ids = self.reshape(predicted_ids, (-1, self.max_decode_length))
|
||||
out = predicted_ids
|
||||
|
||||
return out
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Infer api."""
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
from src.dataset import load_dataset
|
||||
from .gnmt import GNMT
|
||||
from ..utils import zero_weight
|
||||
from ..utils.load_weights import load_infer_weights
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=False)
|
||||
|
||||
|
||||
def get_weight_and_variable(model_path, params):
|
||||
print("model path is {}".format(model_path))
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
with open("variable.txt", "w") as f:
|
||||
for msname in ms_ckpt:
|
||||
f.write(msname + "\n")
|
||||
with open("weights.txt", "w") as f:
|
||||
for param in params:
|
||||
name = param.name
|
||||
f.write(name + "\n")
|
||||
|
||||
|
||||
class GNMTInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GNMT network infer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): GNMT model.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor], predicted_ids and predicted_probs.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(GNMTInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
source_len):
|
||||
"""Defines the computation performed."""
|
||||
|
||||
predicted_ids = self.network(source_ids,
|
||||
source_mask,
|
||||
source_len)
|
||||
|
||||
return predicted_ids
|
||||
|
||||
|
||||
def gnmt_infer(config, dataset):
|
||||
"""
|
||||
Run infer with GNMT.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config.
|
||||
dataset (Dataset): Dataset.
|
||||
|
||||
Returns:
|
||||
List[Dict], prediction, each example has 4 keys, "source",
|
||||
"target", "prediction" and "prediction_prob".
|
||||
"""
|
||||
tfm_model = GNMT(config=config,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
params = tfm_model.trainable_params()
|
||||
get_weight_and_variable(config.existed_ckpt, params)
|
||||
weights = load_infer_weights(config)
|
||||
|
||||
for param in params:
|
||||
value = param.data
|
||||
name = param.name
|
||||
if name not in weights:
|
||||
raise ValueError(f"{name} is not found in weights.")
|
||||
with open("weight_after_deal.txt", "a+") as f:
|
||||
weights_name = name
|
||||
f.write(weights_name)
|
||||
f.write("\n")
|
||||
if isinstance(value, Tensor):
|
||||
print(name, value.asnumpy().shape)
|
||||
if weights_name in weights:
|
||||
assert weights_name in weights
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
|
||||
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
print("weight not found in checkpoint: " + weights_name)
|
||||
param.set_data(zero_weight(value.asnumpy().shape))
|
||||
f.close()
|
||||
print(" | Load weights successfully.")
|
||||
tfm_infer = GNMTInferCell(tfm_model)
|
||||
model = Model(tfm_infer)
|
||||
|
||||
predictions = []
|
||||
source_sentences = []
|
||||
|
||||
shape = P.Shape()
|
||||
concat = P.Concat(axis=0)
|
||||
batch_index = 1
|
||||
pad_idx = 0
|
||||
sos_idx = 2
|
||||
eos_idx = 3
|
||||
source_ids_pad = Tensor(np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)]),
|
||||
[config.batch_size, 1]), mstype.int32)
|
||||
source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),
|
||||
[config.batch_size, 1]), mstype.int32)
|
||||
source_len_pad = Tensor(np.tile(np.array([[2]]), [config.batch_size, 1]), mstype.int32)
|
||||
for batch in dataset.create_dict_iterator():
|
||||
source_sentences.append(batch["source_eos_ids"].asnumpy())
|
||||
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
|
||||
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
|
||||
source_len = Tensor(batch["source_eos_length"], mstype.int32)
|
||||
|
||||
active_num = shape(source_ids)[0]
|
||||
if active_num < config.batch_size:
|
||||
source_ids = concat((source_ids, source_ids_pad[active_num:, :]))
|
||||
source_mask = concat((source_mask, source_mask_pad[active_num:, :]))
|
||||
source_len = concat((source_len, source_len_pad[active_num:, :]))
|
||||
|
||||
start_time = time.time()
|
||||
predicted_ids = model.predict(source_ids, source_mask, source_len)
|
||||
|
||||
print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
|
||||
f"Time cost: {time.time() - start_time}.")
|
||||
if active_num < config.batch_size:
|
||||
predicted_ids = predicted_ids[:active_num, :]
|
||||
batch_index = batch_index + 1
|
||||
predictions.append(predicted_ids.asnumpy())
|
||||
|
||||
output = []
|
||||
for inputs, batch_out in zip(source_sentences, predictions):
|
||||
for i, _ in enumerate(batch_out):
|
||||
if batch_out.ndim == 3:
|
||||
batch_out = batch_out[:, 0]
|
||||
|
||||
example = {
|
||||
"source": inputs[i].tolist(),
|
||||
"prediction": batch_out[i].tolist()
|
||||
}
|
||||
output.append(example)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def infer(config):
|
||||
"""
|
||||
GNMT infer api.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config.
|
||||
|
||||
Returns:
|
||||
list, result with
|
||||
"""
|
||||
eval_dataset = load_dataset(data_files=config.test_dataset,
|
||||
schema=config.dataset_schema,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
drop_remainder=False,
|
||||
is_translate=True,
|
||||
shuffle=False) if config.test_dataset else None
|
||||
prediction = gnmt_infer(config, eval_dataset)
|
||||
return prediction
|
|
@ -0,0 +1,345 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GNMT for training."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
|
||||
from .gnmt import GNMT
|
||||
from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
|
||||
|
||||
|
||||
class PredLogProbs(nn.Cell):
|
||||
"""
|
||||
Get log probs.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): The config of GNMT.
|
||||
|
||||
Returns:
|
||||
Tensor, log softmax output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, log softmax output.
|
||||
"""
|
||||
shape = self.get_shape(input_tensor)
|
||||
logits = self.reshape(input_tensor, (shape[0] * shape[1], shape[2]))
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class GNMTTraining(nn.Cell):
|
||||
"""
|
||||
GNMT training network.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): The config of GNMT.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(GNMTTraining, self).__init__()
|
||||
self.gnmt = GNMT(config, is_training, use_one_hot_embeddings)
|
||||
self.projection = PredLogProbs(config)
|
||||
|
||||
def construct(self, source_ids, source_mask, source_len, target_ids):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
source_ids (Tensor): Source sentence.
|
||||
source_mask (Tensor): Source padding mask.
|
||||
source_len (Tensor): Effective length of source sentence.
|
||||
target_ids (Tensor): Target sentence.
|
||||
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores.
|
||||
"""
|
||||
decoder_outputs = self.gnmt(source_ids, source_mask, source_len, target_ids)
|
||||
prediction_scores = self.projection(decoder_outputs)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyCriterion(nn.Cell):
|
||||
"""
|
||||
Label Smoothed Cross-Entropy Criterion.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): The config of GNMT.
|
||||
|
||||
Returns:
|
||||
Tensor, final loss.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(LabelSmoothedCrossEntropyCriterion, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.batch_size = config.batch_size
|
||||
self.smoothing = 0.1
|
||||
self.confidence = 0.9
|
||||
self.last_idx = (-1,)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.index_ids = Tensor(np.arange(config.batch_size * config.max_decode_length).reshape((-1, 1)), mstype.int32)
|
||||
self.gather_nd = P.GatherNd()
|
||||
self.expand = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=-1)
|
||||
|
||||
def construct(self, prediction_scores, label_ids, label_weights):
|
||||
"""
|
||||
Construct network to calculate loss.
|
||||
|
||||
Args:
|
||||
prediction_scores (Tensor): Prediction scores. [batchsize, seq_len, vocab_size]
|
||||
label_ids (Tensor): Labels. [batchsize, seq_len]
|
||||
label_weights (Tensor): Mask tensor. [batchsize, seq_len]
|
||||
|
||||
Returns:
|
||||
Tensor, final loss.
|
||||
"""
|
||||
prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
|
||||
label_ids = self.reshape(label_ids, (-1, 1))
|
||||
label_weights = self.reshape(label_weights, (-1,))
|
||||
tmp_gather_indices = self.concat((self.index_ids, label_ids))
|
||||
nll_loss = self.neg(self.gather_nd(prediction_scores, tmp_gather_indices))
|
||||
nll_loss = label_weights * nll_loss
|
||||
smooth_loss = self.neg(self.reduce_mean(prediction_scores, self.last_idx))
|
||||
smooth_loss = label_weights * smooth_loss
|
||||
loss = self.reduce_sum(self.confidence * nll_loss + self.smoothing * smooth_loss, ())
|
||||
loss = loss / self.batch_size
|
||||
return loss
|
||||
|
||||
|
||||
class GNMTNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide GNMT training loss through network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of GNMT.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(GNMTNetworkWithLoss, self).__init__()
|
||||
self.gnmt = GNMTTraining(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = LabelSmoothedCrossEntropyCriterion(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
source_len,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights):
|
||||
prediction_scores = self.gnmt(source_ids, source_mask, source_len, target_ids)
|
||||
total_loss = self.loss(prediction_scores, label_ids, label_weights)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GNMT network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network: Cell. The training network. Note that loss function should have
|
||||
been added.
|
||||
optimizer: Optimizer. Optimizer for updating the weights.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(GNMTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.all_reduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
self.loss_scalar = P.ScalarSummary()
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
source_eos_length,
|
||||
target_sos_ids,
|
||||
target_eos_ids,
|
||||
target_eos_mask,
|
||||
sens=None):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
source_eos_ids (Tensor): Source sentence.
|
||||
source_eos_mask (Tensor): Source padding mask.
|
||||
source_eos_length (Tensor): Effective length of source sentence.
|
||||
target_sos_ids (Tensor): Target sentence.
|
||||
target_eos_ids (Tensor): Prediction sentence.
|
||||
target_eos_mask (Tensor): Prediction padding mask.
|
||||
sens (Tensor): Loss sen.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
|
||||
"""
|
||||
source_ids = source_eos_ids
|
||||
source_mask = source_eos_mask
|
||||
source_len = source_eos_length
|
||||
target_ids = target_sos_ids
|
||||
label_ids = target_eos_ids
|
||||
label_weights = target_eos_mask
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(source_ids,
|
||||
source_mask,
|
||||
source_len,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights)
|
||||
# Alloc status.
|
||||
init = self.alloc_status()
|
||||
# Clear overflow buffer.
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(source_ids,
|
||||
source_mask,
|
||||
source_len,
|
||||
target_ids,
|
||||
label_ids,
|
||||
label_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# Apply grad reducer on grads.
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
||||
if self.is_distributed:
|
||||
# Sum overflow flag over devices.
|
||||
flag_reduce = self.all_reduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
self.loss_scalar("loss", loss)
|
||||
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gradient clip."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""
|
||||
Construct gradient clip network.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
|
||||
return new_grads
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for gnmt model."""
|
||||
|
||||
from .lr_scheduler import square_root_schedule
|
||||
from .loss_monitor import LossCallBack
|
||||
from .initializer import zero_weight, one_weight, normal_weight, weight_variable
|
||||
|
||||
__all__ = [
|
||||
"square_root_schedule",
|
||||
"LossCallBack",
|
||||
"one_weight",
|
||||
"zero_weight",
|
||||
"normal_weight",
|
||||
"weight_variable"
|
||||
]
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initializer."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def _compute_fans(shape):
|
||||
"""
|
||||
Computes the number of input and output units for a weight shape.
|
||||
|
||||
Args:
|
||||
shape (tuple): Integer shape tuple or TF tensor shape.
|
||||
|
||||
Returns:
|
||||
tuple, integer scalars (fan_in, fan_out).
|
||||
"""
|
||||
if not shape:
|
||||
fan_in = fan_out = 1
|
||||
elif len(shape) == 1:
|
||||
fan_in = fan_out = shape[0]
|
||||
elif len(shape) == 2:
|
||||
fan_in = shape[0]
|
||||
fan_out = shape[1]
|
||||
else:
|
||||
# Assuming convolution kernels (2D, 3D, or more).
|
||||
# kernel shape: (..., input_depth, depth)
|
||||
receptive_field_size = 1
|
||||
for dim in shape[:-2]:
|
||||
receptive_field_size *= dim
|
||||
fan_in = shape[-2] * receptive_field_size
|
||||
fan_out = shape[-1] * receptive_field_size
|
||||
return int(fan_in), int(fan_out)
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
"""
|
||||
Generate weight var.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
# scale_shape = shape
|
||||
# fan_in, fan_out = _compute_fans(scale_shape)
|
||||
# scale = 1.0 / max(1., (fan_in + fan_out) / 2.)
|
||||
# limit = math.sqrt(3.0 * scale)
|
||||
limit = 0.1
|
||||
values = np.random.uniform(-limit, limit, shape)
|
||||
return values
|
||||
|
||||
|
||||
def one_weight(shape):
|
||||
"""
|
||||
Generate weight with ones.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
ones = np.ones(shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
def zero_weight(shape):
|
||||
"""
|
||||
Generate weight with zeros.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
zeros = np.zeros(shape).astype(np.float32)
|
||||
return Tensor(zeros)
|
||||
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
"""
|
||||
Generate weight with normal dist.
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape.
|
||||
num_units (int): Dimension.
|
||||
|
||||
Returns:
|
||||
Tensor, var.
|
||||
"""
|
||||
norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Weight loader."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
|
||||
def load_infer_weights(config):
|
||||
"""
|
||||
Load weights from ckpt or npz.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config.
|
||||
|
||||
Returns:
|
||||
dict, weights.
|
||||
"""
|
||||
model_path = config.existed_ckpt
|
||||
if model_path.endswith(".npz"):
|
||||
ms_ckpt = np.load(model_path)
|
||||
is_npz = True
|
||||
else:
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
is_npz = False
|
||||
weights = {}
|
||||
with open("variable_after_deal.txt", "w") as f:
|
||||
for param_name in ms_ckpt:
|
||||
infer_name = param_name.replace("gnmt.gnmt.", "")
|
||||
if infer_name.startswith("embedding_lookup."):
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
f.write(infer_name)
|
||||
f.write("\n")
|
||||
infer_name = "beam_decoder.decoder." + infer_name
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
f.write(infer_name)
|
||||
f.write("\n")
|
||||
continue
|
||||
|
||||
elif not infer_name.startswith("gnmt_encoder"):
|
||||
if infer_name.startswith("gnmt_decoder."):
|
||||
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
|
||||
infer_name = "beam_decoder.decoder." + infer_name
|
||||
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[param_name]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||
f.write(infer_name)
|
||||
f.write("\n")
|
||||
|
||||
f.close()
|
||||
return weights
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss monitor."""
|
||||
import time
|
||||
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
from config import GNMTConfig
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
def __init__(self, config: GNMTConfig, per_print_times: int = 1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.config = config
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
if not self.time_stamp_init:
|
||||
self.time_stamp_first = self._get_ms_timestamp()
|
||||
self.time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end."""
|
||||
cb_params = run_context.original_args()
|
||||
file_name = "./loss.log"
|
||||
if self.config.modelarts:
|
||||
import os
|
||||
file_name = "/home/work/workspace/loss/loss_{}.log".format(os.getenv('DEVICE_ID'))
|
||||
with open(file_name, "a+") as f:
|
||||
time_stamp_current = self._get_ms_timestamp()
|
||||
f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format(
|
||||
time_stamp_current - self.time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs[0].asnumpy()),
|
||||
str(cb_params.net_outputs[1].asnumpy()),
|
||||
str(cb_params.net_outputs[2].asnumpy())
|
||||
))
|
||||
|
||||
if self.config.modelarts:
|
||||
from modelarts.data_util import upload_output
|
||||
rank_id = get_rank()
|
||||
if cb_params.cur_step_num % self.config.save_step == 1 \
|
||||
and cb_params.cur_step_num != 1 and rank_id in [0, 8]:
|
||||
upload_output("/home/work/workspace/loss", self.config.train_url)
|
||||
upload_output("/cache/ckpt_0", self.config.train_url)
|
||||
|
||||
@staticmethod
|
||||
def _get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning scheduler."""
|
||||
from math import ceil
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_float2int(values, total_steps):
|
||||
if isinstance(values, float):
|
||||
values = int(values * total_steps)
|
||||
return values
|
||||
|
||||
|
||||
def square_root_schedule(lr, update_num, decay_start_step,
|
||||
warmup_steps=2000,
|
||||
min_lr=1e-5):
|
||||
"""
|
||||
Decay the LR based on the ISR(inverse square root).
|
||||
|
||||
During warm-up::
|
||||
lrs = np.linspace(0, lr, warmup_steps)
|
||||
|
||||
After warm-up:
|
||||
decay_factor = lr * sqrt(warmup_steps)
|
||||
lr = decay_factor / sqrt(step) if step >= decay_start_step else lr
|
||||
|
||||
Args:
|
||||
lr (float): Init learning rate.
|
||||
update_num (int): Total steps.
|
||||
decay_start_step (int): Decay begins after `decay_start_step` steps.
|
||||
warmup_steps (int): Warm up steps.
|
||||
min_lr (float): Min learning rate.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate array.
|
||||
"""
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
|
||||
# If warmup_init_lr > lr, then lr_step is negative.
|
||||
# Otherwise, it's positive.
|
||||
lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps
|
||||
decay_factor = lr * warmup_steps ** 0.5
|
||||
|
||||
lrs = np.empty(shape=update_num, dtype=np.float32)
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < update_num:
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
for step in range(_start_step, update_num):
|
||||
if step < warmup_steps:
|
||||
_lr = warmup_init_lr + step * lr_step
|
||||
elif step < decay_start_step:
|
||||
_lr = lr
|
||||
else:
|
||||
_lr = decay_factor * step ** -0.5
|
||||
if _lr < min_lr:
|
||||
_lr = min_lr
|
||||
lrs[step] = _lr
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
warmup_steps (int): Warmup steps.
|
||||
decay_steps (int): Decay steps.
|
||||
total_update_num (int): Total update steps.
|
||||
min_lr (float): Min learning.
|
||||
power (float): Power factor.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < total_update_num:
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
decay_steps = decay_steps
|
||||
for step in range(_start_step, total_update_num):
|
||||
_step = step - _start_step # 2999
|
||||
ratio = ceil(_step / decay_steps) # 3
|
||||
ratio = 1 if ratio < 1 else ratio
|
||||
_decay_steps = decay_steps * ratio # 3000
|
||||
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def Warmup_MultiStepLR_scheduler(base_lr=0.002, total_update_num=200, warmup_steps=200, remain_steps=1.0,
|
||||
decay_interval=-1, decay_steps=4, decay_factor=0.5):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
base_lr (float): Initial learning rate.
|
||||
total_update_num (int): Total update steps.
|
||||
warmup_steps (int or float): Warmup steps.
|
||||
remain_steps (int or float): start decay at 'remain_steps' iteration
|
||||
decay_interval (int): interval between LR decay steps
|
||||
decay_steps (int): Decay steps.
|
||||
decay_factor (float): decay factor
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
remain_steps = convert_float2int(remain_steps, total_update_num)
|
||||
warmup_steps = convert_float2int(warmup_steps, total_update_num)
|
||||
if warmup_steps > remain_steps:
|
||||
warmup_steps = remain_steps
|
||||
|
||||
if decay_interval < 0:
|
||||
decay_iterations = total_update_num - remain_steps
|
||||
decay_interval = decay_iterations // decay_steps
|
||||
decay_interval = max(decay_interval, 1)
|
||||
else:
|
||||
decay_interval = convert_float2int(decay_interval, total_update_num)
|
||||
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
_start_step = 0
|
||||
for last_epoch in range(_start_step, total_update_num):
|
||||
if last_epoch < warmup_steps:
|
||||
if warmup_steps != 0:
|
||||
warmup_factor = math.exp(math.log(0.01) / warmup_steps)
|
||||
else:
|
||||
warmup_factor = 1.0
|
||||
inv_decay = warmup_factor ** (warmup_steps - last_epoch)
|
||||
lrs[last_epoch] = base_lr * inv_decay
|
||||
elif last_epoch >= remain_steps:
|
||||
decay_iter = last_epoch - remain_steps
|
||||
num_decay_step = decay_iter // decay_interval + 1
|
||||
num_decay_step = min(num_decay_step, decay_steps)
|
||||
lrs[last_epoch] = base_lr * (decay_factor ** num_decay_step)
|
||||
else:
|
||||
lrs[last_epoch] = base_lr
|
||||
return lrs
|
|
@ -0,0 +1,423 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""adam"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn import Optimizer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
_learning_rate_update_func = ['linear', 'cos', 'sin']
|
||||
|
||||
adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
|
||||
|
||||
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||
return next_v
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
# validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
# validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
# validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
# validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_float_positive('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_positive('power', power, prim_name)
|
||||
validator.check_float_legal_value('power', power, prim_name)
|
||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor")
|
||||
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
|
||||
moment2):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""
|
||||
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
||||
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
|
||||
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
|
||||
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
|
||||
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
|
||||
:math:`\epsilon` represents `eps`.
|
||||
|
||||
Note:
|
||||
The Adam optimizer supports separating parameter groups. Different parameter groups can set different
|
||||
`learning_rate` and `weight_decay`.
|
||||
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
|
||||
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr" and "weight_decay" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value should be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.9.
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
|
||||
0.999.
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
||||
1e-8.
|
||||
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
|
||||
If True, updating of the var, m, and v tensors will be protected by a lock.
|
||||
If False, the result is unpredictable. Default: False.
|
||||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
|
||||
1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = nn.Adam(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
>>> {'params': no_conv_params}]
|
||||
>>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
|
||||
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
|
||||
>>> # learning rate of 0.1 and a weight decay of 0.0.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
|
||||
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
# validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
|
||||
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
|
||||
self.eps = eps
|
||||
|
||||
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
|
||||
self.pow = P.Pow()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.realdiv = P.RealDiv()
|
||||
|
||||
self.lr_scalar = P.ScalarSummary()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam optimizer."""
|
||||
params = self.parameters
|
||||
moment1 = self.moment1
|
||||
moment2 = self.moment2
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
self.lr_scalar("learning_rate", lr)
|
||||
|
||||
beta1_power = self.beta1_power * self.beta1
|
||||
self.beta1_power = beta1_power
|
||||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2)
|
||||
return success
|
||||
|
||||
|
||||
class AdamWeightDecay(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm weight decay fix.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecay, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam Weight Decay"""
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
return updated_velocity
|
||||
|
||||
|
||||
class AdamWeightDecayDynamicLR(Optimizer):
|
||||
"""
|
||||
Adam Weight Decay Dynamic Learning Rate (LR).
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
decay_steps (int): The steps of the decay.
|
||||
learning_rate (float): A floating point value for the learning rate. Default: 0.001.
|
||||
end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
|
||||
power (float): Power. Default: 10.0.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
decay_steps,
|
||||
learning_rate=0.001,
|
||||
end_learning_rate=0.0001,
|
||||
power=10.0,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
eps=1e-6,
|
||||
weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
|
||||
warmup_steps=0):
|
||||
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
|
||||
# turn them to scalar when me support scalar/tensor mix operations
|
||||
self.global_step = Parameter(initializer(0, [1]), name="global_step")
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
|
||||
self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32))
|
||||
self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32))
|
||||
self.power = power
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.min = P.Minimum()
|
||||
self.pow = P.Pow()
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32))
|
||||
|
||||
def construct(self, gradients):
|
||||
"""Adam Weight Decay Dynamic LR."""
|
||||
step = self.min(self.global_step, self.decay_steps)
|
||||
p = step / self.decay_steps
|
||||
lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
|
||||
if self.warmup_flag:
|
||||
warmup_percent = self.global_step / self.warmup_steps
|
||||
warmup_lr = self.start_learning_rate * warmup_percent
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
added_global_step = self.global_step + self.one
|
||||
F.control_depend(lr, added_global_step)
|
||||
self.global_step = added_global_step
|
||||
|
||||
return updated_velocity
|
|
@ -0,0 +1,360 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train api."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, SummaryCollector, TimeMonitor
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from config import GNMTConfig
|
||||
from src.dataset import load_dataset
|
||||
from src.gnmt_model import GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
|
||||
from src.utils import LossCallBack
|
||||
from src.utils import one_weight, weight_variable
|
||||
from src.utils.lr_scheduler import square_root_schedule, polynomial_decay_scheduler, Warmup_MultiStepLR_scheduler
|
||||
from src.utils.optimizer import Adam
|
||||
|
||||
parser = argparse.ArgumentParser(description='GNMT train entry point.')
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
|
||||
device_id = os.getenv('DEVICE_ID', None)
|
||||
if device_id is None:
|
||||
raise RuntimeError("`DEVICE_ID` can not be None.")
|
||||
|
||||
device_id = int(device_id)
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=True,
|
||||
device_id=device_id)
|
||||
|
||||
|
||||
def get_config(config):
|
||||
config = GNMTConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
|
||||
def _train(model, config: GNMTConfig,
|
||||
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
|
||||
callbacks: list = None):
|
||||
"""
|
||||
Train model.
|
||||
|
||||
Args:
|
||||
model (Model): MindSpore model instance.
|
||||
config (GNMTConfig): Config of mass model.
|
||||
pre_training_dataset (Dataset): Pre-training dataset.
|
||||
fine_tune_dataset (Dataset): Fine-tune dataset.
|
||||
test_dataset (Dataset): Test dataset.
|
||||
callbacks (list): A list of callbacks.
|
||||
"""
|
||||
callbacks = callbacks if callbacks else []
|
||||
|
||||
if pre_training_dataset is not None:
|
||||
print(" | Start pre-training job.")
|
||||
epoch_size = pre_training_dataset.get_repeat_count()
|
||||
print("epoch size ", epoch_size)
|
||||
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||
model.train(config.epochs, pre_training_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
||||
|
||||
if fine_tune_dataset is not None:
|
||||
print(" | Start fine-tuning job.")
|
||||
epoch_size = fine_tune_dataset.get_repeat_count()
|
||||
|
||||
model.train(config.epochs, fine_tune_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
||||
|
||||
|
||||
def _load_checkpoint_to_net(config, network):
|
||||
"""load parameters to network from checkpoint."""
|
||||
if config.existed_ckpt:
|
||||
if config.existed_ckpt.endswith(".npz"):
|
||||
weights = np.load(config.existed_ckpt)
|
||||
else:
|
||||
weights = load_checkpoint(config.existed_ckpt)
|
||||
for param in network.trainable_params():
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"Param {weights_name} is not found in ckpt file.")
|
||||
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
param.set_data(weights[weights_name].data)
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
for param in network.trainable_params():
|
||||
name = param.name
|
||||
value = param.data
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
param.set_data(one_weight(value.asnumpy().shape))
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
# param.set_data(zero_weight(value.asnumpy().shape))
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float32)))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float16)))
|
||||
else:
|
||||
if param.data.dtype == "Float32":
|
||||
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float32)))
|
||||
elif param.data.dtype == "Float16":
|
||||
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float16)))
|
||||
|
||||
|
||||
def _get_lr(config, update_steps):
|
||||
"""generate learning rate."""
|
||||
if config.lr_scheduler == "isr":
|
||||
lr = Tensor(square_root_schedule(lr=config.lr,
|
||||
update_num=update_steps,
|
||||
decay_start_step=config.decay_start_step,
|
||||
warmup_steps=config.warmup_steps,
|
||||
min_lr=config.min_lr), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "poly":
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=config.decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.lr_scheduler_power), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "WarmupMultiStepLR":
|
||||
lr = Tensor(Warmup_MultiStepLR_scheduler(base_lr=config.lr,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
remain_steps=config.warmup_lr_remain_steps,
|
||||
decay_interval=config.warmup_lr_decay_interval,
|
||||
decay_steps=config.decay_steps,
|
||||
decay_factor=config.lr_scheduler_power), dtype=mstype.float32)
|
||||
else:
|
||||
lr = config.lr
|
||||
return lr
|
||||
|
||||
|
||||
def _get_optimizer(config, network, lr):
|
||||
"""get gnmt optimizer, support Adam, Lamb, Momentum."""
|
||||
if config.optimizer.lower() == "adam":
|
||||
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
|
||||
elif config.optimizer.lower() == "lamb":
|
||||
optimizer = Lamb(network.trainable_params(), decay_steps=12000,
|
||||
start_learning_rate=config.lr, end_learning_rate=config.min_lr,
|
||||
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
|
||||
eps=1e-6)
|
||||
elif config.optimizer.lower() == "momentum":
|
||||
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"optimizer only support `adam` and `momentum` now.")
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def _build_training_pipeline(config: GNMTConfig,
|
||||
pre_training_dataset=None,
|
||||
fine_tune_dataset=None,
|
||||
test_dataset=None):
|
||||
"""
|
||||
Build training pipeline.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config of mass model.
|
||||
pre_training_dataset (Dataset): Pre-training dataset.
|
||||
fine_tune_dataset (Dataset): Fine-tune dataset.
|
||||
test_dataset (Dataset): Test dataset.
|
||||
"""
|
||||
net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True)
|
||||
net_with_loss.init_parameters_data()
|
||||
_load_checkpoint_to_net(config, net_with_loss)
|
||||
|
||||
dataset = pre_training_dataset if pre_training_dataset is not None \
|
||||
else fine_tune_dataset
|
||||
|
||||
if dataset is None:
|
||||
raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
|
||||
|
||||
update_steps = config.epochs * dataset.get_dataset_size()
|
||||
|
||||
lr = _get_lr(config, update_steps)
|
||||
optimizer = _get_optimizer(config, net_with_loss, lr)
|
||||
|
||||
# Dynamic loss scale.
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
|
||||
scale_factor=config.loss_scale_factor,
|
||||
scale_window=config.scale_window)
|
||||
net_with_grads = GNMTTrainOneStepWithLossScaleCell(
|
||||
network=net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=scale_manager.get_update_cell()
|
||||
)
|
||||
net_with_grads.set_train(True)
|
||||
model = Model(net_with_grads)
|
||||
loss_monitor = LossCallBack(config)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
|
||||
rank_size = os.getenv('RANK_SIZE')
|
||||
callbacks = [time_cb, loss_monitor]
|
||||
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
prefix=config.ckpt_prefix,
|
||||
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
|
||||
callbacks.append(summary_callback)
|
||||
|
||||
if rank_size is None or int(rank_size) == 1:
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
prefix=config.ckpt_prefix,
|
||||
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
|
||||
callbacks.append(summary_callback)
|
||||
|
||||
print(f" | ALL SET, PREPARE TO TRAIN.")
|
||||
_train(model=model, config=config,
|
||||
pre_training_dataset=pre_training_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset,
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
def _setup_parallel_env():
|
||||
context.reset_auto_parallel_context()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
gradients_mean=True
|
||||
)
|
||||
|
||||
|
||||
def train_parallel(config: GNMTConfig):
|
||||
"""
|
||||
Train model with multi ascend chips.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config for MASS model.
|
||||
"""
|
||||
_setup_parallel_env()
|
||||
print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
|
||||
|
||||
pre_train_dataset = load_dataset(
|
||||
data_files=config.pre_train_dataset,
|
||||
schema=config.dataset_schema,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(
|
||||
data_files=config.fine_tune_dataset, schema=config.dataset_schema,
|
||||
batch_size=config.batch_size, epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(
|
||||
data_files=config.test_dataset, schema=config.dataset_schema,
|
||||
batch_size=config.batch_size, epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank()
|
||||
) if config.test_dataset else None
|
||||
|
||||
_build_training_pipeline(config=config,
|
||||
pre_training_dataset=pre_train_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset)
|
||||
|
||||
|
||||
def train_single(config: GNMTConfig):
|
||||
"""
|
||||
Train model on single device.
|
||||
|
||||
Args:
|
||||
config (GNMTConfig): Config for model.
|
||||
"""
|
||||
print(" | Starting training on single device.")
|
||||
|
||||
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
|
||||
schema=config.dataset_schema,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
|
||||
schema=config.dataset_schema,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(data_files=config.test_dataset,
|
||||
schema=config.dataset_schema,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.test_dataset else None
|
||||
|
||||
_build_training_pipeline(config=config,
|
||||
pre_training_dataset=pre_train_dataset,
|
||||
fine_tune_dataset=fine_tune_dataset,
|
||||
test_dataset=test_dataset)
|
||||
|
||||
|
||||
def _check_args(config):
|
||||
if not os.path.exists(config):
|
||||
raise FileNotFoundError("`config` is not existed.")
|
||||
if not isinstance(config, str):
|
||||
raise ValueError("`config` must be type of str.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_rank_size = os.getenv('RANK_SIZE')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
_check_args(args.config)
|
||||
_config = get_config(args.config)
|
||||
|
||||
set_seed(_config.random_seed)
|
||||
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
train_parallel(_config)
|
||||
else:
|
||||
train_single(_config)
|
Loading…
Reference in New Issue