upload gnmt_v2

This commit is contained in:
gaojing 2020-11-06 03:18:06 -05:00
parent 64d078da79
commit 6ac5be72d9
42 changed files with 5117 additions and 0 deletions

View File

@ -43,6 +43,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp)
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)

View File

@ -0,0 +1,263 @@
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
<!-- TOC -->
- [GNMT v2 For MindSpore](#gnmt-v2-for-mindspore)
- [Model Structure](#model-structure)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Platform](#platform)
- [Software](#software)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Dataset Preparation](#dataset-preparation)
- [Configuration File](#configuration-file)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Result](#result)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Practice](#practice)
- [Dataset Preprocessing](#dataset-preprocessing)
- [Training](#training-1)
- [Inference](#inference-1)
- [Random Situation Description](#random-situation-description)
- [Others](#others)
- [ModelZoo](#modelzoo)
<!-- /TOC -->
# GNMT v2 For MindSpore
The GNMT v2 model is similar to the model described in [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144), which is mainly used for corpus translation.
# Model Structure
The GNMTv2 model mainly consists of an encoder, a decoder, and an attention mechanism, where the encoder and the decoder use a shared word embedding vector.
Encoder: consists of four long short-term memory (LSTM) layers. The first LSTM layer is bidirectional, while the other three layers are unidirectional.
Decoder: consists of four unidirectional LSTM layers and a fully connected classifier. The output embedding dimension of LSTM is 1024.
Attention mechanism: uses the standardized Bahdanau attention mechanism. First, the first layer output of the decoder is used as the input of the attention mechanism. Then, the computing result of the attention mechanism is connected to the input of the decoder LSTM, which is used as the input of the subsequent LSTM layer.
# Dataset
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
- *WMT Englis-German* for training.
- *WMT newstest2014* for evaluation.
# Environment Requirements
## Platform
- Hardware (Ascend)
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial.
- Framework
- Install [MindSpore](https://www.mindspore.cn/install/en).
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/en/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/en/master/index.html)
## Software
```txt
numpy
sacrebleu==1.2.10
sacremoses==0.0.19
subword_nmt==0.3.7
```
# [Quick Start](#contents)
After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
python train.py --config /home/workspace/gnmt_v2/config/config.json
# run distributed training example
cd ./scripts
sh run_distributed_train_ascend.sh
# run evaluation example
cd ./scripts
sh run_standalone_eval_ascend.sh
```
# Script Description
The GNMT network script and code result are as follows:
```text
├── gnmt
├── README.md // Introduction of GNMTv2 model.
├── config
│ ├──__init__.py // User interface.
│ ├──config.py // Configuration instance definition.
│ ├──config.json // Configuration file for pre-train or finetune.
│ ├──config_test.json // Configuration file for test.
├── src
│ ├──__init__.py // User interface.
│ ├──dataset
│ ├──__init__.py // User interface.
│ ├──base.py // Base class of data loader.
│ ├──bi_data_loader.py // Bilingual data loader.
│ ├──load_dataset.py // Dataset loader to feed into model.
│ ├──schema.py // Define schema of mindrecord.
│ ├──tokenizer.py // Tokenizer class.
│ ├──gnmt_model
│ ├──__init__.py // User interface.
│ ├──attention.py // Bahdanau attention mechanism.
│ ├──beam_search.py // Beam search decoder for inferring.
│ ├──bleu_calculate.py // Calculat the blue accuracy.
│ ├──components.py // Components.
│ ├──create_attention.py // Recurrent attention.
│ ├──create_attn_padding.py // Create attention paddings from input paddings.
│ ├──decoder.py // GNMT decoder component.
│ ├──decoder_beam_infer.py // GNMT decoder component for beam search.
│ ├──dynamic_rnn.py // DynamicRNN.
│ ├──embedding.py // Embedding component.
│ ├──encoder.py // GNMT encoder component.
│ ├──gnmt.py // GNMT model architecture.
│ ├──gnmt_for_infer.py // Use GNMT to infer.
│ ├──gnmt_for_train.py // Use GNMT to train.
│ ├──grad_clip.py // Gradient clip
│ ├──utils
│ ├──__init__.py // User interface.
│ ├──initializer.py // Parameters initializer.
│ ├──load_weights.py // Load weights from a checkpoint or NPZ file.
│ ├──loss_moniter.py // Callback of monitering loss during training step.
│ ├──lr_scheduler.py // Learning rate scheduler.
│ ├──optimizer.py // Optimizer.
├── scripts
│ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend.
│ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend.
│ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend.
├── create_dataset.py // dataset preparation.
├── eval.py // Infer API entry.
├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry.
```
## Dataset Preparation
You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
- train.tok.clean.bpe.32000.en
- train.tok.clean.bpe.32000.de
- vocab.bpe.32000
- bpe.32000
- newstest2014.en
- newstest2014.de
- Convert the original data to tfrecord for training and evaluation:
``` bash
python create_dataset.py --src_folder /home/workspace/wmt16_de_en --output_folder /home/workspace/dataset_menu
```
## Configuration File
The JSON file in the `config/` directory is the template configuration file.
Almost all required options and parameters can be easily assigned, including the training platform, dataset and model configuration, and optimizer parameters. By setting the corresponding options, you can also obtain optional functions such as loss scale and checkpoint.
For more information about attributes, see the `config/config.py` file.
## Training Process
The model training requires the shell script `scripts/run_standalone_train_ascend.sh`. In this script, set environment variables and the training script `train.py` to be executed in `gnmt_v2/`.
Start task training on a single device and run the following command in bash:
```bash
cd ./scripts
sh run_standalone_train_ascend.sh
```
or multiple devices
Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:
```bash
cd ./scripts
sh run_distributed_train_ascend.sh
```
Note: Ensure that the hccl_json file is assigned when distributed training is running.
Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `distribute_script/rank_table_8p.json` file.
## Evaluation Process
Set options in `config/config_test.json`. Make sure the 'existed_ckpt', 'dataset_schema' and 'test_dataset' are set to your own path.
Run `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.
```bash
cd ./scripts
sh run_standalone_eval_ascend.sh
```
# Model Description
## Performance
### Result
#### Training Performance
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 |
| uploaded Date | 11/06/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | WMT Englis-German |
| Training Parameters | epoch=6, batch_size=128 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| BLEU Score | 24.05 |
| Speed | 344ms/step (8pcs) |
| Loss | 63.35 |
| Params (M) | 613 |
| Checkpoint for inference | 1.8G (.ckpt file) |
| Scripts | [gnmt_v2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2) |
#### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 11/06/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | WMT newstest2014 |
| batch_size | 128 |
| outputs | BLEU score |
| Accuracy | BLEU= 24.05 |
## Practice
The process of GNMTv2 performing the text translation task is as follows:
1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
2. Dataset preprocessing.
3. Perform training.
4. Perform inference.
### Dataset Preprocessing
For a pre-trained model, configure the following options in the `config.json` file:
```
python create_dataset.py --src_folder /home/work_space/wmt16_de_en --output_folder /home/work_space/dataset_menu
```
### Training
For a pre-trained model, configure the following options in the `config/config.json` file:
- Assign `pre_train_dataset` and `dataset_schema` to the training dataset path.
- Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
- Set other parameters, including dataset configuration and network configuration.
- If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.
Run the shell script `run.sh`:
```bash
cd ./scripts
sh run_standalone_train_ascend.sh
```
### Inference
For inference using a trained model on multiple hardware platforms, such as GPU, Ascend 910, and Ascend 310, see [Network Migration](https://www.mindspore.cn/tutorial/en/master/advanced_use/network_migration.html).
For inference interruption, configure the following options in the `config/config.json` file:
- Assign `test_dataset` and the `dataset_schema` to the inference dataset path.
- Assign `existed_ckpt` and the `checkpoint_path` to the path of the model file generated during training.
- Set other parameters, including dataset configuration and network configuration.
Run the shell script `run.sh`:
```bash
cd ./scripts
sh run_standalone_eval_ascend.sh
```
# Random Situation Description
There are three random situations:
- Shuffle of the dataset.
- Initialization of some model weights.
- Dropout operations.
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in config/config.json.
# Others
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
# ModelZoo 主页
[链接](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)

View File

@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMTv2 model configuration."""
from .config import GNMTConfig
__all__ = [
"GNMTConfig"
]

View File

@ -0,0 +1,55 @@
{
"training_platform": {
"modelarts": false
},
"dataset_config": {
"random_seed": 50,
"epochs": 6,
"batch_size": 128,
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001",
"fine_tune_dataset": null,
"test_dataset": null,
"valid_dataset": null,
"dataset_sink_mode": true,
"dataset_sink_step": 2
},
"model_config": {
"seq_length": 51,
"vocab_size": 32320,
"hidden_size": 1024,
"num_hidden_layers": 4,
"intermediate_size": 4096,
"hidden_dropout_prob": 0.2,
"attention_dropout_prob": 0.2,
"initializer_range": 0.1,
"label_smoothing": 0.1,
"beam_width": 2,
"length_penalty_weight": 0.6,
"max_decode_length": 50
},
"loss_scale_config": {
"init_loss_scale": 65536,
"loss_scale_factor": 2,
"scale_window": 1000
},
"learn_rate_config": {
"optimizer": "adam",
"lr": 2e-3,
"lr_scheduler": "WarmupMultiStepLR",
"lr_scheduler_power": 0.5,
"warmup_lr_remain_steps": 0.666,
"warmup_lr_decay_interval": -1,
"decay_steps": 4,
"decay_start_step": -1,
"warmup_steps": 200,
"min_lr": 1e-6
},
"checkpoint_options": {
"existed_ckpt": "",
"save_ckpt_steps": 3452,
"keep_ckpt_max": 6,
"ckpt_prefix": "gnmt",
"ckpt_path": "text_translation"
}
}

View File

@ -0,0 +1,228 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Configuration class for GNMT."""
import os
import json
import copy
from typing import List
import mindspore.common.dtype as mstype
def _is_dataset_file(file: str):
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
def _get_files_from_dir(folder: str):
_files = []
for file in os.listdir(folder):
if _is_dataset_file(file):
_files.append(os.path.join(folder, file))
return _files
def get_source_list(folder: str) -> List:
"""
Get file list from a folder.
Returns:
list, file list.
"""
_list = []
if not folder:
return _list
if os.path.isdir(folder):
_list = _get_files_from_dir(folder)
else:
if _is_dataset_file(folder):
_list.append(folder)
return _list
PARAM_NODES = {"dataset_config",
"training_platform",
"model_config",
"loss_scale_config",
"learn_rate_config",
"checkpoint_options"}
class GNMTConfig:
"""
Configuration for `GNMT`.
Args:
random_seed (int): Random seed.
batch_size (int): Batch size of input dataset.
epochs (int): Epoch number.
dataset_sink_mode (bool): Whether enable dataset sink mode.
dataset_sink_step (int): Dataset sink step.
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
lr (float): Initial learning rate.
min_lr (float): Minimum learning rate.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
dataset_schema (str): Path of dataset schema file.
pre_train_dataset (str): Path of pre-training dataset file or folder.
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
test_dataset (str): Path of test dataset file or folder.
valid_dataset (str): Path of validation dataset file or folder.
ckpt_path (str): Checkpoints save path.
save_ckpt_steps (int): Interval of saving ckpt.
ckpt_prefix (str): Prefix of ckpt file.
keep_ckpt_max (int): Max ckpt files number.
seq_length (int): Length of input sequence. Default: 64.
vocab_size (int): The shape of each embedding vector. Default: 46192.
hidden_size (int): Size of embedding, attention, dim. Default: 512.
num_hidden_layers (int): Encoder, Decoder layers.
intermediate_size (int): Size of intermediate layer in the Transformer
encoder/decoder cell. Default: 4096.
hidden_act (str): Activation function used in the Transformer encoder/decoder
cell. Default: "relu".
init_loss_scale (int): Initialized loss scale.
loss_scale_factor (int): Loss scale factor.
scale_window (int): Window size of loss scale.
beam_width (int): Beam width for beam search in inferring. Default: 4.
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
label_smoothing (float): Label smoothing setting. Default: 0.1.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
is wanted.
dtype (mstype): Data type of the input. Default: mstype.float32.
max_decode_length (int): Max decode length for inferring. Default: 64.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
attention_dropout_prob (float): The dropout probability for
Multi-head Self-Attention. Default: 0.1.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
modelarts=False, random_seed=74,
epochs=6, batch_size=64,
dataset_schema: str = None,
pre_train_dataset: str = None,
fine_tune_dataset: str = None,
test_dataset: str = None,
valid_dataset: str = None,
dataset_sink_mode=True, dataset_sink_step=1,
seq_length=51, vocab_size=32320, hidden_size=1024,
num_hidden_layers=4, intermediate_size=4096,
hidden_act="tanh",
hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
initializer_range=0.1,
label_smoothing=0.1,
beam_width=5,
length_penalty_weight=1.0,
max_decode_length=50,
input_mask_from_dataset=False,
init_loss_scale=2 ** 10,
loss_scale_factor=2, scale_window=128,
lr_scheduler="", optimizer="adam",
lr=1e-4, min_lr=1e-6,
decay_steps=4, lr_scheduler_power=1,
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
decay_start_step=-1, warmup_steps=200,
existed_ckpt="", save_ckpt_steps=2000, keep_ckpt_max=20,
ckpt_prefix="gnmt", ckpt_path: str = None,
save_step=10000,
save_graphs=False,
dtype=mstype.float32):
self.save_graphs = save_graphs
self.random_seed = random_seed
self.modelarts = modelarts
self.save_step = save_step
self.dataset_schema = dataset_schema
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
self.test_dataset = get_source_list(test_dataset) # type: List[str]
if not isinstance(epochs, int) and epochs < 0:
raise ValueError("`epoch` must be type of int.")
self.epochs = epochs
self.dataset_sink_mode = dataset_sink_mode
self.dataset_sink_step = dataset_sink_step
self.ckpt_path = ckpt_path
self.keep_ckpt_max = keep_ckpt_max
self.save_ckpt_steps = save_ckpt_steps
self.ckpt_prefix = ckpt_prefix
self.existed_ckpt = existed_ckpt
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_dropout_prob = attention_dropout_prob
self.initializer_range = initializer_range
self.label_smoothing = label_smoothing
self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight
self.max_decode_length = max_decode_length
self.input_mask_from_dataset = input_mask_from_dataset
self.compute_type = mstype.float16
self.dtype = dtype
self.scale_window = scale_window
self.loss_scale_factor = loss_scale_factor
self.init_loss_scale = init_loss_scale
self.optimizer = optimizer
self.lr = lr
self.lr_scheduler = lr_scheduler
self.min_lr = min_lr
self.lr_scheduler_power = lr_scheduler_power
self.warmup_lr_remain_steps = warmup_lr_remain_steps
self.warmup_lr_decay_interval = warmup_lr_decay_interval
self.decay_steps = decay_steps
self.decay_start_step = decay_start_step
self.warmup_steps = warmup_steps
self.train_url = ""
@classmethod
def from_dict(cls, json_object: dict):
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
_params = {}
for node in PARAM_NODES:
for key in json_object[node]:
_params[key] = json_object[node][key]
return cls(**_params)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `TransformerConfig` from a json file of parameters."""
with open(json_file, "r") as reader:
return cls.from_dict(json.load(reader))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

View File

@ -0,0 +1,55 @@
{
"training_platform": {
"modelarts": false
},
"dataset_config": {
"random_seed": 50,
"epochs": 6,
"batch_size": 128,
"dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json",
"pre_train_dataset": null,
"fine_tune_dataset": null,
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001",
"valid_dataset": null,
"dataset_sink_mode": true,
"dataset_sink_step": 2
},
"model_config": {
"seq_length": 107,
"vocab_size": 32320,
"hidden_size": 1024,
"num_hidden_layers": 4,
"intermediate_size": 4096,
"hidden_dropout_prob": 0.2,
"attention_dropout_prob": 0.2,
"initializer_range": 0.1,
"label_smoothing": 0.1,
"beam_width": 2,
"length_penalty_weight": 0.6,
"max_decode_length": 80
},
"loss_scale_config": {
"init_loss_scale": 8192,
"loss_scale_factor": 2,
"scale_window": 128
},
"learn_rate_config": {
"optimizer": "adam",
"lr": 2e-3,
"lr_scheduler": "WarmupMultiStepLR",
"lr_scheduler_power": 0.5,
"warmup_lr_remain_steps": 0.666,
"warmup_lr_decay_interval": -1,
"decay_steps": 4,
"decay_start_step": -1,
"warmup_steps": 200,
"min_lr": 1e-6
},
"checkpoint_options": {
"existed_ckpt": "/home/workspace/gnmt_v2/gnmt-6_3452.ckpt",
"save_ckpt_steps": 3452,
"keep_ckpt_max": 6,
"ckpt_prefix": "gnmt",
"ckpt_path": "text_translation"
}
}

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create Dataset."""
import os
import argparse
from src.dataset.bi_data_loader import BiLingualDataLoader, TextDataLoader
from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='Generate dataset file.')
parser.add_argument("--src_folder", type=str, default="/home/workspace/wmt16_de_en", required=False,
help="Raw corpus folder.")
parser.add_argument("--output_folder", type=str, default="/home/workspace/dataset_menu",
required=False,
help="Dataset output path.")
if __name__ == '__main__':
args, _ = parser.parse_known_args()
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
dicts = []
train_src_file = "train.tok.clean.bpe.32000.en"
train_tgt_file = "train.tok.clean.bpe.32000.de"
test_src_file = "newstest2014.en"
test_tgt_file = "newstest2014.de"
vocab = args.src_folder + "/vocab.bpe.32000"
bpe_codes = args.src_folder + "/bpe.32000"
pad_vocab = 8
tokenizer = Tokenizer(vocab, bpe_codes, src_en='en', tgt_de='de', vocab_pad=pad_vocab)
test = TextDataLoader(
src_filepath=os.path.join(args.src_folder, test_src_file),
tokenizer=tokenizer,
source_max_sen_len=None,
schema_address=args.output_folder + "/" + test_src_file + ".json"
)
print(f" | It's writing, please wait a moment.")
test.write_to_tfrecord(
path=os.path.join(
args.output_folder,
os.path.basename(test_src_file) + ".tfrecord"
)
)
train = BiLingualDataLoader(
src_filepath=os.path.join(args.src_folder, train_src_file),
tgt_filepath=os.path.join(args.src_folder, train_tgt_file),
tokenizer=tokenizer,
source_max_sen_len=51,
target_max_sen_len=50,
schema_address=args.output_folder + "/" + train_src_file + ".json"
)
print(f" | It's writing, please wait a moment.")
train.write_to_tfrecord(
path=os.path.join(
args.output_folder,
os.path.basename(train_src_file) + ".tfrecord"
)
)
print(f" | Vocabulary size: {tokenizer.vocab_size}.")

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation api."""
import argparse
import pickle
from mindspore.common import dtype as mstype
from config import GNMTConfig
from src.gnmt_model import infer
from src.gnmt_model.bleu_calculate import bleu_calculate
from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='gnmt')
parser.add_argument("--config", type=str, required=True,
help="model config json file path.")
parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--bpe_codes", type=str, required=True,
help="bpe codes to use.")
parser.add_argument("--test_tgt", type=str, required=False,
default=None,
help="data file of the test target")
parser.add_argument("--output", type=str, required=False,
default="./output.npz",
help="result file path.")
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
if __name__ == '__main__':
args, _ = parser.parse_known_args()
_config = get_config(args.config)
result = infer(_config)
with open(args.output, "wb") as f:
pickle.dump(result, f, 1)
result_npy_addr = args.output
vocab = args.vocab
bpe_codes = args.bpe_codes
test_tgt = args.test_tgt
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
print(f"BLEU scores is :{scores}")

View File

@ -0,0 +1,6 @@
nltk
jieba
numpy
subword-nmt==0.3.7
sacrebleu==1.2.10
sacremoses==0.0.19

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_TABLE_FILE=/home/workspace/rank_table_8p.json
export MINDSPORE_HCCL_CONFIG_PATH=/home/workspace/rank_table_8p.json
echo $RANK_TABLE_FILE
export RANK_SIZE=8
for((i=0;i<=7;i++));
do
rm -rf ${current_exec_path}/device$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
cp ../../*.py .
cp ../../*.sh .
cp -r ../../src .
cp -r ../../config .
export RANK_ID=$i
export DEVICE_ID=$i
python ../../train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit
done
cd ${current_exec_path} || exit

View File

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export DEVICE_ID=5
export RANK_ID=0
export RANK_SIZE=1
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cp -r ../config ./eval
cd ./eval || exit
echo "start eval for device $DEVICE_ID"
env > env.log
python eval.py --config /home/workspace/gnmt_v2/config/config_test.json --vocab /home/workspace/wmt16_de_en/vocab.bpe.32000 --bpe_codes /home/workspace/wmt16_de_en/bpe.32000 --test_tgt /home/workspace/wmt16_de_en/newstest2014.de >log_infer.log 2>&1 &
cd ..

View File

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cp -r ../config ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network.log 2>&1 &
cd ..

View File

@ -0,0 +1,29 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMTv2 Init."""
from .dataset import load_dataset
from .dataset import bi_data_loader
from .gnmt_model import GNMT, infer, GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
from .gnmt_model import LabelSmoothedCrossEntropyCriterion
__all__ = [
"load_dataset",
"bi_data_loader",
"GNMT",
"infer",
"GNMTNetworkWithLoss",
"GNMTTrainOneStepWithLossScaleCell",
"LabelSmoothedCrossEntropyCriterion"
]

View File

@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset Init."""
from .bi_data_loader import BiLingualDataLoader, TextDataLoader
from .load_dataset import load_dataset
from .tokenizer import Tokenizer
__all__ = [
"load_dataset",
"BiLingualDataLoader",
"TextDataLoader",
"Tokenizer"
]

View File

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base class of data loader."""
import os
import collections
import numpy as np
from mindspore.mindrecord import FileWriter
from .schema import SCHEMA
class DataLoader:
"""Data loader for dataset."""
_SCHEMA = SCHEMA
def __init__(self):
self._examples = []
def _load(self):
raise NotImplementedError
def padding(self, sen, padding_idx, need_sentence_len=None, dtype=np.int64):
"""Padding <pad> to sentence."""
if need_sentence_len is None:
return None
if sen.shape[0] > need_sentence_len:
return None
new_sen = np.array([padding_idx] * need_sentence_len, dtype=dtype)
new_sen[:sen.shape[0]] = sen[:]
return new_sen
def write_to_mindrecord(self, path, shard_num=1, desc=""):
"""
Write mindrecord file.
Args:
path (str): File path.
shard_num (int): Shard num.
desc (str): Description.
"""
if not os.path.isabs(path):
path = os.path.abspath(path)
writer = FileWriter(file_name=path, shard_num=shard_num)
writer.add_schema(self._SCHEMA, desc)
if not self._examples:
self._load()
writer.write_raw_data(self._examples)
writer.commit()
print(f"| Wrote to {path}.")
def write_to_tfrecord(self, path, shard_num=1):
"""
Write to tfrecord.
Args:
path (str): Output file path.
shard_num (int): Shard num.
"""
import tensorflow as tf
if not os.path.isabs(path):
path = os.path.abspath(path)
output_files = []
for i in range(shard_num):
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
output_files.append(output_file)
# create writers
writers = []
for output_file in output_files:
writers.append(tf.io.TFRecordWriter(output_file))
if not self._examples:
self._load()
# create feature
features = collections.OrderedDict()
for example in self._examples:
for key in example:
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
for writer in writers:
writer.write(tf_example.SerializeToString())
for writer in writers:
writer.close()
for p in output_files:
print(f" | Write to {p}.")
def _add_example(self, example):
self._examples.append(example)

View File

@ -0,0 +1,233 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bilingual data loader."""
import numpy as np
from .base import DataLoader
from .tokenizer import Tokenizer
class BiLingualDataLoader(DataLoader):
"""Loader for bilingual data."""
def __init__(self,
src_filepath: str,
tgt_filepath: str,
tokenizer: Tokenizer,
min_sen_len=0,
source_max_sen_len=None,
target_max_sen_len=80,
schema_address=None):
super(BiLingualDataLoader, self).__init__()
self._src_filepath = src_filepath
self._tgt_filepath = tgt_filepath
self.tokenizer = tokenizer
self.min_sen_len = min_sen_len
self.source_max_sen_len = source_max_sen_len
self.target_max_sen_len = target_max_sen_len
self.schema_address = schema_address
def _load(self):
count = 0
if self.source_max_sen_len is None:
with open(self._src_filepath, "r") as _src_file:
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
max_src = 0
for _, _pair in enumerate(_src_file):
src_tokens = [
int(self.tokenizer.tok2idx[t])
for t in _pair.strip().split(" ") if t
]
src_len = len(src_tokens)
if src_len > max_src:
max_src = src_len
self.source_max_sen_len = max_src + 2
if self.target_max_sen_len is None:
with open(self._src_filepath, "r") as _tgt_file:
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
max_tgt = 0
for _, _pair in enumerate(_tgt_file):
src_tokens = [
int(self.tokenizer.tok2idx[t])
for t in _pair.strip().split(" ") if t
]
tgt_len = len(src_tokens)
if tgt_len > max_tgt:
max_tgt = tgt_len
self.target_max_sen_len = max_tgt + 1
with open(self._src_filepath, "r") as _src_file:
print(f" | Processing corpus {self._src_filepath}.")
print(f" | Processing corpus {self._tgt_filepath}.")
with open(self._tgt_filepath, "r") as _tgt_file:
for _, _pair in enumerate(zip(_src_file, _tgt_file)):
src_tokens = [
int(self.tokenizer.tok2idx[t])
for t in _pair[0].strip().split(" ") if t
]
tgt_tokens = [
int(self.tokenizer.tok2idx[t])
for t in _pair[1].strip().split(" ") if t
]
src_tokens.insert(0, self.tokenizer.bos_index)
src_tokens.append(self.tokenizer.eos_index)
tgt_tokens.insert(0, self.tokenizer.bos_index)
tgt_tokens.append(self.tokenizer.eos_index)
src_tokens = np.array(src_tokens)
tgt_tokens = np.array(tgt_tokens)
src_len = src_tokens.shape[0]
tgt_len = tgt_tokens.shape[0]
if (src_len > self.source_max_sen_len) or (src_len < self.min_sen_len) or (
tgt_len > (self.target_max_sen_len + 1)) or (tgt_len < self.min_sen_len):
print(f"+++++ delete! src_len={src_len}, tgt_len={tgt_len - 1}, "
f"source_max_sen_len={self.source_max_sen_len},"
f"target_max_sen_len={self.target_max_sen_len}")
continue
# encoder inputs
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
for i in range(src_len):
src_padding[i] = 1
src_length = np.array([src_len], dtype=np.int64)
# decoder inputs
decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)
# decoder outputs
decoder_output = self.padding(tgt_tokens[1:], self.tokenizer.padding_index, self.target_max_sen_len)
tgt_padding = np.zeros(shape=self.target_max_sen_len + 1, dtype=np.int64)
for j in range(tgt_len):
tgt_padding[j] = 1
tgt_padding = tgt_padding[1:]
decoder_input = np.array(decoder_input, dtype=np.int64)
decoder_output = np.array(decoder_output, dtype=np.int64)
tgt_padding = np.array(tgt_padding, dtype=np.int64)
example = {
"src": encoder_input,
"src_padding": src_padding,
"src_length": src_length,
"prev_opt": decoder_input,
"target": decoder_output,
"tgt_padding": tgt_padding
}
self._add_example(example)
count += 1
print(f" | source padding_len = {self.source_max_sen_len}.")
print(f" | target padding_len = {self.target_max_sen_len}.")
print(f" | Total activate sen = {count}.")
print(f" | Total sen = {count}.")
if self.schema_address is not None:
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1,
self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]
columns = ["src", "src_padding", "src_length", "prev_opt", "target", "tgt_padding"]
with open(self.schema_address, "w", encoding="utf-8") as f:
f.write("{\n")
f.write(' "datasetType":"TF",\n')
f.write(' "numRows":%s,\n' % provlist[0])
f.write(' "columns":{\n')
t = 1
for name in columns:
f.write(' "%s":{\n' % name)
f.write(' "type":"int64",\n')
f.write(' "rank":1,\n')
f.write(' "shape":[%s]\n' % provlist[t])
f.write(' }')
if t < len(columns):
f.write(',')
f.write('\n')
t += 1
f.write(' }\n}\n')
print(" | Write to " + self.schema_address)
class TextDataLoader(DataLoader):
"""Loader for text data."""
def __init__(self,
src_filepath: str,
tokenizer: Tokenizer,
min_sen_len=0,
source_max_sen_len=None,
schema_address=None):
super(TextDataLoader, self).__init__()
self._src_filepath = src_filepath
self.tokenizer = tokenizer
self.min_sen_len = min_sen_len
self.source_max_sen_len = source_max_sen_len
self.schema_address = schema_address
def _load(self):
count = 0
if self.source_max_sen_len is None:
with open(self._src_filepath, "r") as _src_file:
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
max_src = 0
for _, _pair in enumerate(_src_file):
src_tokens = self.tokenizer.tokenize(_pair)
src_len = len(src_tokens)
if src_len > max_src:
max_src = src_len
self.source_max_sen_len = max_src
with open(self._src_filepath, "r") as _src_file:
print(f" | Processing corpus {self._src_filepath}.")
for _, _pair in enumerate(_src_file):
src_tokens = self.tokenizer.tokenize(_pair)
src_len = len(src_tokens)
src_tokens = np.array(src_tokens)
# encoder inputs
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
for i in range(src_len):
src_padding[i] = 1
src_length = np.array([src_len], dtype=np.int64)
example = {
"src": encoder_input,
"src_padding": src_padding,
"src_length": src_length
}
self._add_example(example)
count += 1
print(f" | source padding_len = {self.source_max_sen_len}.")
print(f" | Total activate sen = {count}.")
print(f" | Total sen = {count}.")
if self.schema_address is not None:
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1]
columns = ["src", "src_padding", "src_length"]
with open(self.schema_address, "w", encoding="utf-8") as f:
f.write("{\n")
f.write(' "datasetType":"TF",\n')
f.write(' "numRows":%s,\n' % provlist[0])
f.write(' "columns":{\n')
t = 1
for name in columns:
f.write(' "%s":{\n' % name)
f.write(' "type":"int64",\n')
f.write(' "rank":1,\n')
f.write(' "shape":[%s]\n' % provlist[t])
f.write(' }')
if t < len(columns):
f.write(',')
f.write('\n')
t += 1
f.write(' }\n}\n')
print(" | Write to " + self.schema_address)

View File

@ -0,0 +1,147 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset loader to feed into model."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True,
drop_remainder=True, is_translate=False):
"""
Load dataset according to passed in params.
Args:
input_files (list): Data files.
schema_file (str): Schema file path.
batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size.
rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset.
drop_remainder (bool): Whether drop the last possibly incomplete batch.
is_translate (bool): Whether translate the text.
Returns:
Dataset, dataset instance.
"""
if not input_files:
raise FileNotFoundError("Require at least one dataset.")
if not (schema_file and
os.path.exists(schema_file)
and os.path.isfile(schema_file)
and os.path.basename(schema_file).endswith(".json")):
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
if not isinstance(sink_mode, bool):
raise ValueError("`sink` must be type of bool.")
for datafile in input_files:
print(f" | Loading {datafile}.")
if not is_translate:
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
"src", "src_padding", "src_length",
"prev_opt",
"target", "tgt_padding"
],
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.")
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="prev_opt", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="target", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op, num_parallel_workers=8)
ds = ds.rename(
input_columns=["src",
"src_padding",
"src_length",
"prev_opt",
"target",
"tgt_padding"],
output_columns=["source_eos_ids",
"source_eos_mask",
"source_eos_length",
"target_sos_ids",
"target_eos_ids",
"target_eos_mask"]
)
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
else:
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
"src", "src_padding", "src_length"
],
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.")
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
ds = ds.rename(
input_columns=["src",
"src_padding",
"src_length"],
output_columns=["source_eos_ids",
"source_eos_mask",
"source_eos_length"]
)
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds
def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: int, sink_mode: bool, sink_step: int = 1,
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
"""
Load dataset.
Args:
data_files (list): Data files.
schema (str): Schema file path.
batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size.
rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset.
Returns:
Dataset, dataset instance.
"""
return _load_dataset(data_files, schema, batch_size, epoch_count, sink_mode,
sink_step, rank_size, rank_id, shuffle=shuffle,
drop_remainder=drop_remainder, is_translate=is_translate)

View File

@ -0,0 +1,24 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define schema of mindrecord."""
SCHEMA = {
"src": {"type": "int64", "shape": [-1]},
"src_padding": {"type": "int64", "shape": [-1]},
"src_length": {"type": "int64", "shape": [-1]},
"prev_opt": {"type": "int64", "shape": [-1]},
"target": {"type": "int64", "shape": [-1]},
"tgt_padding": {"type": "int64", "shape": [-1]},
}

View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tokenizer."""
import os
from collections import defaultdict
from functools import partial
import subword_nmt.apply_bpe
import sacremoses
class Tokenizer:
"""
Tokenizer class.
"""
def __init__(self, vocab_address=None, bpe_code_address=None,
src_en='en', tgt_de='de', vocab_pad=8, isolator='@@'):
"""
Constructor for the Tokenizer class.
Args:
vocab_address: vocabulary address.
bpe_code_address: path to the file with bpe codes.
vocab_pad: pads vocabulary to a multiple of 'vocab_pad' tokens.
isolator: tokenization isolator.
"""
self.padding_index = 0
self.unk_index = 1
self.bos_index = 2
self.eos_index = 3
self.pad_word = '<pad>'
self.unk_word = '<unk>'
self.bos_word = '<s>'
self.eos_word = r'<\s>'
self.isolator = isolator
self.init_bpe(bpe_code_address)
self.vocab_establist(vocab_address, vocab_pad)
self.sacremoses_tokenizer = sacremoses.MosesTokenizer(src_en)
self.sacremoses_detokenizer = sacremoses.MosesDetokenizer(tgt_de)
def init_bpe(self, bpe_code_address):
"""Init bpe."""
if (bpe_code_address is not None) and os.path.exists(bpe_code_address):
with open(bpe_code_address, 'r') as f1:
self.bpe = subword_nmt.apply_bpe.BPE(f1)
def vocab_establist(self, vocab_address, vocab_pad):
"""Establish vocabulary."""
if (vocab_address is None) or (not os.path.exists(vocab_address)):
return
vocab_words = [self.pad_word, self.unk_word, self.bos_word, self.eos_word]
with open(vocab_address) as f1:
for sentence in f1:
vocab_words.append(sentence.strip())
vocab_size = len(vocab_words)
padded_vocab_size = (vocab_size + vocab_pad - 1) // vocab_pad * vocab_pad
for idx in range(0, padded_vocab_size - vocab_size):
fil_token = f'filled{idx:04d}'
vocab_words.append(fil_token)
self.vocab_size = len(vocab_words)
self.tok2idx = defaultdict(partial(int, self.unk_index))
for idx, token in enumerate(vocab_words):
self.tok2idx[token] = idx
self.idx2tok = {}
self.idx2tok = defaultdict(partial(str, ","))
for token, idx in self.tok2idx.items():
self.idx2tok[idx] = token
def tokenize(self, sentence):
"""Tokenize sentence."""
tokenized = self.sacremoses_tokenizer.tokenize(sentence, return_str=True)
bpe = self.bpe.process_line(tokenized)
sentence = bpe.strip().split()
inputs = [self.tok2idx[i] for i in sentence]
inputs = [self.bos_index] + inputs + [self.eos_index]
return inputs
def detokenize(self, indexes, gap=' '):
"""Detokenizes single sentence and removes token isolator characters."""
reconstruction_bpe = gap.join([self.idx2tok[idx] for idx in indexes])
reconstruction_bpe = reconstruction_bpe.replace(self.isolator + ' ', '')
reconstruction_bpe = reconstruction_bpe.replace(self.isolator, '')
reconstruction_bpe = reconstruction_bpe.replace(self.bos_word, '')
reconstruction_bpe = reconstruction_bpe.replace(self.eos_word, '')
reconstruction_bpe = reconstruction_bpe.replace(self.unk_word, '')
reconstruction_bpe = reconstruction_bpe.replace(self.pad_word, '')
reconstruction_bpe = reconstruction_bpe.strip()
reconstruction_words = self.sacremoses_detokenizer.detokenize(reconstruction_bpe.split())
return reconstruction_words

View File

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMTv2 Init."""
from config.config import GNMTConfig
from .gnmt import GNMT
from .attention import BahdanauAttention
from .gnmt_for_train import GNMTTraining, LabelSmoothedCrossEntropyCriterion, \
GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
from .gnmt_for_infer import infer
from .bleu_calculate import bleu_calculate
__all__ = [
"infer",
"GNMTTraining",
"LabelSmoothedCrossEntropyCriterion",
"GNMTTrainOneStepWithLossScaleCell",
"GNMTNetworkWithLoss",
"GNMT",
"BahdanauAttention",
"GNMTConfig",
"bleu_calculate"
]

View File

@ -0,0 +1,201 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bahdanau attention block."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.ops.operations as P
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import Uniform
INF = 65504.0
class BahdanauAttention(nn.Cell):
"""
Constructor for the BahdanauAttention.
Args:
is_training (bool): Whether to train.
query_size (int): feature dimension for query.
key_size (int): feature dimension for keys.
num_units (int): internal feature dimension.
normalize (bool): Whether to normalize.
initializer_range: range for uniform initializer parameters.
Returns:
Tensor, shape (N, T, D).
"""
def __init__(self,
is_training,
query_size,
key_size,
num_units,
normalize=False,
initializer_range=0.1,
compute_type=mstype.float16):
super(BahdanauAttention, self).__init__()
self.is_training = is_training
self.mask = None
self.query_size = query_size
self.key_size = key_size
self.normalize = normalize
self.num_units = num_units
self.linear_att = Parameter(Tensor(np.random.uniform(-initializer_range, initializer_range, size=[num_units]),
dtype=mstype.float32), name='linear_att')
if self.normalize:
self.normalize_scalar = Parameter(Tensor(np.array([1.0 / num_units]), dtype=mstype.float32),
name='normalize_scalar')
self.normalize_bias = Parameter(Tensor(np.zeros(num_units), dtype=mstype.float32), name='normalize_bias')
self.transpose = P.Transpose()
self.transpose_orders = (1, 0, 2)
self.shape_op = P.Shape()
self.linear_q = nn.Dense(query_size,
num_units,
has_bias=False,
weight_init=Uniform(initializer_range)).to_float(compute_type)
self.linear_k = nn.Dense(key_size,
num_units,
has_bias=False,
weight_init=Uniform(initializer_range)).to_float(compute_type)
self.expand = P.ExpandDims()
self.tile = P.Tile()
self.norm = nn.Norm(axis=-1)
self.mul = P.Mul()
self.matmul = P.MatMul()
self.batchMatmul = P.BatchMatMul()
self.tanh = nn.Tanh()
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
self.reshape = P.Reshape()
self.cast = P.Cast()
def construct(self, query, keys, attention_mask=None):
"""
Construct attention block.
Args:
query (Tensor): Shape (t_q, N, D).
keys (Tensor): Shape (t_k, N, D).
attention_mask: Shape(N, t_k).
Returns:
Tensor, shape (N, t_q, D).
"""
# (t_k, N, D) -> (N, t_k, D).
keys = self.transpose(keys, self.transpose_orders)
# (t_q, N, D) -> (N, t_q, D).
query = self.transpose(query, self.transpose_orders)
query_shape = self.shape_op(query)
b = query_shape[0]
t_q = query_shape[1]
t_k = self.shape_op(keys)[1]
# (N, t_q, D)
query = self.reshape(query, (b * t_q, self.query_size))
if self.is_training:
query = self.cast(query, mstype.float16)
processed_query = self.linear_q(query)
if self.is_trining:
processed_query = self.cast(processed_query, mstype.float32)
processed_query = self.reshape(processed_query, (b, t_q, self.num_units))
# (N, t_k, D)
keys = self.reshape(keys, (b * t_k, self.key_size))
if self.is_training:
keys = self.cast(keys, mstype.float16)
processed_key = self.linear_k(keys)
if self.is_trining:
processed_key = self.cast(processed_key, mstype.float32)
processed_key = self.reshape(processed_key, (b, t_k, self.num_units))
# scores: (N T_q T_k)
scores = self.calc_score(processed_query, processed_key)
# attention_mask: (N, T_k)
mask = attention_mask
# [N 1]
if mask is not None:
mask = 1.0 - mask
mask = self.tile(self.expand(mask, 1), (1, t_q, 1))
scores += mask * (-INF)
# [b, t_q, t_k]
scores_normalized = self.softmax(scores)
keys = self.reshape(keys, (b, t_k, self.key_size))
if self.is_training:
keys = self.cast(keys, mstype.float16)
scores_normalized_fp16 = self.cast(scores_normalized, mstype.float16)
else:
scores_normalized_fp16 = scores_normalized
# (b, t_q, n)
context_attention = self.batchMatmul(scores_normalized_fp16, keys)
# [t_q,b,D]
context_attention = self.transpose(context_attention, self.transpose_orders)
if self.is_training:
context_attention = self.cast(context_attention, mstype.float32)
return context_attention, scores_normalized
def calc_score(self, att_query, att_keys):
"""
Calculate Bahdanau score
Args:
att_query: (N, T_q, D).
att_keys: (N, T_k, D).
returns:
scores: (N, T_q, T_k).
"""
b, t_k, n = self.shape_op(att_keys)
t_q = self.shape_op(att_query)[1]
# (b, t_q, t_k, n)
att_query = self.tile(self.expand(att_query, 2), (1, 1, t_k, 1))
att_keys = self.tile(self.expand(att_keys, 1), (1, t_q, 1, 1))
# (b, t_q, t_k, n)
sum_qk = att_query + att_keys
if self.normalize:
# (b, t_q, t_k, n)
sum_qk = sum_qk + self.normalize_bias
linear_att = self.linear_att / self.norm(self.linear_att)
linear_att = self.cast(linear_att, mstype.float32)
linear_att = self.mul(linear_att, self.normalize_scalar)
else:
linear_att = self.linear_att
linear_att = self.expand(linear_att, -1)
sum_qk = self.reshape(sum_qk, (-1, n))
tanh_sum_qk = self.tanh(sum_qk)
if self.is_training:
linear_att = self.cast(linear_att, mstype.float16)
tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16)
out = self.matmul(tanh_sum_qk, linear_att)
# (b, t_q, t_k)
out = self.reshape(out, (b, t_q, t_k))
if self.is_training:
out = self.cast(out, mstype.float32)
return out

View File

@ -0,0 +1,421 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Beam search decoder."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
INF = 65536.0
class LengthPenalty(nn.Cell):
"""
Length penalty.
Args:
weight (float): The length penalty weight.
compute_type (mstype): Mindspore data type. Default: mstype.float32.
"""
def __init__(self, weight=1.0, compute_type=mstype.float32):
super(LengthPenalty, self).__init__()
self.weight = weight
self.add = P.TensorAdd()
self.pow = P.Pow()
self.div = P.RealDiv()
self.five = Tensor(5.0, mstype.float32)
self.six = Tensor(6.0, mstype.float32)
self.cast = P.Cast()
def construct(self, length_tensor):
"""
Process source sentence
Inputs:
length_tensor (Tensor): the input tensor.
Returns:
Tensor, after punishment of length.
"""
length_tensor = self.cast(length_tensor, mstype.float32)
output = self.add(length_tensor, self.five)
output = self.div(output, self.six)
output = self.pow(output, self.weight)
return output
class TileBeam(nn.Cell):
"""
Beam Tile operation.
Args:
beam_width (int): The Number of beam.
compute_type (mstype): Mindspore data type. Default: mstype.float32.
"""
def __init__(self, beam_width, compute_type=mstype.float32):
super(TileBeam, self).__init__()
self.beam_width = beam_width
self.expand = P.ExpandDims()
self.tile = P.Tile()
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, input_tensor):
"""
Process source sentence
Inputs:
input_tensor (Tensor): with shape (N, T, D).
Returns:
Tensor, tiled tensor.
"""
shape = self.shape(input_tensor)
# add an dim
input_tensor = self.expand(input_tensor, 1)
# get tile shape: [1, beam, ...]
# shape = self.shape(input_tensor)
tile_shape = (1,) + (self.beam_width,)
for _ in range(len(shape) - 1):
tile_shape = tile_shape + (1,)
# tile
output = self.tile(input_tensor, tile_shape)
# reshape to [batch*beam, ...]
out_shape = (shape[0] * self.beam_width,) + shape[1:]
output = self.reshape(output, out_shape)
return output
class Mod(nn.Cell):
"""
Mod operation.
Args:
compute_type (mstype): Mindspore data type. Default: mstype.float32.
"""
def __init__(self,
compute_type=mstype.float32):
super(Mod, self).__init__()
self.compute_type = compute_type
self.floor_div = P.FloorDiv()
self.sub = P.Sub()
self.multiply = P.Mul()
def construct(self, input_x, input_y):
"""
Get the remainder of input_x and input_y.
Inputs:
input_x (Tensor): Divisor.
input_y (Tensor): Dividend.
Returns:
Tensor, remainder.
"""
x = self.floor_div(input_x, input_y)
x = self.multiply(x, input_y)
x = self.sub(input_x, x)
return x
class BeamSearchDecoder(nn.Cell):
"""
Beam search decoder.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
vocab_size (int): The shape of each embedding vector.
decoder (Cell): The GNMT decoder.
beam_width (int): Beam width for beam search in inferring. Default: 4.
decoder_layers_nums (int): The nums of decoder layers.
length_penalty_weight (float): Penalty for sentence length. Default: 0.6.
max_decode_length (int): Max decode length for inferring. Default: 64.
sos_id (int): The index of start label <SOS>. Default: 1.
eos_id (int): The index of end label <EOS>. Default: 2.
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
Returns:
Tensor, predictions output.
"""
def __init__(self,
batch_size,
seq_length,
vocab_size,
decoder,
beam_width=4,
decoder_layers_nums=4,
length_penalty_weight=0.6,
cov_penalty_factor=0.1,
hidden_size=1024,
max_decode_length=64,
sos_id=2,
eos_id=3,
compute_type=mstype.float32):
super(BeamSearchDecoder, self).__init__()
self.encoder_length = seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
self.vocab_size = vocab_size
self.beam_width = beam_width
self.decoder_layers_nums = decoder_layers_nums
self.length_penalty_weight = length_penalty_weight
self.cov_penalty_factor = cov_penalty_factor
self.max_decode_length = max_decode_length
self.decoder = decoder
self.add = P.TensorAdd()
self.expand = P.ExpandDims()
self.reshape = P.Reshape()
self.shape_flat = (-1,)
self.shape = P.Shape()
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
self.select = P.Select()
self.flat_shape = (batch_size, beam_width * vocab_size)
self.topk = P.TopK(sorted=True)
self.floor_div = P.FloorDiv()
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
self.real_div = P.RealDiv()
self.mod = Mod()
self.equal = P.Equal()
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
self.beam_ids = Tensor(beam_ids, mstype.int32)
batch_ids = np.arange(batch_size * beam_width).reshape((batch_size, beam_width)) // beam_width
self.batch_ids = Tensor(batch_ids, mstype.int32)
self.concat = P.Concat(axis=-1)
self.gather_nd = P.GatherNd()
self.start = Tensor(0, dtype=mstype.int32)
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), mstype.int32)
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
self.init_scores = Tensor(init_scores, mstype.float32)
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
self.one = Tensor(1, mstype.int32)
self.prob_concat = P.Concat(axis=1)
self.cast = P.Cast()
self.decoder_hidden_state = Tensor(np.zeros([self.decoder_layers_nums, 2,
self.batch_size * self.beam_width,
hidden_size]), mstype.float32)
self.zeros_scores = Tensor(np.zeros([batch_size, beam_width], dtype=np.float))
self.active_index = Tensor(np.ones([batch_size, beam_width], dtype=np.int32))
self.init_zeros = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
self.init_ones = Tensor(np.ones([batch_size, beam_width], dtype=np.float32))
self.accu_attn_scores = Tensor(np.zeros([batch_size, beam_width, self.encoder_length], dtype=np.float32))
self.zeros = Tensor([0], mstype.int32)
self.eos_tensor = Tensor(np.full([batch_size, beam_width, beam_width], eos_id), mstype.int32)
self.ones_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 1), mstype.float32)
self.neg_inf_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], -INF), mstype.float32)
self.zeros_3d = Tensor(np.full([batch_size, beam_width, self.encoder_length], 0), mstype.float32)
self.zeros_2d = Tensor(np.full([batch_size * beam_width, self.encoder_length], 0), mstype.int32)
self.argmin = P.ArgMinWithValue(axis=1)
self.reducesum = P.ReduceSum()
self.div = P.Div()
self.shape_op = P.Shape()
self.mul = P.Mul()
self.log = P.Log()
self.less = P.Less()
self.tile = P.Tile()
self.noteq = P.Neg()
self.zeroslike = P.ZerosLike()
self.greater_equal = P.GreaterEqual()
self.sub = P.Sub()
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None,
state_finished=None):
"""
Beam search one_step output.
Inputs:
cur_input_ids (Tensor): with shape (batch_size * beam_width, 1).
enc_states (Tensor): with shape (batch_size * beam_width, T, D).
enc_attention_mask (Tensor): with shape (batch_size * beam_width, T).
state_log_probs (Tensor): with shape (batch_size, beam_width).
state_seq (Tensor): with shape (batch_size, beam_width, max_decoder_length).
state_length (Tensor): with shape (batch_size, beam_width).
idx (Tensor): with shape ().
decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D).
accu_attn_scores (Tensor): with shape (batchsize, beam_width, seq_length).
state_finished (Tensor): with shape (batch_size, beam_width).
"""
# log_probs, [batch_size * beam_width, 1, V]
log_probs, all_decoder_state, attn = self.decoder(cur_input_ids, enc_states, enc_attention_mask,
decoder_hidden_state)
# consider attention_scores
attn = self.reshape(attn, (-1, self.beam_width, self.encoder_length))
state_finished_attn = self.cast(state_finished, mstype.int32)
attn_mask_0 = self.tile(self.expand(state_finished_attn, 2), (1, 1, self.encoder_length))
attn_mask_0 = self.cast(attn_mask_0, mstype.bool_)
attn_new = self.select(attn_mask_0, self.zeros_3d, attn)
accu_attn_scores = self.add(accu_attn_scores, attn_new)
# log_probs: [batch_size, beam_width, V]
log_probs = self.reshape(log_probs, (-1, self.beam_width, self.vocab_size))
# select topk indices, [batch_size, beam_width, V]
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
# mask finished beams, [batch_size, beam_width]
# t-1 has finished
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
# save the t-1 probability
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
# [batch, beam*vocab]
flat_scores = self.reshape(total_log_probs, (-1, self.beam_width * self.vocab_size))
# select topk, [batch, beam]
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
# convert to beam and word indices, [batch, beam]
# beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor)
# word_indices = self.mod(topk_indices, self.vocab_size_tensor)
# ======================================================================
# replace floor_div and mod op, since these two ops only support fp16 on
# Ascend310, which will cause overflow.
temp = topk_indices
beam_indices = self.zeroslike(topk_indices)
for _ in range(self.beam_width - 1):
temp = self.sub(temp, self.vocab_size_tensor)
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
beam_indices = beam_indices + res
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
# ======================================================================
# mask finished indices, [batch, beam]
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
word_indices = self.select(state_finished, self.eos_ids, word_indices)
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
# sort according to scores with -inf for finished beams, [batch, beam]
# t ends
tmp_log_probs = self.select(
self.equal(word_indices, self.eos_ids),
self.ninf_tensor,
topk_scores)
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
# update, [batch_size, beam_width, 2]
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
# [batch_size, beam_width]
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
# gather indices for selecting alive beams
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
# length add 1 if not finished in the previous step, [batch_size, beam_width]
length_add = self.add(state_length, self.one)
state_length = self.select(state_finished, state_length, length_add)
state_length = self.gather_nd(state_length, gather_indices)
# concat seq
seq = self.gather_nd(state_seq, gather_indices)
# update accu_attn_scores
accu_attn_scores = self.gather_nd(accu_attn_scores, gather_indices)
# update all_decoder_state
all_decoder_state = self.reshape(all_decoder_state,
(self.decoder_layers_nums * 2, self.batch_size, self.beam_width,
self.hidden_size))
for i in range(self.decoder_layers_nums * 2):
all_decoder_state[i, :, :, :] = self.gather_nd(all_decoder_state[i, :, :, :], gather_indices)
all_decoder_state = self.reshape(all_decoder_state,
(self.decoder_layers_nums, 2, self.batch_size * self.beam_width,
self.hidden_size))
# update state_seq
state_seq_new = self.cast(seq, mstype.float32)
word_indices_fp32 = self.cast(word_indices, mstype.float32)
state_seq_new[:, :, idx] = word_indices_fp32
state_seq = self.cast(state_seq_new, mstype.int32)
cur_input_ids = self.reshape(word_indices, (-1, 1))
state_log_probs = topk_scores
state_finished = self.equal(word_indices, self.eos_ids)
return cur_input_ids, state_log_probs, state_seq, state_length, \
all_decoder_state, accu_attn_scores, state_finished
def construct(self, enc_states, enc_attention_mask):
"""
Process source sentence
Inputs:
enc_states (Tensor): Output of transformer encoder with shape (batch_size * beam_width, T, D).
enc_attention_mask (Tensor): encoder attention mask with shape (batch_size * beam_width, T).
Returns:
Tensor, predictions output.
"""
# beam search start
cur_input_ids = self.start_ids
state_log_probs = self.init_scores
state_seq = self.init_seq
state_finished = self.init_finished
state_length = self.init_length
decoder_hidden_state = self.decoder_hidden_state
accu_attn_scores = self.accu_attn_scores
idx = self.start + 1
ends = self.start + self.max_decode_length + 1
while idx < ends:
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores,
state_finished)
idx = idx + 1
# add length penalty scores
penalty_len = self.length_penalty(state_length)
# return penalty_len
log_probs = self.real_div(state_log_probs, penalty_len)
penalty_cov = C.clip_by_value(accu_attn_scores, 0.0, 1.0)
penalty_cov = self.log(penalty_cov)
penalty_less = self.less(penalty_cov, self.neg_inf_3d)
penalty = self.select(penalty_less, self.zeros_3d, penalty_cov)
penalty = self.reducesum(penalty, 2)
log_probs = log_probs + penalty * self.cov_penalty_factor
# sort according to scores
_, top_beam_indices = self.topk(log_probs, self.beam_width)
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
# sort sequence and attention scores
predicted_ids = self.gather_nd(state_seq, gather_indices)
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]
return predicted_ids

View File

@ -0,0 +1,93 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Calculate the blue scores"""
import subprocess
import numpy as np
from src.dataset.tokenizer import Tokenizer
def load_result_data(result_npy_addr):
# load the numpy to list.
result = np.load(result_npy_addr, allow_pickle=True)
return result
def get_bleu_data(tokenizer: Tokenizer, result_npy_addr):
"""
Detokenizer the prediction.
Args:
tokenizer (Tokenizer): tokenizer operations.
result_npy_addr (string): Path to the predict file.
Returns:
List, the predict text context.
"""
result = load_result_data(result_npy_addr)
prediction_list = []
for _, info in enumerate(result):
# prediction detokenize
prediction = info["prediction"]
prediction_str = tokenizer.detokenize(prediction)
prediction_list.append(prediction_str)
return prediction_list
def calculate_sacrebleu(predict_path, target_path):
"""
Calculate the BLEU scores.
Args:
predict_path (string): Path to the predict file.
target_path (string): Path to the target file.
Returns:
Float32, bleu scores.
"""
sacrebleu_params = '--score-only -lc --tokenize intl'
sacrebleu = subprocess.run([f'sacrebleu --input {predict_path} \
{target_path} {sacrebleu_params}'],
stdout=subprocess.PIPE, shell=True)
bleu_scores = round(float(sacrebleu.stdout.strip()), 2)
return bleu_scores
def bleu_calculate(tokenizer, result_npy_addr, target_addr=None):
"""
Calculate the BLEU scores.
Args:
tokenizer (Tokenizer): tokenizer operations.
result_npy_addr (string): Path to the predict file.
target_addr (string): Path to the target file.
Returns:
Float32, bleu scores.
"""
prediction = get_bleu_data(tokenizer, result_npy_addr)
print("predict:\n", prediction)
eval_path = './predict.txt'
with open(eval_path, 'w') as eval_file:
lines = [line + '\n' for line in prediction]
eval_file.writelines(lines)
reference_path = target_addr
bleu_scores = calculate_sacrebleu(eval_path, reference_path)
return bleu_scores

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Components of model."""
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
class SaturateCast(nn.Cell):
"""Cast wrapper."""
def __init__(self, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
return self.cast(x, self.dst_type)
class LayerNorm(nn.Cell):
"""
Do layer norm.
Args:
in_channels (int): In channels number of layer norm.
return_2d (bool): Whether return 2d tensor.
Returns:
Tensor, output.
"""
def __init__(self, in_channels=None, return_2d=False):
super(LayerNorm, self).__init__()
self.return_2d = return_2d
self.layer_norm = nn.LayerNorm((in_channels,))
self.cast = P.Cast()
self.get_dtype = P.DType()
self.reshape = P.Reshape()
self.get_shape = P.Shape()
def construct(self, input_tensor):
"""Do layer norm."""
shape = self.get_shape(input_tensor)
batch_size = shape[0]
max_len = shape[1]
embed_dim = shape[2]
output = self.reshape(input_tensor, (-1, embed_dim))
output = self.cast(output, mstype.float32)
output = self.layer_norm(output)
output = self.cast(output, self.get_dtype(input_tensor))
if not self.return_2d:
output = self.reshape(output, (batch_size, max_len, embed_dim))
return output

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create attention block."""
import mindspore.common.dtype as mstype
from mindspore import nn
from .attention import BahdanauAttention
class RecurrentAttention(nn.Cell):
"""
Constructor for the RecurrentAttention.
Args:
input_size: number of features in input tensor.
context_size: number of features in output from encoder.
hidden_size: internal hidden size.
num_layers: number of layers in LSTM.
dropout: probability of dropout (on input to LSTM layer).
initializer_range: range for the uniform initializer.
Returns:
Tensor, shape (N, T, D).
"""
def __init__(self,
rnn,
is_training=True,
input_size=1024,
context_size=1024,
hidden_size=1024,
num_layers=1,
dropout=0.2,
initializer_range=0.1):
super(RecurrentAttention, self).__init__()
self.dropout = nn.Dropout(keep_prob=1.0 - dropout)
self.rnn = rnn
self.attn = BahdanauAttention(is_training=is_training,
query_size=hidden_size,
key_size=hidden_size,
num_units=hidden_size,
normalize=True,
initializer_range=initializer_range,
compute_type=mstype.float16)
def construct(self, decoder_embedding, context_key, attention_mask=None, rnn_init_state=None):
# decoder_embedding: [t_q,N,D]
# context: [t_k,N,D]
# attention_mask: [N,t_k]
# [t_q,N,D]
decoder_embedding = self.dropout(decoder_embedding)
rnn_outputs, rnn_state = self.rnn(decoder_embedding, rnn_init_state)
# rnn_outputs:[t_q,b,D], attn_outputs:[t_q,b,D], scores:[b, t_q, t_k], rnn_state:tuple([2,b,D]).
attn_outputs, scores = self.attn(query=rnn_outputs, keys=context_key, attention_mask=attention_mask)
return rnn_outputs, attn_outputs, rnn_state, scores

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create attention paddings from input paddings."""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
class CreateAttentionPaddingsFromInputPaddings(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (GNMTConfig): Config class.
Returns:
Tensor, shape of (N, T, T).
"""
def __init__(self,
config,
is_training=True):
super(CreateAttentionPaddingsFromInputPaddings, self).__init__()
self.is_training = is_training
self.input_mask = None
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.batch_matmul = P.BatchMatMul()
self.multiply = P.Mul()
self.shape = P.Shape()
# mask future positions
ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))
self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)
def construct(self, input_mask, mask_future=False):
"""
Construct network.
Args:
input_mask (Tensor): Tensor mask vectors with shape (N, T).
mask_future (bool): Whether mask future (for decoder training).
Returns:
Tensor, shape of (N, T, T).
"""
input_shape = self.shape(input_mask)
# Add this for infer as the seq_length will increase.
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
if self.is_training:
input_mask = self.cast(input_mask, mstype.float16)
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.batch_matmul(mask_left, mask_right)
if self.is_training:
attention_mask = self.cast(attention_mask, mstype.float32)
if mask_future:
attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)
return attention_mask

View File

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Decoder of GNMT."""
import numpy as np
from mindspore.common.initializer import Uniform
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .dynamic_rnn import DynamicRNNNet
from .create_attention import RecurrentAttention
class GNMTDecoder(nn.Cell):
"""
Implements of Transformer decoder.
Args:
attn_embed_dim (int): Dimensions of attention layer.
decoder_layers (int): Decoder layers.
num_attn_heads (int): Attention heads number.
intermediate_size (int): Hidden size of FFN.
attn_dropout_prob (float): Dropout rate in attention. Default: 0.1.
initializer_range (float): Initial range. Default: 0.02.
dropout_prob (float): Dropout rate between layers. Default: 0.1.
hidden_act (str): Non-linear activation function in FFN. Default: "relu".
compute_type (mstype): Mindspore data type. Default: mstype.float32.
Returns:
Tensor, shape of (N, T', D).
"""
def __init__(self,
config: GNMTConfig,
is_training: bool,
use_one_hot_embeddings: bool = False,
initializer_range=0.1,
infer_beam_width=1,
compute_type=mstype.float16):
super(GNMTDecoder, self).__init__()
self.is_training = is_training
self.attn_embed_dim = config.hidden_size
self.num_layers = config.num_hidden_layers
self.hidden_dropout_prob = config.hidden_dropout_prob
self.vocab_size = config.vocab_size
self.seq_length = config.max_decode_length
# batchsize* beam_width for beam_search.
self.batch_size = config.batch_size * infer_beam_width
self.word_embed_dim = config.hidden_size
self.transpose = P.Transpose()
self.transpose_orders = (1, 0, 2)
self.reshape = P.Reshape()
self.concat = P.Concat(axis=-1)
self.state_concat = P.Concat(axis=0)
self.all_decoder_state = Tensor(np.zeros([self.num_layers, 2, self.batch_size, config.hidden_size]),
mstype.float32)
decoder_layers = []
for i in range(0, self.num_layers):
if i == 0:
# the inputs is [T,D,N]
scaler = 1
else:
# the inputs is [T,D,2N]
scaler = 2
layer = DynamicRNNNet(seq_length=self.seq_length,
batchsize=self.batch_size,
word_embed_dim=scaler * self.word_embed_dim,
hidden_size=self.word_embed_dim)
decoder_layers.append(layer)
self.decoder_layers = nn.CellList(decoder_layers)
self.att_rnn = RecurrentAttention(rnn=self.decoder_layers[0],
is_training=is_training,
input_size=self.word_embed_dim,
context_size=self.attn_embed_dim,
hidden_size=self.attn_embed_dim,
num_layers=1,
dropout=config.attention_dropout_prob)
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
self.classifier = nn.Dense(config.hidden_size,
config.vocab_size,
has_bias=True,
weight_init=Uniform(initializer_range),
bias_init=Uniform(initializer_range)).to_float(compute_type)
self.cast = P.Cast()
self.shape_op = P.Shape()
self.expand = P.ExpandDims()
def construct(self, tgt_embeddings, encoder_outputs, attention_mask=None,
decoder_init_state=None):
"""Decoder."""
# tgt_embeddings: [T',N,D], encoder_outputs: [T,N,D], attention_mask: [N,T].
query_shape = self.shape_op(tgt_embeddings)
if decoder_init_state is None:
hidden_state = self.all_decoder_state
else:
hidden_state = decoder_init_state
# x:[t_q,b,D], attn:[t_q,b,D], scores:[b, t_q, t_k], state_0:[2,b,D].
x, attn, state_0, scores = self.att_rnn(decoder_embedding=tgt_embeddings, context_key=encoder_outputs,
attention_mask=attention_mask, rnn_init_state=hidden_state[0, :, :, :])
x = self.concat((x, attn))
x = self.dropout(x)
decoder_outputs, state_1 = self.decoder_layers[1](x, hidden_state[1, :, :, :])
all_decoder_state = self.state_concat((self.expand(state_0, 0), self.expand(state_1, 0)))
for i in range(2, self.num_layers):
residual = decoder_outputs
decoder_outputs = self.concat((decoder_outputs, attn))
decoder_outputs = self.dropout(decoder_outputs)
# 1st unidirectional layer. encoder_outputs: [T,N,D]
decoder_outputs, decoder_state = self.decoder_layers[i](decoder_outputs, hidden_state[i, :, :, :])
decoder_outputs += residual
all_decoder_state = self.state_concat((all_decoder_state, self.expand(decoder_state, 0)))
# [m, batch_size * beam_width, D]
decoder_outputs = self.reshape(decoder_outputs, (-1, self.word_embed_dim))
if self.is_training:
decoder_outputs = self.cast(decoder_outputs, mstype.float16)
decoder_outputs = self.classifier(decoder_outputs)
if self.is_training:
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
# [m, batch_size * beam_width, V]
decoder_outputs = self.reshape(decoder_outputs, (query_shape[0], query_shape[1], self.vocab_size))
# all_decoder_state:[4,2,b,D]
return decoder_outputs, all_decoder_state, scores

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Decoder for beam_search of GNMT."""
import numpy as np
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .embedding import EmbeddingLookup
from .decoder import GNMTDecoder
from .create_attn_padding import CreateAttentionPaddingsFromInputPaddings
from .components import SaturateCast
class PredLogProbs(nn.Cell):
"""
Get log probs.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): The length of sequences.
width (int): Number of parameters of a layer
compute_type (int): Type of input type.
dtype (int): Type of MindSpore output type.
"""
def __init__(self,
batch_size,
seq_length,
width,
compute_type=mstype.float32,
dtype=mstype.float32):
super(PredLogProbs, self).__init__()
self.batch_size = batch_size
self.seq_length = seq_length
self.width = width
self.compute_type = compute_type
self.dtype = dtype
self.log_softmax = nn.LogSoftmax(axis=-1)
# self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width)
self.cast = P.Cast()
def construct(self, logits):
"""
Calculate the log_softmax.
Inputs:
input_tensor (Tensor): A batch of sentences with shape (N, T).
output_weights (Tensor): A batch of masks with shape (N, T).
Returns:
Tensor, the prediction probability with shape (N, T').
"""
log_probs = self.log_softmax(logits)
return log_probs
class BeamDecoderStep(nn.Cell):
"""
Multi-layer transformer decoder step.
Args:
config (GNMTConfig): The config of Transformer.
"""
def __init__(self,
config,
use_one_hot_embeddings,
compute_type=mstype.float32):
super(BeamDecoderStep, self).__init__(auto_prefix=True)
self.vocab_size = config.vocab_size
self.word_embed_dim = config.hidden_size
self.embedding_lookup = EmbeddingLookup(
is_training=False,
vocab_size=config.vocab_size,
embed_dim=self.word_embed_dim,
use_one_hot_embeddings=use_one_hot_embeddings)
self.projection = PredLogProbs(
batch_size=config.batch_size * config.beam_width,
seq_length=1,
width=config.vocab_size,
compute_type=config.compute_type)
self.seq_length = config.max_decode_length
self.decoder = GNMTDecoder(config,
is_training=False,
infer_beam_width=config.beam_width)
self.ones_like = P.OnesLike()
self.shape = P.Shape()
self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,
is_training=False)
self.expand = P.ExpandDims()
self.multiply = P.Mul()
ones = np.ones(shape=(config.max_decode_length, config.max_decode_length))
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
self.transpose = P.Transpose()
self.transpose_orders = (1, 0, 2)
def construct(self, input_ids, enc_states, enc_attention_mask, decoder_hidden_state=None):
"""
Get log probs.
Args:
input_ids: [batch_size * beam_width, m]
enc_states: [batch_size * beam_width, T, D]
enc_attention_mask: [batch_size * beam_width, T]
decoder_hidden_state: [decoder_layers_nums, 2, batch_size * beam_width, hidden_size].
Returns:
Tensor, the log_probs. [batch_size * beam_width, 1, vocabulary_size]
"""
# process embedding. input_embedding: [batch_size * beam_width, m, D], embedding_tables: [V, D]
input_embedding, _ = self.embedding_lookup(input_ids)
input_embedding = self.cast_compute_type(input_embedding)
input_shape = self.shape(input_ids)
input_len = input_shape[1]
# [m, batch_size * beam_width, D]
input_embedding = self.transpose(input_embedding, self.transpose_orders)
enc_states = self.transpose(enc_states, self.transpose_orders)
# decoder_output: [m, batch_size*beam_width, V], scores:[b, t_q, t_k], all_decoder_state:[4,2,b*beam_width,D]
decoder_output, all_decoder_state, scores = self.decoder(input_embedding, enc_states, enc_attention_mask,
decoder_hidden_state)
# [batch_size * beam_width, m, v]
decoder_output = self.transpose(decoder_output, self.transpose_orders)
# take the last step, [batch_size * beam_width, 1, V]
decoder_output = decoder_output[:, (input_len - 1):input_len, :]
# projection and log_prob
log_probs = self.projection(decoder_output)
# [batch_size * beam_width, 1, vocabulary_size]
return log_probs, all_decoder_state, scores

View File

@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DynamicRNN."""
import numpy as np
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
class DynamicRNNCell(nn.Cell):
"""
DynamicRNN Cell.
Args:
num_setp (int): Lengths of setences.
batch_size (int): Batch size.
word_embed_dim (int): Input size.
hidden_size (int): Hidden size .
initializer_range (float): Initial range. Default: 0.02
"""
def __init__(self,
num_setp=50,
batch_size=128,
word_embed_dim=1024,
hidden_size=1024,
initializer_range=0.1):
super(DynamicRNNCell, self).__init__()
self.rnn = P.DynamicRNN()
self.num_step = num_setp
self.batch_size = batch_size
self.input_size = word_embed_dim
self.hidden_size = hidden_size
# w
dynamicRNN_w = np.random.uniform(-initializer_range, initializer_range,
size=[self.input_size + self.hidden_size, 4 * self.hidden_size])
self.dynamicRNN_w = Parameter(Tensor(dynamicRNN_w, mstype.float32), name='weight')
# b
dynamicRNN_b = np.random.uniform(-initializer_range, initializer_range, size=[4 * self.hidden_size])
self.dynamicRNN_b = Parameter(Tensor(dynamicRNN_b, mstype.float32), name='bias')
self.dynamicRNN_h = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
self.dynamicRNN_c = Tensor(np.zeros((1, self.batch_size, self.hidden_size)), mstype.float32)
self.cast = P.Cast()
def construct(self, x, init_h=None, init_c=None):
w = self.cast(self.dynamicRNN_w, mstype.float16)
b = self.cast(self.dynamicRNN_b, mstype.float16)
if init_h is None or init_c is None:
init_h = self.cast(self.dynamicRNN_h, mstype.float16)
init_c = self.cast(self.dynamicRNN_c, mstype.float16)
out = self.rnn(x, w, b, None, init_h, init_c)
return out[0], out[1], out[2]
class DynamicRNNNet(nn.Cell):
"""
DynamicRNN Network.
Args:
seq_length (int): Lengths of setences.
batchsize (int): Batch size.
word_embed_dim (int): Input size.
hidden_size (int): Hidden size.
"""
def __init__(self,
seq_length=80,
batchsize=128,
word_embed_dim=1024,
hidden_size=1024):
super(DynamicRNNNet, self).__init__()
self.max_length = seq_length
self.hidden_size = hidden_size
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.get_shape = P.Shape()
self.net = DynamicRNNCell(num_setp=seq_length,
batch_size=batchsize,
word_embed_dim=word_embed_dim,
hidden_size=hidden_size)
def construct(self, inputs, init_state=None):
"""DynamicRNN Network."""
inputs = self.cast(inputs, mstype.float16)
if init_state is not None:
init_h = self.cast(init_state[0:1, :, :], mstype.float16)
init_c = self.cast(init_state[-1:, :, :], mstype.float16)
out, state_h, state_c = self.net(inputs, init_h, init_c)
else:
out, state_h, state_c = self.net(inputs)
out = self.cast(out, mstype.float32)
state = self.concat((state_h[-1:, :, :], state_c[-1:, :, :]))
state = self.cast(state, mstype.float32)
# out:[T,b,D], state:[2,b,D]
return out, state

View File

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Word embedding for gnmt."""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
class EmbeddingLookup(nn.Cell):
"""
Embeddings lookup table with a fixed dictionary and size.
Args:
is_training (bool): Whether to train.
vocab_size (int): Size of the dictionary of embeddings.
embed_dim (int): The size of word embedding.
initializer_range (int): The initialize range of parameters.
use_one_hot_embeddings (bool): Whether use one-hot embedding. Default: False.
"""
def __init__(self,
is_training,
vocab_size,
embed_dim,
initializer_range=0.1,
use_one_hot_embeddings=False):
super(EmbeddingLookup, self).__init__()
self.is_training = is_training
self.embedding_dim = embed_dim
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
init_weight = np.random.normal(-initializer_range, initializer_range, size=[vocab_size, embed_dim])
self.embedding_table = Parameter(Tensor(init_weight, mstype.float32),
name='embedding_table')
self.expand = P.ExpandDims()
self.gather = P.GatherV2()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.get_shape = P.Shape()
self.cast = P.Cast()
def construct(self, input_ids):
"""
Construct network.
Args:
input_ids (Tensor): A batch of sentences with shape (N, T).
Returns:
Tensor, word embeddings with shape (N, T, D)
"""
_shape = self.get_shape(input_ids) # (N, T).
_batch_size = _shape[0]
_max_len = _shape[1]
if self.is_training:
embedding_table = self.cast(self.embedding_table, mstype.float16)
else:
embedding_table = self.embedding_table
flat_ids = self.reshape(input_ids, (_batch_size * _max_len,))
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
if self.is_training:
one_hot_ids = self.cast(one_hot_ids, mstype.float16)
output_for_reshape = self.array_mul(
one_hot_ids, embedding_table)
else:
output_for_reshape = self.gather(embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, (_batch_size, _max_len, self.embedding_dim))
if self.is_training:
output = self.cast(output, mstype.float32)
embedding_table = self.cast(embedding_table, mstype.float32)
return output, embedding_table

View File

@ -0,0 +1,100 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Encoder of GNMT."""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .dynamic_rnn import DynamicRNNNet
class GNMTEncoder(nn.Cell):
"""
Implements of GNMT encoder.
Args:
config (GNMTConfig): Configuration of GNMT network.
is_training (bool): Whether to train.
compute_type (mstype): Mindspore data type.
Returns:
Tensor, shape of (N, T, D).
"""
def __init__(self,
config: GNMTConfig,
is_training: bool,
compute_type=mstype.float32):
super(GNMTEncoder, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.max_positions = config.seq_length
self.attn_embed_dim = config.hidden_size
self.num_layers = config.num_hidden_layers
self.hidden_dropout_prob = config.hidden_dropout_prob
self.vocab_size = config.vocab_size
self.seq_length = config.seq_length
self.batch_size = config.batch_size
self.word_embed_dim = config.hidden_size
self.transpose = P.Transpose()
self.transpose_orders = (1, 0, 2)
self.reshape = P.Reshape()
self.concat = P.Concat(axis=-1)
encoder_layers = []
for i in range(0, self.num_layers + 1):
if i == 2:
# the bidirectional layer's output is [T,D,2N]
scaler = 2
else:
# the rest layer's output is [T,D,N]
scaler = 1
layer = DynamicRNNNet(seq_length=self.seq_length,
batchsize=self.batch_size,
word_embed_dim=scaler * self.word_embed_dim,
hidden_size=self.word_embed_dim)
encoder_layers.append(layer)
self.encoder_layers = nn.CellList(encoder_layers)
self.reverse_v2 = P.ReverseV2(axis=[0])
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
def construct(self, inputs, source_len, attention_mask=None):
"""Encoder."""
inputs = self.dropout(inputs)
# bidirectional layer, fwd_encoder_outputs: [T,N,D]
fwd_encoder_outputs, _ = self.encoder_layers[0](inputs)
# the input need reverse.
inputs_r = self.reverse_v2(inputs)
bak_encoder_outputs, _ = self.encoder_layers[1](inputs_r)
# the result need reverse.
bak_encoder_outputs = self.reverse_v2(bak_encoder_outputs)
# bi_encoder_outputs: [T,N,2D]
bi_encoder_outputs = self.concat((fwd_encoder_outputs, bak_encoder_outputs))
# 1st unidirectional layer. encoder_outputs: [T,N,D]
bi_encoder_outputs = self.dropout(bi_encoder_outputs)
encoder_outputs, _ = self.encoder_layers[2](bi_encoder_outputs)
# Build all the rest unidi layers of encoder
for i in range(3, self.num_layers + 1):
residual = encoder_outputs
encoder_outputs = self.dropout(encoder_outputs)
# [T,N,D] -> [T,N,D]
encoder_outputs_o, _ = self.encoder_layers[i](encoder_outputs)
encoder_outputs = encoder_outputs_o + residual
return encoder_outputs

View File

@ -0,0 +1,166 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMTv2 network."""
import copy
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .embedding import EmbeddingLookup
from .create_attn_padding import CreateAttentionPaddingsFromInputPaddings
from .beam_search import BeamSearchDecoder, TileBeam
from .encoder import GNMTEncoder
from .decoder import GNMTDecoder
from .decoder_beam_infer import BeamDecoderStep
from .components import SaturateCast
class GNMT(nn.Cell):
"""
GNMT with encoder and decoder.
In GNMT, we define T = src_max_len, T' = tgt_max_len.
Args:
config (GNMTConfig): Model config.
is_training (bool): Whether is training.
use_one_hot_embeddings (bool): Whether use one-hot embedding.
Returns:
Tuple[Tensor], network outputs.
"""
def __init__(self,
config: GNMTConfig,
is_training: bool = False,
use_one_hot_embeddings: bool = False,
use_positional_embedding: bool = True,
compute_type=mstype.float32):
super(GNMT, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.max_positions = config.seq_length
self.attn_embed_dim = config.hidden_size
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_dropout_prob = 0.0
self.is_training = is_training
self.num_layers = config.num_hidden_layers
self.hidden_dropout_prob = config.hidden_dropout_prob
self.vocab_size = config.vocab_size
self.seq_length = config.seq_length
self.batch_size = config.batch_size
self.max_decode_length = config.max_decode_length
self.word_embed_dim = config.hidden_size
self.beam_width = config.beam_width
self.transpose = P.Transpose()
self.transpose_orders = (1, 0, 2)
self.embedding_lookup = EmbeddingLookup(
is_training=self.is_training,
vocab_size=self.vocab_size,
embed_dim=self.word_embed_dim,
use_one_hot_embeddings=use_one_hot_embeddings)
self.gnmt_encoder = GNMTEncoder(config, is_training)
if self.is_training:
# use for train.
self.gnmt_decoder = GNMTDecoder(config, is_training)
else:
# use for infer.
self.expand = P.ExpandDims()
self.multiply = P.Mul()
self.reshape = P.Reshape()
self.create_att_paddings_from_input_paddings = CreateAttentionPaddingsFromInputPaddings(config,
is_training=False)
self.tile_beam = TileBeam(beam_width=config.beam_width)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
beam_decoder_cell = BeamDecoderStep(config, use_one_hot_embeddings=use_one_hot_embeddings)
# link beam_search after decoder
self.beam_decoder = BeamSearchDecoder(
batch_size=config.batch_size,
seq_length=config.seq_length,
vocab_size=config.vocab_size,
decoder=beam_decoder_cell,
beam_width=config.beam_width,
decoder_layers_nums=config.num_hidden_layers,
length_penalty_weight=config.length_penalty_weight,
hidden_size=config.hidden_size,
max_decode_length=config.max_decode_length)
self.beam_decoder.add_flags(loop_can_unroll=True)
self.shape = P.Shape()
def construct(self, source_ids, source_mask=None, source_len=None,
target_ids=None):
"""
Construct network.
In this method, T = src_max_len, T' = tgt_max_len.
Args:
source_ids (Tensor): Source sentences with shape (N, T).
source_mask (Tensor): Source sentences padding mask with shape (N, T),
where 0 indicates padding position.
target_ids (Tensor): Target sentences with shape (N, T').
target_mask (Tensor): Target sentences padding mask with shape (N, T'),
where 0 indicates padding position.
Returns:
Tuple[Tensor], network outputs.
"""
# Process source sentences. src_embeddings:[N, T, D].
src_embeddings, _ = self.embedding_lookup(source_ids)
# T, N, D
inputs = self.transpose(src_embeddings, self.transpose_orders)
# encoder. encoder_outputs: [T, N, D]
encoder_outputs = self.gnmt_encoder(inputs, source_len=source_len)
# decoder.
if self.is_training:
# training
# process target input sentences. N, T, D
tgt_embeddings, _ = self.embedding_lookup(target_ids)
# T, N, D
tgt_embeddings = self.transpose(tgt_embeddings, self.transpose_orders)
# cell: [T,N,D].
cell, _, _ = self.gnmt_decoder(tgt_embeddings,
encoder_outputs,
attention_mask=source_mask)
# decoder_output: (N, T', V).
decoder_outputs = self.transpose(cell, self.transpose_orders)
out = decoder_outputs
else:
# infer
# encoder_output: [T, N, D] -> [N, T, D].
beam_encoder_output = self.transpose(encoder_outputs, self.transpose_orders)
# bean search for encoder output, [N*beam_width, T, D]
beam_encoder_output = self.tile_beam(beam_encoder_output)
# (N*beam_width, T)
beam_enc_attention_pad = self.tile_beam(source_mask)
predicted_ids = self.beam_decoder(beam_encoder_output, beam_enc_attention_pad)
predicted_ids = self.reshape(predicted_ids, (-1, self.max_decode_length))
out = predicted_ids
return out

View File

@ -0,0 +1,203 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Infer api."""
import time
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore import context, Parameter
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from src.dataset import load_dataset
from .gnmt import GNMT
from ..utils import zero_weight
from ..utils.load_weights import load_infer_weights
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=False)
def get_weight_and_variable(model_path, params):
print("model path is {}".format(model_path))
ms_ckpt = load_checkpoint(model_path)
with open("variable.txt", "w") as f:
for msname in ms_ckpt:
f.write(msname + "\n")
with open("weights.txt", "w") as f:
for param in params:
name = param.name
f.write(name + "\n")
class GNMTInferCell(nn.Cell):
"""
Encapsulation class of GNMT network infer.
Args:
network (nn.Cell): GNMT model.
Returns:
Tuple[Tensor, Tensor], predicted_ids and predicted_probs.
"""
def __init__(self, network):
super(GNMTInferCell, self).__init__(auto_prefix=False)
self.network = network
def construct(self,
source_ids,
source_mask,
source_len):
"""Defines the computation performed."""
predicted_ids = self.network(source_ids,
source_mask,
source_len)
return predicted_ids
def gnmt_infer(config, dataset):
"""
Run infer with GNMT.
Args:
config (GNMTConfig): Config.
dataset (Dataset): Dataset.
Returns:
List[Dict], prediction, each example has 4 keys, "source",
"target", "prediction" and "prediction_prob".
"""
tfm_model = GNMT(config=config,
is_training=False,
use_one_hot_embeddings=False)
params = tfm_model.trainable_params()
get_weight_and_variable(config.existed_ckpt, params)
weights = load_infer_weights(config)
for param in params:
value = param.data
name = param.name
if name not in weights:
raise ValueError(f"{name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f:
weights_name = name
f.write(weights_name)
f.write("\n")
if isinstance(value, Tensor):
print(name, value.asnumpy().shape)
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
f.close()
print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model)
model = Model(tfm_infer)
predictions = []
source_sentences = []
shape = P.Shape()
concat = P.Concat(axis=0)
batch_index = 1
pad_idx = 0
sos_idx = 2
eos_idx = 3
source_ids_pad = Tensor(np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)]),
[config.batch_size, 1]), mstype.int32)
source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),
[config.batch_size, 1]), mstype.int32)
source_len_pad = Tensor(np.tile(np.array([[2]]), [config.batch_size, 1]), mstype.int32)
for batch in dataset.create_dict_iterator():
source_sentences.append(batch["source_eos_ids"].asnumpy())
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
source_len = Tensor(batch["source_eos_length"], mstype.int32)
active_num = shape(source_ids)[0]
if active_num < config.batch_size:
source_ids = concat((source_ids, source_ids_pad[active_num:, :]))
source_mask = concat((source_mask, source_mask_pad[active_num:, :]))
source_len = concat((source_len, source_len_pad[active_num:, :]))
start_time = time.time()
predicted_ids = model.predict(source_ids, source_mask, source_len)
print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
f"Time cost: {time.time() - start_time}.")
if active_num < config.batch_size:
predicted_ids = predicted_ids[:active_num, :]
batch_index = batch_index + 1
predictions.append(predicted_ids.asnumpy())
output = []
for inputs, batch_out in zip(source_sentences, predictions):
for i, _ in enumerate(batch_out):
if batch_out.ndim == 3:
batch_out = batch_out[:, 0]
example = {
"source": inputs[i].tolist(),
"prediction": batch_out[i].tolist()
}
output.append(example)
return output
def infer(config):
"""
GNMT infer api.
Args:
config (GNMTConfig): Config.
Returns:
list, result with
"""
eval_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
drop_remainder=False,
is_translate=True,
shuffle=False) if config.test_dataset else None
prediction = gnmt_infer(config, eval_dataset)
return prediction

View File

@ -0,0 +1,345 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMT for training."""
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from .gnmt import GNMT
from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
class PredLogProbs(nn.Cell):
"""
Get log probs.
Args:
config (GNMTConfig): The config of GNMT.
Returns:
Tensor, log softmax output.
"""
def __init__(self, config):
super(PredLogProbs, self).__init__()
self.reshape = P.Reshape()
self.log_softmax = nn.LogSoftmax(axis=-1)
self.get_shape = P.Shape()
def construct(self, input_tensor):
"""
Construct network.
Args:
input_tensor (Tensor): Tensor.
Returns:
Tensor, log softmax output.
"""
shape = self.get_shape(input_tensor)
logits = self.reshape(input_tensor, (shape[0] * shape[1], shape[2]))
log_probs = self.log_softmax(logits)
return log_probs
class GNMTTraining(nn.Cell):
"""
GNMT training network.
Args:
config (GNMTConfig): The config of GNMT.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
Returns:
Tensor, prediction_scores.
"""
def __init__(self, config, is_training, use_one_hot_embeddings):
super(GNMTTraining, self).__init__()
self.gnmt = GNMT(config, is_training, use_one_hot_embeddings)
self.projection = PredLogProbs(config)
def construct(self, source_ids, source_mask, source_len, target_ids):
"""
Construct network.
Args:
source_ids (Tensor): Source sentence.
source_mask (Tensor): Source padding mask.
source_len (Tensor): Effective length of source sentence.
target_ids (Tensor): Target sentence.
Returns:
Tensor, prediction_scores.
"""
decoder_outputs = self.gnmt(source_ids, source_mask, source_len, target_ids)
prediction_scores = self.projection(decoder_outputs)
return prediction_scores
class LabelSmoothedCrossEntropyCriterion(nn.Cell):
"""
Label Smoothed Cross-Entropy Criterion.
Args:
config (GNMTConfig): The config of GNMT.
Returns:
Tensor, final loss.
"""
def __init__(self, config):
super(LabelSmoothedCrossEntropyCriterion, self).__init__()
self.vocab_size = config.vocab_size
self.batch_size = config.batch_size
self.smoothing = 0.1
self.confidence = 0.9
self.last_idx = (-1,)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.neg = P.Neg()
self.cast = P.Cast()
self.index_ids = Tensor(np.arange(config.batch_size * config.max_decode_length).reshape((-1, 1)), mstype.int32)
self.gather_nd = P.GatherNd()
self.expand = P.ExpandDims()
self.concat = P.Concat(axis=-1)
def construct(self, prediction_scores, label_ids, label_weights):
"""
Construct network to calculate loss.
Args:
prediction_scores (Tensor): Prediction scores. [batchsize, seq_len, vocab_size]
label_ids (Tensor): Labels. [batchsize, seq_len]
label_weights (Tensor): Mask tensor. [batchsize, seq_len]
Returns:
Tensor, final loss.
"""
prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size))
label_ids = self.reshape(label_ids, (-1, 1))
label_weights = self.reshape(label_weights, (-1,))
tmp_gather_indices = self.concat((self.index_ids, label_ids))
nll_loss = self.neg(self.gather_nd(prediction_scores, tmp_gather_indices))
nll_loss = label_weights * nll_loss
smooth_loss = self.neg(self.reduce_mean(prediction_scores, self.last_idx))
smooth_loss = label_weights * smooth_loss
loss = self.reduce_sum(self.confidence * nll_loss + self.smoothing * smooth_loss, ())
loss = loss / self.batch_size
return loss
class GNMTNetworkWithLoss(nn.Cell):
"""
Provide GNMT training loss through network.
Args:
config (BertConfig): The config of GNMT.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(GNMTNetworkWithLoss, self).__init__()
self.gnmt = GNMTTraining(config, is_training, use_one_hot_embeddings)
self.loss = LabelSmoothedCrossEntropyCriterion(config)
self.cast = P.Cast()
def construct(self,
source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights):
prediction_scores = self.gnmt(source_ids, source_mask, source_len, target_ids)
total_loss = self.loss(prediction_scores, label_ids, label_weights)
return self.cast(total_loss, mstype.float32)
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of GNMT network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network: Cell. The training network. Note that loss function should have
been added.
optimizer: Optimizer. Optimizer for updating the weights.
Returns:
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(GNMTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.all_reduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
self.loss_scalar = P.ScalarSummary()
def construct(self,
source_eos_ids,
source_eos_mask,
source_eos_length,
target_sos_ids,
target_eos_ids,
target_eos_mask,
sens=None):
"""
Construct network.
Args:
source_eos_ids (Tensor): Source sentence.
source_eos_mask (Tensor): Source padding mask.
source_eos_length (Tensor): Effective length of source sentence.
target_sos_ids (Tensor): Target sentence.
target_eos_ids (Tensor): Prediction sentence.
target_eos_mask (Tensor): Prediction padding mask.
sens (Tensor): Loss sen.
Returns:
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
"""
source_ids = source_eos_ids
source_mask = source_eos_mask
source_len = source_eos_length
target_ids = target_sos_ids
label_ids = target_eos_ids
label_weights = target_eos_mask
weights = self.weights
loss = self.network(source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights)
# Alloc status.
init = self.alloc_status()
# Clear overflow buffer.
self.clear_before_grad(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# Apply grad reducer on grads.
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# Sum overflow flag over devices.
flag_reduce = self.all_reduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
self.loss_scalar("loss", loss)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gradient clip."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 5.0
class ClipGradients(nn.Cell):
"""
Clip gradients.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
"""
Construct gradient clip network.
Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads

View File

@ -0,0 +1,28 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for gnmt model."""
from .lr_scheduler import square_root_schedule
from .loss_monitor import LossCallBack
from .initializer import zero_weight, one_weight, normal_weight, weight_variable
__all__ = [
"square_root_schedule",
"LossCallBack",
"one_weight",
"zero_weight",
"normal_weight",
"weight_variable"
]

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Initializer."""
import numpy as np
from mindspore import Tensor
def _compute_fans(shape):
"""
Computes the number of input and output units for a weight shape.
Args:
shape (tuple): Integer shape tuple or TF tensor shape.
Returns:
tuple, integer scalars (fan_in, fan_out).
"""
if not shape:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
else:
# Assuming convolution kernels (2D, 3D, or more).
# kernel shape: (..., input_depth, depth)
receptive_field_size = 1
for dim in shape[:-2]:
receptive_field_size *= dim
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
return int(fan_in), int(fan_out)
def weight_variable(shape):
"""
Generate weight var.
Args:
shape (tuple): Shape.
Returns:
Tensor, var.
"""
# scale_shape = shape
# fan_in, fan_out = _compute_fans(scale_shape)
# scale = 1.0 / max(1., (fan_in + fan_out) / 2.)
# limit = math.sqrt(3.0 * scale)
limit = 0.1
values = np.random.uniform(-limit, limit, shape)
return values
def one_weight(shape):
"""
Generate weight with ones.
Args:
shape (tuple): Shape.
Returns:
Tensor, var.
"""
ones = np.ones(shape).astype(np.float32)
return Tensor(ones)
def zero_weight(shape):
"""
Generate weight with zeros.
Args:
shape (tuple): Shape.
Returns:
Tensor, var.
"""
zeros = np.zeros(shape).astype(np.float32)
return Tensor(zeros)
def normal_weight(shape, num_units):
"""
Generate weight with normal dist.
Args:
shape (tuple): Shape.
num_units (int): Dimension.
Returns:
Tensor, var.
"""
norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32)
return Tensor(norm)

View File

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Weight loader."""
import numpy as np
from mindspore.train.serialization import load_checkpoint
def load_infer_weights(config):
"""
Load weights from ckpt or npz.
Args:
config (GNMTConfig): Config.
Returns:
dict, weights.
"""
model_path = config.existed_ckpt
if model_path.endswith(".npz"):
ms_ckpt = np.load(model_path)
is_npz = True
else:
ms_ckpt = load_checkpoint(model_path)
is_npz = False
weights = {}
with open("variable_after_deal.txt", "w") as f:
for param_name in ms_ckpt:
infer_name = param_name.replace("gnmt.gnmt.", "")
if infer_name.startswith("embedding_lookup."):
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
f.close()
return weights

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loss monitor."""
import time
from mindspore.train.callback import Callback
from mindspore.communication.management import get_rank
from config import GNMTConfig
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
time_stamp_init = False
time_stamp_first = 0
def __init__(self, config: GNMTConfig, per_print_times: int = 1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self.config = config
self._per_print_times = per_print_times
if not self.time_stamp_init:
self.time_stamp_first = self._get_ms_timestamp()
self.time_stamp_init = True
def step_end(self, run_context):
"""step end."""
cb_params = run_context.original_args()
file_name = "./loss.log"
if self.config.modelarts:
import os
file_name = "/home/work/workspace/loss/loss_{}.log".format(os.getenv('DEVICE_ID'))
with open(file_name, "a+") as f:
time_stamp_current = self._get_ms_timestamp()
f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format(
time_stamp_current - self.time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs[0].asnumpy()),
str(cb_params.net_outputs[1].asnumpy()),
str(cb_params.net_outputs[2].asnumpy())
))
if self.config.modelarts:
from modelarts.data_util import upload_output
rank_id = get_rank()
if cb_params.cur_step_num % self.config.save_step == 1 \
and cb_params.cur_step_num != 1 and rank_id in [0, 8]:
upload_output("/home/work/workspace/loss", self.config.train_url)
upload_output("/cache/ckpt_0", self.config.train_url)
@staticmethod
def _get_ms_timestamp():
t = time.time()
return int(round(t * 1000))

View File

@ -0,0 +1,165 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning scheduler."""
from math import ceil
import math
import numpy as np
def convert_float2int(values, total_steps):
if isinstance(values, float):
values = int(values * total_steps)
return values
def square_root_schedule(lr, update_num, decay_start_step,
warmup_steps=2000,
min_lr=1e-5):
"""
Decay the LR based on the ISR(inverse square root).
During warm-up::
lrs = np.linspace(0, lr, warmup_steps)
After warm-up:
decay_factor = lr * sqrt(warmup_steps)
lr = decay_factor / sqrt(step) if step >= decay_start_step else lr
Args:
lr (float): Init learning rate.
update_num (int): Total steps.
decay_start_step (int): Decay begins after `decay_start_step` steps.
warmup_steps (int): Warm up steps.
min_lr (float): Min learning rate.
Returns:
np.ndarray, learning rate array.
"""
warmup_end_lr = lr
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
# If warmup_init_lr > lr, then lr_step is negative.
# Otherwise, it's positive.
lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps
decay_factor = lr * warmup_steps ** 0.5
lrs = np.empty(shape=update_num, dtype=np.float32)
_start_step = 0
if 0 < warmup_steps < update_num:
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
_start_step = warmup_steps
for step in range(_start_step, update_num):
if step < warmup_steps:
_lr = warmup_init_lr + step * lr_step
elif step < decay_start_step:
_lr = lr
else:
_lr = decay_factor * step ** -0.5
if _lr < min_lr:
_lr = min_lr
lrs[step] = _lr
return lrs
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
"""
Implements of polynomial decay learning rate scheduler which cycles by default.
Args:
lr (float): Initial learning rate.
warmup_steps (int): Warmup steps.
decay_steps (int): Decay steps.
total_update_num (int): Total update steps.
min_lr (float): Min learning.
power (float): Power factor.
Returns:
np.ndarray, learning rate of each step.
"""
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
if decay_steps <= 0:
raise ValueError("`decay_steps` must larger than 1.")
_start_step = 0
if 0 < warmup_steps < total_update_num:
warmup_end_lr = lr
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
_start_step = warmup_steps
decay_steps = decay_steps
for step in range(_start_step, total_update_num):
_step = step - _start_step # 2999
ratio = ceil(_step / decay_steps) # 3
ratio = 1 if ratio < 1 else ratio
_decay_steps = decay_steps * ratio # 3000
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
return lrs
def Warmup_MultiStepLR_scheduler(base_lr=0.002, total_update_num=200, warmup_steps=200, remain_steps=1.0,
decay_interval=-1, decay_steps=4, decay_factor=0.5):
"""
Implements of polynomial decay learning rate scheduler which cycles by default.
Args:
base_lr (float): Initial learning rate.
total_update_num (int): Total update steps.
warmup_steps (int or float): Warmup steps.
remain_steps (int or float): start decay at 'remain_steps' iteration
decay_interval (int): interval between LR decay steps
decay_steps (int): Decay steps.
decay_factor (float): decay factor
Returns:
np.ndarray, learning rate of each step.
"""
if decay_steps <= 0:
raise ValueError("`decay_steps` must larger than 1.")
remain_steps = convert_float2int(remain_steps, total_update_num)
warmup_steps = convert_float2int(warmup_steps, total_update_num)
if warmup_steps > remain_steps:
warmup_steps = remain_steps
if decay_interval < 0:
decay_iterations = total_update_num - remain_steps
decay_interval = decay_iterations // decay_steps
decay_interval = max(decay_interval, 1)
else:
decay_interval = convert_float2int(decay_interval, total_update_num)
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
_start_step = 0
for last_epoch in range(_start_step, total_update_num):
if last_epoch < warmup_steps:
if warmup_steps != 0:
warmup_factor = math.exp(math.log(0.01) / warmup_steps)
else:
warmup_factor = 1.0
inv_decay = warmup_factor ** (warmup_steps - last_epoch)
lrs[last_epoch] = base_lr * inv_decay
elif last_epoch >= remain_steps:
decay_iter = last_epoch - remain_steps
num_decay_step = decay_iter // decay_interval + 1
num_decay_step = min(num_decay_step, decay_steps)
lrs[last_epoch] = base_lr * (decay_factor ** num_decay_step)
else:
lrs[last_epoch] = base_lr
return lrs

View File

@ -0,0 +1,423 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""adam"""
import numpy as np
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.nn import Optimizer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
_learning_rate_update_func = ['linear', 'cos', 'sin']
adam_opt = C.MultitypeFuncGraph("adam_opt")
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
"""
Update parameters.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (op_sqrt(next_v) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
return next_v
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
# validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
# validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):
"""Check the type of inputs."""
validator.check_float_positive('learning_rate', learning_rate, prim_name)
validator.check_float_legal_value('learning_rate', learning_rate, prim_name)
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor")
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
class Adam(Optimizer):
r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
Note:
The Adam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
0.999.
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
1e-8.
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
If True, updating of the var, m, and v tensors will be protected by a lock.
If False, the result is unpredictable. Default: False.
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Adam(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
# validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = eps
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.pow = P.Pow()
self.sqrt = P.Sqrt()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.realdiv = P.RealDiv()
self.lr_scalar = P.ScalarSummary()
def construct(self, gradients):
"""Adam optimizer."""
params = self.parameters
moment1 = self.moment1
moment2 = self.moment2
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
self.lr_scalar("learning_rate", lr)
beta1_power = self.beta1_power * self.beta1
self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success
class AdamWeightDecay(Optimizer):
"""
Implements Adam algorithm weight decay fix.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecay, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
self.params = self.parameters
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
self.decay_flag = tuple(decay_filter(x) for x in self.params)
self.hyper_map = C.HyperMap()
def construct(self, gradients):
"""Adam Weight Decay"""
lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
return updated_velocity
class AdamWeightDecayDynamicLR(Optimizer):
"""
Adam Weight Decay Dynamic Learning Rate (LR).
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
decay_steps (int): The steps of the decay.
learning_rate (float): A floating point value for the learning rate. Default: 0.001.
end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
power (float): Power. Default: 10.0.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self,
params,
decay_steps,
learning_rate=0.001,
end_learning_rate=0.0001,
power=10.0,
beta1=0.9,
beta2=0.999,
eps=1e-6,
weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
warmup_steps=0):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step")
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32))
self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32))
self.power = power
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
self.params = self.parameters
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
self.decay_flag = tuple(decay_filter(x) for x in self.params)
self.hyper_map = C.HyperMap()
self.min = P.Minimum()
self.pow = P.Pow()
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32))
def construct(self, gradients):
"""Adam Weight Decay Dynamic LR."""
step = self.min(self.global_step, self.decay_steps)
p = step / self.decay_steps
lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
if self.warmup_flag:
warmup_percent = self.global_step / self.warmup_steps
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step
return updated_velocity

View File

@ -0,0 +1,360 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train api."""
import os
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn import Momentum
from mindspore.nn.optim import Lamb
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, SummaryCollector, TimeMonitor
from mindspore import context, Parameter
from mindspore.context import ParallelMode
from mindspore.communication import management as MultiAscend
from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed
from config import GNMTConfig
from src.dataset import load_dataset
from src.gnmt_model import GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
from src.utils import LossCallBack
from src.utils import one_weight, weight_variable
from src.utils.lr_scheduler import square_root_schedule, polynomial_decay_scheduler, Warmup_MultiStepLR_scheduler
from src.utils.optimizer import Adam
parser = argparse.ArgumentParser(description='GNMT train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
raise RuntimeError("`DEVICE_ID` can not be None.")
device_id = int(device_id)
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=True,
device_id=device_id)
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def _train(model, config: GNMTConfig,
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
callbacks: list = None):
"""
Train model.
Args:
model (Model): MindSpore model instance.
config (GNMTConfig): Config of mass model.
pre_training_dataset (Dataset): Pre-training dataset.
fine_tune_dataset (Dataset): Fine-tune dataset.
test_dataset (Dataset): Test dataset.
callbacks (list): A list of callbacks.
"""
callbacks = callbacks if callbacks else []
if pre_training_dataset is not None:
print(" | Start pre-training job.")
epoch_size = pre_training_dataset.get_repeat_count()
print("epoch size ", epoch_size)
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
model.train(config.epochs, pre_training_dataset,
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
if fine_tune_dataset is not None:
print(" | Start fine-tuning job.")
epoch_size = fine_tune_dataset.get_repeat_count()
model.train(config.epochs, fine_tune_dataset,
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
def _load_checkpoint_to_net(config, network):
"""load parameters to network from checkpoint."""
if config.existed_ckpt:
if config.existed_ckpt.endswith(".npz"):
weights = np.load(config.existed_ckpt)
else:
weights = load_checkpoint(config.existed_ckpt)
for param in network.trainable_params():
weights_name = param.name
if weights_name not in weights:
raise ValueError(f"Param {weights_name} is not found in ckpt file.")
if isinstance(weights[weights_name], Parameter):
param.set_data(weights[weights_name].data)
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else:
for param in network.trainable_params():
name = param.name
value = param.data
if isinstance(value, Tensor):
if name.endswith(".gamma"):
param.set_data(one_weight(value.asnumpy().shape))
elif name.endswith(".beta") or name.endswith(".bias"):
# param.set_data(zero_weight(value.asnumpy().shape))
if param.data.dtype == "Float32":
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float32)))
elif param.data.dtype == "Float16":
param.set_data((weight_variable(value.asnumpy().shape).astype(np.float16)))
else:
if param.data.dtype == "Float32":
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float32)))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float16)))
def _get_lr(config, update_steps):
"""generate learning rate."""
if config.lr_scheduler == "isr":
lr = Tensor(square_root_schedule(lr=config.lr,
update_num=update_steps,
decay_start_step=config.decay_start_step,
warmup_steps=config.warmup_steps,
min_lr=config.min_lr), dtype=mstype.float32)
elif config.lr_scheduler == "poly":
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
min_lr=config.min_lr,
decay_steps=config.decay_steps,
total_update_num=update_steps,
warmup_steps=config.warmup_steps,
power=config.lr_scheduler_power), dtype=mstype.float32)
elif config.lr_scheduler == "WarmupMultiStepLR":
lr = Tensor(Warmup_MultiStepLR_scheduler(base_lr=config.lr,
total_update_num=update_steps,
warmup_steps=config.warmup_steps,
remain_steps=config.warmup_lr_remain_steps,
decay_interval=config.warmup_lr_decay_interval,
decay_steps=config.decay_steps,
decay_factor=config.lr_scheduler_power), dtype=mstype.float32)
else:
lr = config.lr
return lr
def _get_optimizer(config, network, lr):
"""get gnmt optimizer, support Adam, Lamb, Momentum."""
if config.optimizer.lower() == "adam":
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
elif config.optimizer.lower() == "lamb":
optimizer = Lamb(network.trainable_params(), decay_steps=12000,
start_learning_rate=config.lr, end_learning_rate=config.min_lr,
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
eps=1e-6)
elif config.optimizer.lower() == "momentum":
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
else:
raise ValueError(f"optimizer only support `adam` and `momentum` now.")
return optimizer
def _build_training_pipeline(config: GNMTConfig,
pre_training_dataset=None,
fine_tune_dataset=None,
test_dataset=None):
"""
Build training pipeline.
Args:
config (GNMTConfig): Config of mass model.
pre_training_dataset (Dataset): Pre-training dataset.
fine_tune_dataset (Dataset): Fine-tune dataset.
test_dataset (Dataset): Test dataset.
"""
net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True)
net_with_loss.init_parameters_data()
_load_checkpoint_to_net(config, net_with_loss)
dataset = pre_training_dataset if pre_training_dataset is not None \
else fine_tune_dataset
if dataset is None:
raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
update_steps = config.epochs * dataset.get_dataset_size()
lr = _get_lr(config, update_steps)
optimizer = _get_optimizer(config, net_with_loss, lr)
# Dynamic loss scale.
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
scale_factor=config.loss_scale_factor,
scale_window=config.scale_window)
net_with_grads = GNMTTrainOneStepWithLossScaleCell(
network=net_with_loss, optimizer=optimizer,
scale_update_cell=scale_manager.get_update_cell()
)
net_with_grads.set_train(True)
model = Model(net_with_grads)
loss_monitor = LossCallBack(config)
dataset_size = dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
keep_checkpoint_max=config.keep_ckpt_max)
rank_size = os.getenv('RANK_SIZE')
callbacks = [time_cb, loss_monitor]
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
ckpt_callback = ModelCheckpoint(
prefix=config.ckpt_prefix,
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
config=ckpt_config)
callbacks.append(ckpt_callback)
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
callbacks.append(summary_callback)
if rank_size is None or int(rank_size) == 1:
ckpt_callback = ModelCheckpoint(
prefix=config.ckpt_prefix,
directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
config=ckpt_config)
callbacks.append(ckpt_callback)
summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
callbacks.append(summary_callback)
print(f" | ALL SET, PREPARE TO TRAIN.")
_train(model=model, config=config,
pre_training_dataset=pre_training_dataset,
fine_tune_dataset=fine_tune_dataset,
test_dataset=test_dataset,
callbacks=callbacks)
def _setup_parallel_env():
context.reset_auto_parallel_context()
MultiAscend.init()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(),
gradients_mean=True
)
def train_parallel(config: GNMTConfig):
"""
Train model with multi ascend chips.
Args:
config (GNMTConfig): Config for MASS model.
"""
_setup_parallel_env()
print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
pre_train_dataset = load_dataset(
data_files=config.pre_train_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank()
) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank()
) if config.fine_tune_dataset else None
test_dataset = load_dataset(
data_files=config.test_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank()
) if config.test_dataset else None
_build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset,
fine_tune_dataset=fine_tune_dataset,
test_dataset=test_dataset)
def train_single(config: GNMTConfig):
"""
Train model on single device.
Args:
config (GNMTConfig): Config for model.
"""
print(" | Starting training on single device.")
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.test_dataset else None
_build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset,
fine_tune_dataset=fine_tune_dataset,
test_dataset=test_dataset)
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
if __name__ == '__main__':
_rank_size = os.getenv('RANK_SIZE')
args, _ = parser.parse_known_args()
_check_args(args.config)
_config = get_config(args.config)
set_seed(_config.random_seed)
if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config)
else:
train_single(_config)