init for gru

fix for copyright

fix for create_dataset

fix for some bug

fix for parser_output

fix for parse output

add readme and fix for some bug

fix test2016 to test

delete space

fix some bug for create dataset

fix for comments and overflow

fix a wrong bug

fix for weight init
This commit is contained in:
qujianwei 2021-01-16 11:01:50 +08:00
parent dfa6daaa57
commit a37ad24136
22 changed files with 2060 additions and 0 deletions

View File

@ -0,0 +1,252 @@
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
<!-- TOC -->
- [GRU](#gru)
- [Model Structure](#model-structure)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Dataset Preparation](#dataset-preparation)
- [Configuration File](#configuration-file)
- [Training Process](#training-process)
- [Inference Process](#inference-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Random Situation Description](#random-situation-description)
- [Others](#others)
- [ModelZoo HomePage](#modelzoo-homepage)
<!-- /TOC -->
# [GRU](#contents)
GRU(Gate Recurrent Unit) is a kind of recurrent neural network algorithm, just like the LSTM(Long-Short Term Memory). It was proposed by Kyunghyun Cho, Bart van Merrienboer etc. in the article "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" in 2014. In this paper, it proposes a novel neural network model called RNN Encoder-Decoder that consists of two recurrent neural networks (RNN).To improve the effect of translation task, we also refer to "Sequence to Sequence Learning with Neural Networks" and "Neural Machine Translation by Jointly Learning to Align and Translate".
## Paper
1.[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation", 2014, Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
2.[Paper](https://arxiv.org/pdf/1409.3215.pdf): "Sequence to Sequence Learning with Neural Networks", 2014, Ilya Sutskever, Oriol Vinyals, Quoc V. Le
3.[Paper](): "Neural Machine Translation by Jointly Learning to Align and Translate", 2014, Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio
# [Model Structure](#contents)
The GRU model mainly consists of an Encoder and a Decoder.The Encoder is constructed with a bidirection GRU cell.The Decoder mainly contains an attention and a GRU cell.The input of the net is sequence of words (text or sentence), and the output of the net is the probability of each word in vocab, and we choose the maximum probability one as our prediction.
# [Dataset](#contents)
In this model, we use the Multi30K dataset as our train and test dataset.As training dataset, it provides 29,000 respectively, each containing an German sentence and its English translation.For testing dataset, it provides 1000 German and English sentences.We also provide a preprocess script to tokenize the dataset and create the vocab file.
# [Environment Requirements](#content)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#content)
After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
cd ./scripts
sh run_standalone_train.sh [TRAIN_DATASET_PATH]
# run distributed training example
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH]
# run evaluation example
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
```
# [Script Description](#content)
The GRU network script and code result are as follows:
```text
├── gru
├── README.md // Introduction of GRU model.
├── src
| ├──gru.py // gru cell architecture.
│ ├──config.py // Configuration instance definition.
│ ├──create_data.py // Dataset preparation.
│ ├──dataset.py // Dataset loader to feed into model.
│ ├──gru_for_infer.py // GRU eval model architecture.
│ ├──gru_for_train.py // GRU train model architecture.
│ ├──loss.py // Loss architecture.
│ ├──lr_schedule.py // Learning rate scheduler.
│ ├──parse_output.py // Parse output file.
│ ├──preprocess.py // Dataset preprocess.
│ ├──seq2seq.py // Seq2seq architecture.
│ ├──tokenization.py // tokenization for the dataset.
│ ├──weight_init.py // Initialize weights in the net.
├── scripts
│ ├──create_dataset.sh // shell script for create dataset.
│ ├──parse_output.sh // shell script for parse eval output file to calculate BLEU.
│ ├──preprocess.sh // shell script for preprocess dataset.
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
│ ├──run_eval.sh // shell script for standalone eval on ascend.
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
├── eval.py // Infer API entry.
├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry.
```
## [Dataset Preparation](#content)
Firstly, we should download the dataset from the WMT16 official net.After downloading the Multi30k dataset file, we get six dataset file, which is show as below.And we should in put the in same directory.
```text
train.de
train.en
val.de
val.en
test.de
test.en
```
Then, we can use the scripts/preprocess.sh to tokenize the dataset file and get the vocab file.
```bash
bash preprocess.sh [DATASET_PATH]
```
After preprocess, we will get the dataset file which is suffix with ".tok" and two vocab file, which are nameed vocab.de and vocab.en.
Then we provided scripts/create_dataset.sh to create the dataset file which format is mindrecord.
```bash
bash preprocess.sh [DATASET_PATH] [OUTPUT_PATH]
```
Finally, we will get multi30k_train_mindrecord_0 ~ multi30k_train_mindrecord_8 as our train dataset, and multi30k_test_mindrecord as our test dataset.
## [Configuration File](#content)
Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs.
- Network Parameters
```text
"batch_size": 16, # batch size of input dataset.
"src_vocab_size": 8154, # source dataset vocabulary size.
"trg_vocab_size": 6113, # target dataset vocabulary size.
"encoder_embedding_size": 256, # encoder embedding size.
"decoder_embedding_size": 256, # decoder embedding size.
"hidden_size": 512, # hidden size of gru.
"max_length": 32, # max sentence length.
"num_epochs": 30, # total epoch.
"save_checkpoint": True, # whether save checkpoint file.
"ckpt_epoch": 1, # frequence to save checkpoint file.
"target_file": "target.txt", # the target file.
"output_file": "output.txt", # the output file.
"keep_checkpoint_max": 30, # the maximum number of checkpoint file.
"base_lr": 0.001, # init learning rate.
"warmup_step": 300, # warmup step.
"momentum": 0.9, # momentum in optimizer.
"init_loss_scale_value": 1024, # init scale sense.
'scale_factor': 2, # scale factor for dynamic loss scale.
'scale_window': 2000, # scale window for dynamic loss scale.
"warmup_ratio": 1/3.0, # warmup ratio.
"teacher_force_ratio": 0.5 # teacher force ratio.
```
## [Training Process](#content)
- Start task training on a single device and run the shell script
```bash
cd ./scripts
sh run_standalone_train.sh [DATASET_PATH]
```
- Running scripts for distributed training of GRU. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
``` bash
cd ./scripts
sh run_distributed_train.sh [RANK_TABLE_PATH] [DATASET_PATH]
```
## [Inference Process](#content)
- Running scripts for evaluation of GRU. The commdan as below.
``` bash
cd ./scripts
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
```
- After evalulation, we will get eval/target.txt and eval/output.txt.Then we can use scripts/parse_output.sh to get the translation.
``` bash
cp eval/*.txt ./
sh parse_output.sh target.txt output.txt /path/vocab.en
```
- After parse output, we will get target.txt.forbleu and output.txt.forbleu.To calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
```bash
perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu
```
Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord
# [Model Description](#content)
## [Performance](#content)
### Training Performance
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 |
| uploaded Date | 01/18/2021 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Multi30k Dataset |
| Training Parameters | epoch=30, batch_size=16 |
| Optimizer | Adam |
| Loss Function | NLLLoss |
| outputs | probability |
| Speed | 50ms/step (1pcs) |
| Epoch Time | 13.4s (1pcs) |
| Loss | 2.5984 |
| Params (M) | 21 |
| Checkpoint for inference | 272M (.ckpt file) |
| Scripts | [gru](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gru) |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 01/18/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Multi30K |
| batch_size | 1 |
| outputs | label index |
| Accuracy | BLEU: 30.30 |
| Model for inference | 272M (.ckpt file) |
# [Random Situation Description](#content)
There only one random situation.
- Initialization of some model weights.
Some seeds have already been set in train.py to avoid the randomness of weight initialization.
# [Others](#others)
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
# [ModelZoo HomePage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

View File

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer evaluation script."""
import argparse
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context
from src.dataset import create_gru_dataset
from src.seq2seq import Seq2Seq
from src.gru_for_infer import GRUInferCell
from src.config import config
def run_gru_eval():
"""
Transformer evaluation.
"""
parser = argparse.ArgumentParser(description='GRU eval')
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
parser.add_argument("--dataset_path", type=str, default="",
help="Dataset path, default: f`sns.")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
device_id=args.device_id, save_graphs=False)
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config, is_training=False)
network = GRUInferCell(network)
network.set_train(False)
if args.ckpt_file != "":
parameter_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, parameter_dict)
model = Model(network)
predictions = []
source_sents = []
target_sents = []
eval_text_len = 0
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
source_sents.append(batch["source_ids"])
target_sents.append(batch["target_ids"])
source_ids = Tensor(batch["source_ids"], mstype.int32)
target_ids = Tensor(batch["target_ids"], mstype.int32)
predicted_ids = model.predict(source_ids, target_ids)
print("predicts is ", predicted_ids.asnumpy())
print("target_ids is ", target_ids)
predictions.append(predicted_ids.asnumpy())
eval_text_len = eval_text_len + 1
f_output = open(config.output_file, 'w')
f_target = open(config.target_file, "w")
for batch_out, true_sentence in zip(predictions, target_sents):
for i in range(config.eval_batch_size):
target_ids = [str(x) for x in true_sentence[i].tolist()]
f_target.write(" ".join(target_ids) + "\n")
token_ids = [str(x) for x in batch_out[i].tolist()]
f_output.write(" ".join(token_ids) + "\n")
f_output.close()
f_target.close()
if __name__ == "__main__":
run_gru_eval()

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh create_dataset.sh DATASET_PATH OUTPUT_PATH"
echo "for example: sh create_dataset.sh /path/multi30k/ /path/multi30k/mindrecord/"
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ulimit -u unlimited
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not valid"
exit 1
fi
OUTPUT_PATH=$(get_real_path $2)
echo $OUTPUT_PATH
if [ ! -d $OUTPUT_PATH ]
then
echo "error: OUTPUT_PATH=$OUTPUT_PATH is not valid"
exit 1
fi
paste $DATASET_PATH/train.de.tok $DATASET_PATH/train.en.tok > $DATASET_PATH/train.all
python ../src/create_data.py --input_file $DATASET_PATH/train.all --num_splits 8 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_train_mindrecord --max_seq_length 32 --bucket [32]
paste $DATASET_PATH/test.de.tok $DATASET_PATH/test.en.tok > $DATASET_PATH/test.all
python ../src/create_data.py --input_file $DATASET_PATH/test.all --num_splits 1 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_test_mindrecord --max_seq_length 32 --bucket [32]

View File

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE"
echo "for example: sh parse_output.sh target.txt output.txt vocab.en"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ref_data=$1
eval_output=$2
vocab_file=$3
cat $ref_data \
| python ../src/parse_output.py --vocab_file $vocab_file \
| sed 's/@@ //g' > ${ref_data}.forbleu
cat $eval_output \
| python ../src/parse_output.py --vocab_file $vocab_file \
| sed 's/@@ //g' > ${eval_output}.forbleu

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATASET_DIR=$1
python ../src/preprocess.py --dataset_path=$DATASET_DIR

View File

@ -0,0 +1,68 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log &
cd ..
done

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_eval.sh [CKPT_FILE] [DATASET_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
CKPT_FILE=$(get_real_path $1)
echo $CKPT_FILE
if [ ! -f $CKPT_FILE ]
then
echo "error: CKPT_FILE=$CKPT_FILE is not a file"
exit 1
fi
DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
rm -rf ./eval
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
echo "start eval for device $DEVICE_ID"
env > env.log
python eval.py --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH &> log &
cd ..

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [DATASET_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
cd ..

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU config"""
from easydict import EasyDict
config = EasyDict({
"batch_size": 16,
"eval_batch_size": 1,
"src_vocab_size": 8154,
"trg_vocab_size": 6113,
"encoder_embedding_size": 256,
"decoder_embedding_size": 256,
"hidden_size": 512,
"max_length": 32,
"num_epochs": 30,
"save_checkpoint": True,
"ckpt_epoch": 10,
"target_file": "target.txt",
"output_file": "output.txt",
"keep_checkpoint_max": 30,
"base_lr": 0.001,
"warmup_step": 300,
"momentum": 0.9,
"init_loss_scale_value": 1024,
'scale_factor': 2,
'scale_window': 2000,
"warmup_ratio": 1/3.0,
"teacher_force_ratio": 0.5
})

View File

@ -0,0 +1,196 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create training instances for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import collections
import logging
import numpy as np
import tokenization
from mindspore.mindrecord import FileWriter
class SampleInstance():
"""A single sample instance (sentence pair)."""
def __init__(self, source_tokens, target_tokens):
self.source_tokens = source_tokens
self.target_tokens = target_tokens
def __str__(self):
s = ""
s += "source_tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.source_tokens]))
s += "target tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.target_tokens]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def get_instance_features(instance, tokenizer_src, tokenizer_trg, max_seq_length, bucket):
"""Get features from `SampleInstance`s."""
def _find_bucket_length(source_tokens, target_tokens):
source_ids = tokenizer_src.convert_tokens_to_ids(source_tokens)
target_ids = tokenizer_trg.convert_tokens_to_ids(target_tokens)
num = max(len(source_ids), len(target_ids))
assert num <= bucket[-1]
for index in range(1, len(bucket)):
if bucket[index - 1] < num <= bucket[index]:
return bucket[index]
return bucket[0]
def _convert_ids_and_mask(tokenizer, input_tokens, seq_max_bucket_length):
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < seq_max_bucket_length:
input_ids.append(1)
input_mask.append(0)
assert len(input_ids) == seq_max_bucket_length
assert len(input_mask) == seq_max_bucket_length
return input_ids, input_mask
seq_max_bucket_length = _find_bucket_length(instance.source_tokens, instance.target_tokens)
source_ids, source_mask = _convert_ids_and_mask(tokenizer_src, instance.source_tokens, seq_max_bucket_length)
target_ids, target_mask = _convert_ids_and_mask(tokenizer_trg, instance.target_tokens, seq_max_bucket_length)
features = collections.OrderedDict()
features["source_ids"] = np.asarray(source_ids)
features["source_mask"] = np.asarray(source_mask)
features["target_ids"] = np.asarray(target_ids)
features["target_mask"] = np.asarray(target_mask)
return features, seq_max_bucket_length
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
"""Creates `SampleInstance`s for a single sentence pair."""
EOS = "<eos>"
SOS = "<sos>"
if len(source_words) >= max_seq_length-1 or len(target_words) >= max_seq_length-1:
if clip_to_max_len:
source_words = source_words[:min([len(source_words, max_seq_length-2)])]
target_words = target_words[:min([len(target_words, max_seq_length-2)])]
else:
return None
source_tokens = [SOS] + source_words + [EOS]
target_tokens = [SOS] + target_words + [EOS]
instance = SampleInstance(
source_tokens=source_tokens,
target_tokens=target_tokens)
return instance
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True,
help='Input raw text file (or comma-separated list of files).')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
parser.add_argument("--num_splits", type=int, default=16,
help='The MindRecord file will be split into the number of partition.')
parser.add_argument("--src_vocab_file", type=str, required=True,
help='The vocabulary file that the Transformer model was trained on.')
parser.add_argument("--trg_vocab_file", type=str, required=True,
help='The vocabulary file that the Transformer model was trained on.')
parser.add_argument("--clip_to_max_len", type=ast.literal_eval, default=False,
help='clip sequences to maximum sequence length.')
parser.add_argument("--max_seq_length", type=int, default=32, help='Maximum sequence length.')
parser.add_argument("--bucket", type=ast.literal_eval, default=[32],
help='bucket sequence length')
args = parser.parse_args()
tokenizer_src = tokenization.WhiteSpaceTokenizer(vocab_file=args.src_vocab_file)
tokenizer_trg = tokenization.WhiteSpaceTokenizer(vocab_file=args.trg_vocab_file)
input_files = []
for input_pattern in args.input_file.split(","):
input_files.append(input_pattern)
logging.info("*** Read from input files ***")
output_file = args.output_file
logging.info("*** Write to output files ***")
logging.info(" %s", output_file)
total_written = 0
total_read = 0
feature_dict = {}
for i in args.bucket:
feature_dict[i] = []
for input_file in input_files:
logging.info("*** Reading from %s ***", input_file)
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
total_read += 1
if total_read % 100000 == 0:
logging.info("Read %d ...", total_read)
if line.strip() == "":
continue
source_line, target_line = line.strip().split("\t")
source_tokens = tokenizer_src.tokenize(source_line)
target_tokens = tokenizer_trg.tokenize(target_line)
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
logging.info("ignore long sentence!")
continue
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
clip_to_max_len=args.clip_to_max_len)
if instance is None:
continue
features, seq_max_bucket_length = get_instance_features(instance, tokenizer_src, tokenizer_trg,
args.max_seq_length, args.bucket)
for key in feature_dict:
if key == seq_max_bucket_length:
feature_dict[key].append(features)
if total_read <= 10:
logging.info("*** Example ***")
logging.info("source tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.source_tokens]))
logging.info("target tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.target_tokens]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
for i in args.bucket:
if args.num_splits == 1:
output_file_name = output_file + '_' + str(i)
else:
output_file_name = output_file + '_' + str(i) + '_'
writer = FileWriter(output_file_name, args.num_splits)
data_schema = {"source_ids": {"type": "int64", "shape": [-1]},
"source_mask": {"type": "int64", "shape": [-1]},
"target_ids": {"type": "int64", "shape": [-1]},
"target_mask": {"type": "int64", "shape": [-1]}
}
writer.add_schema(data_schema, "gru")
features_ = feature_dict[i]
logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
for item in features_:
writer.write_raw_data([item])
total_written += 1
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -0,0 +1,48 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py."""
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC
from src.config import config
import numpy as np
de.config.set_seed(1)
def random_teacher_force(source_ids, target_ids, target_mask):
teacher_force = np.random.random() < config.teacher_force_ratio
teacher_force_array = np.array([teacher_force], dtype=bool)
return source_ids, target_ids, teacher_force_array
def create_gru_dataset(epoch_count=1, batch_size=1, rank_size=1, rank_id=0, do_shuffle=True, dataset_path=None,
is_training=True):
"""create dataset"""
ds = de.MindDataset(dataset_path,
columns_list=["source_ids", "target_ids",
"target_mask"],
shuffle=do_shuffle, num_parallel_workers=10, num_shards=rank_size, shard_id=rank_id)
operations = random_teacher_force
ds = ds.map(operations=operations, input_columns=["source_ids", "target_ids", "target_mask"],
output_columns=["source_ids", "target_ids", "teacher_force"],
column_order=["source_ids", "target_ids", "teacher_force"])
type_cast_op = deC.TypeCast(mstype.int32)
type_cast_op_bool = deC.TypeCast(mstype.bool_)
ds = ds.map(operations=type_cast_op, input_columns="source_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_ids")
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(1)
return ds

View File

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU cell"""
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.weight_init import gru_default_state
class BidirectionGRU(nn.Cell):
'''
BidirectionGRU model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(BidirectionGRU, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.embedding_size = config.encoder_embedding_size
self.hidden_size = config.hidden_size
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = gru_default_state(self.batch_size,
self.embedding_size,
self.hidden_size)
self.weight_bw_i, self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, self.init_bw_h = \
gru_default_state(self.batch_size, self.embedding_size, self.hidden_size)
self.reverse = P.ReverseV2(axis=[1])
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=0)
self.rnn = P.DynamicGRUV2()
self.text_len = config.max_length
self.cast = P.Cast()
def construct(self, x):
'''
BidirectionGRU construction
Args:
x(Tensor): BidirectionGRU input
Returns:
output(Tensor): rnn output
hidden(Tensor): hidden state
'''
x = self.cast(x, mstype.float16)
y1, _, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
bw_x = self.reverse(x)
y1_bw, _, _, _, _, _ = self.rnn(bw_x, self.weight_bw_i,
self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, None, self.init_bw_h)
y1_bw = self.reverse(y1_bw)
output = self.concat((y1, y1_bw))
hidden = self.concat((y1[self.text_len-1:self.text_len:1, ::, ::],
y1_bw[self.text_len-1:self.text_len:1, ::, ::]))
hidden = self.squeeze(hidden)
return output, hidden
class GRU(nn.Cell):
'''
GRU model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(GRU, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.embedding_size = config.encoder_embedding_size
self.hidden_size = config.hidden_size
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = \
gru_default_state(self.batch_size, self.embedding_size + self.hidden_size*2, self.hidden_size)
self.rnn = P.DynamicGRUV2()
self.cast = P.Cast()
def construct(self, x):
'''
GRU construction
Args:
x(Tensor): GRU input
Returns:
output(Tensor): rnn output
hidden(Tensor): hidden state
'''
x = self.cast(x, mstype.float16)
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
return y1, h1

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU Infer cell"""
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.config import config
class GRUInferCell(nn.Cell):
'''
GRU infer consturction
Args:
network: gru network
'''
def __init__(self, network):
super(GRUInferCell, self).__init__(auto_prefix=False)
self.network = network
self.argmax = P.ArgMaxWithValue(axis=2)
self.transpose = P.Transpose()
self.teacher_force = Tensor(np.zeros((config.eval_batch_size)), mstype.bool_)
def construct(self,
encoder_inputs,
decoder_inputs):
predict_probs = self.network(encoder_inputs, decoder_inputs, self.teacher_force)
predict_probs = self.transpose(predict_probs, (1, 0, 2))
predict_ids, _ = self.argmax(predict_probs)
return predict_ids

View File

@ -0,0 +1,243 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU train cell"""
from mindspore import Tensor, Parameter, ParameterTuple, context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from src.config import config
from src.loss import NLLLoss
class GRUWithLossCell(nn.Cell):
"""
GRU network connect with loss function.
Args:
network: The training network.
Returns:
the output of loss function.
"""
def __init__(self, network):
super(GRUWithLossCell, self).__init__()
self.network = network
self.loss = NLLLoss()
self.logits_shape = (-1, config.src_vocab_size)
self.reshape = P.Reshape()
self.cast = P.Cast()
self.mean = P.ReduceMean()
self.text_len = config.max_length
self.split = P.Split(axis=0, output_num=config.max_length-1)
self.squeeze = P.Squeeze()
self.add = P.AddN()
self.transpose = P.Transpose()
self.shape = P.Shape()
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
'''
GRU loss cell
Args:
encoder_inputs(Tensor): encoder inputs
decoder_inputs(Tensor): decoder inputs
teacher_force(Tensor): teacher force flag
Returns:
loss(scalar): loss output
'''
logits = self.network(encoder_inputs, decoder_inputs, teacher_force)
logits = self.cast(logits, mstype.float32)
loss_total = ()
decoder_targets = decoder_inputs
decoder_output = logits
for i in range(1, self.text_len):
loss = self.loss(self.squeeze(decoder_output[i-1:i:1, ::, ::]), decoder_targets[:, i])
loss_total += (loss,)
loss = self.add(loss_total) / self.text_len
return loss
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
class ClipGradients(nn.Cell):
"""
Clip gradients.
Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
"""Defines the gradients clip."""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
t = self.cast(t, dt)
new_grads = new_grads + (t,)
return new_grads
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class GRUTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of GRU network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(GRUTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
encoder_inputs,
decoder_inputs,
teacher_force,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(encoder_inputs,
decoder_inputs,
teacher_force)
init = False
if not self.gpu_target:
# alloc status
init = self.alloc_status()
# clear overflow buffer
self.clear_before_grad(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(encoder_inputs,
decoder_inputs,
teacher_force,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
if not self.gpu_target:
self.get_status(init)
# sum overflow buffer elements, 0: not overflow, >0: overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
# convert flag_sum to scalar
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NLLLoss cell"""
import mindspore.ops.operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
class NLLLoss(_Loss):
'''
NLLLoss function
'''
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()
self.reduce_sum = P.ReduceSum()
def construct(self, logits, label):
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
return self.get_loss(loss)

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for deeptext"""
import math
def rsqrt_decay(warmup_steps, current_step):
return float(max([current_step, warmup_steps])) ** -0.5
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, total_steps):
decay_steps = total_steps - warmup_steps
linear_decay = (total_steps - current_step) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * current_step / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
learning_rate = decayed * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.base_lr
total_steps = int(base_step * config.num_epochs)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ids to tokens."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tokenization
# Explicitly set the encoding
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
def main():
parser = argparse.ArgumentParser(
description="recore nbest with smoothed sentence-level bleu.")
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
args = parser.parse_args()
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
for line in sys.stdin:
token_ids = [int(x) for x in line.strip().split()]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
sent = " ".join(tokens)
sent = sent.split("<sos>")[-1]
sent = sent.split("<eos>")[0]
print(sent.strip())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Dataset preprocess'''
import os
import argparse
from collections import Counter
from nltk.tokenize import word_tokenize
def create_tokenized_sentences(input_files, language):
'''
Create tokenized sentences files.
Args:
input_files: input files.
language: text language
'''
sentence = []
total_lines = open(input_files, "r").read().splitlines()
for line in total_lines:
line = line.strip('\r\n ')
line = line.lower()
tokenize_sentence = word_tokenize(line, language)
str_sentence = " ".join(tokenize_sentence)
sentence.append(str_sentence)
tokenize_file = input_files + ".tok"
f = open(tokenize_file, "w")
for line in sentence:
f.write(line)
f.write("\n")
f.close()
def get_dataset_vocab(text_file, vocab_file):
'''
Create dataset vocab files.
Args:
text_file: dataset text files.
vocab_file: vocab file
'''
counter = Counter()
text_lines = open(text_file, "r").read().splitlines()
for line in text_lines:
for word in line.strip('\r\n ').split(' '):
if word:
counter[word] += 1
vocab = open(vocab_file, "w")
basic_label = ["<unk>", "<pad>", "<sos>", "<eos>"]
for label in basic_label:
vocab.write(label + "\n")
for key, f in sorted(counter.items(), key=lambda x: x[1], reverse=True):
if f < 2:
continue
vocab.write(key + "\n")
vocab.close()
def MergeText(root_dir, file_list, output_file):
'''
Merge text files together.
Args:
root_dir: root dir
file_list: dataset files list.
output_file: output file after merge
'''
output_file = os.path.join(root_dir, output_file)
f_output = open(output_file, "w")
for file_name in file_list:
text_path = os.path.join(root_dir, file_name) + ".tok"
f = open(text_path)
f_output.write(f.read() + "\n")
f_output.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='gru_dataset')
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path, default: f`sns.")
args = parser.parse_args()
dataset_path = args.dataset_path
src_file_list = ["train.de", "test.de", "val.de"]
dst_file_list = ["train.en", "test.en", "val.en"]
for file in src_file_list:
file_path = os.path.join(dataset_path, file)
create_tokenized_sentences(file_path, "english")
for file in dst_file_list:
file_path = os.path.join(dataset_path, file)
create_tokenized_sentences(file_path, "german")
src_all_file = "all.de.tok"
dst_all_file = "all.en.tok"
MergeText(dataset_path, src_file_list, src_all_file)
MergeText(dataset_path, dst_file_list, dst_all_file)
src_vocab = os.path.join(dataset_path, "vocab.de")
dst_vocab = os.path.join(dataset_path, "vocab.en")
get_dataset_vocab(os.path.join(dataset_path, src_all_file), src_vocab)
get_dataset_vocab(os.path.join(dataset_path, dst_all_file), dst_vocab)

View File

@ -0,0 +1,223 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Seq2Seq construction"""
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.gru import BidirectionGRU, GRU
from src.weight_init import dense_default_state
class Attention(nn.Cell):
'''
Attention model
'''
def __init__(self, config):
super(Attention, self).__init__()
self.text_len = config.max_length
self.attn = nn.Dense(in_channels=config.hidden_size * 3,
out_channels=config.hidden_size).to_float(mstype.float16)
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(mstype.float16)
self.expandims = P.ExpandDims()
self.tanh = P.Tanh()
self.softmax = P.Softmax()
self.tile = P.Tile()
self.transpose = P.Transpose()
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=2)
self.cast = P.Cast()
def construct(self, hidden, encoder_outputs):
'''
Attention construction
Args:
hidden(Tensor): hidden state
encoder_outputs(Tensor): the output of encoder
Returns:
Tensor, attention output
'''
hidden = self.expandims(hidden, 1)
hidden = self.tile(hidden, (1, self.text_len, 1))
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
out = self.concat((hidden, encoder_outputs))
out = self.attn(out)
energy = self.tanh(out)
attention = self.fc(energy)
attention = self.squeeze(attention)
attention = self.cast(attention, mstype.float32)
attention = self.softmax(attention)
attention = self.cast(attention, mstype.float16)
return attention
class Encoder(nn.Cell):
'''
Encoder model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Encoder, self).__init__()
self.hidden_size = config.hidden_size
self.vocab_size = config.src_vocab_size
self.embedding_size = config.encoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = BidirectionGRU(config, is_training=is_training).to_float(mstype.float16)
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(mstype.float16)
self.shape = P.Shape()
self.transpose = P.Transpose()
self.p = P.Print()
self.cast = P.Cast()
self.text_len = config.max_length
self.squeeze = P.Squeeze(axis=0)
self.tanh = P.Tanh()
def construct(self, src):
'''
Encoder construction
Args:
src(Tensor): source sentences
Returns:
output(Tensor): output of rnn
hidden(Tensor): output hidden
'''
embedded = self.embedding(src)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float16)
output, hidden = self.rnn(embedded)
hidden = self.fc(hidden)
hidden = self.tanh(hidden)
return output, hidden
class Decoder(nn.Cell):
'''
Decoder model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Decoder, self).__init__()
self.hidden_size = config.hidden_size
self.vocab_size = config.trg_vocab_size
self.embedding_size = config.decoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = GRU(config, is_training=is_training).to_float(mstype.float16)
self.text_len = config.max_length
self.shape = P.Shape()
self.transpose = P.Transpose()
self.p = P.Print()
self.cast = P.Cast()
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=0)
self.expandims = P.ExpandDims()
self.log_softmax = P.LogSoftmax(axis=1)
weight, bias = dense_default_state(self.embedding_size+self.hidden_size*3, self.vocab_size)
self.fc = nn.Dense(self.embedding_size+self.hidden_size*3, self.vocab_size,
weight_init=weight, bias_init=bias).to_float(mstype.float16)
self.attention = Attention(config)
self.bmm = P.BatchMatMul()
self.dropout = nn.Dropout(0.7)
self.expandims = P.ExpandDims()
def construct(self, inputs, hidden, encoder_outputs):
'''
Decoder construction
Args:
inputs(Tensor): decoder input
hidden(Tensor): hidden state
encoder_outputs(Tensor): encoder output
Returns:
pred_prob(Tensor): decoder predict probility
hidden(Tensor): hidden state
'''
embedded = self.embedding(inputs)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float16)
attn = self.attention(hidden, encoder_outputs)
attn = self.expandims(attn, 1)
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
weight = self.bmm(attn, encoder_outputs)
weight = self.transpose(weight, (1, 0, 2))
emd_con = self.concat((embedded, weight))
output, hidden = self.rnn(emd_con)
out = self.concat((embedded, output, weight))
out = self.squeeze(out)
hidden = self.squeeze(hidden)
prediction = self.fc(out)
prediction = self.dropout(prediction)
prediction = self.cast(prediction, mstype.float32)
prediction = self.cast(prediction, mstype.float32)
pred_prob = self.log_softmax(prediction)
pred_prob = self.expandims(pred_prob, 0)
return pred_prob, hidden
class Seq2Seq(nn.Cell):
'''
Seq2Seq model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Seq2Seq, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.encoder = Encoder(config, is_training=is_training)
self.decoder = Decoder(config, is_training=is_training)
self.expandims = P.ExpandDims()
self.dropout = nn.Dropout()
self.shape = P.Shape()
self.concat = P.Concat(axis=0)
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
self.squeeze = P.Squeeze(axis=0)
self.sos = Tensor(np.ones((self.batch_size, 1))*2, mstype.int32)
self.select = P.Select()
self.text_len = config.max_length
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
'''
Seq2Seq construction
Args:
encoder_inputs(Tensor): encoder input sentences
decoder_inputs(Tensor): decoder input sentences
teacher_force(Tensor): teacher force flag
Returns:
outputs(Tensor): total predict probility
'''
decoder_input = self.sos
encoder_output, hidden = self.encoder(encoder_inputs)
decoder_hidden = hidden
decoder_outputs = ()
for i in range(1, self.text_len):
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output)
decoder_outputs += (decoder_output,)
if self.training:
decoder_input_force = decoder_inputs[::, i:i+1]
decoder_input_top1, _ = self.argmax(self.squeeze(decoder_output))
decoder_input = self.select(teacher_force, decoder_input_force, decoder_input_top1)
else:
decoder_input, _ = self.argmax(self.squeeze(decoder_output))
outputs = self.concat(decoder_outputs)
return outputs

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tokenization utilities."""
import sys
import collections
import unicodedata
def convert_to_printable(text):
"""
Converts `text` to a printable coding format.
"""
if sys.version_info[0] == 3:
if isinstance(text, str):
return text
if isinstance(text, bytes):
return text.decode("utf-8", "ignore")
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
raise ValueError("Only supported when running on Python3.")
def convert_to_unicode(text):
"""
Converts `text` to Unicode format.
"""
if sys.version_info[0] == 3:
if isinstance(text, str):
return text
if isinstance(text, bytes):
return text.decode("utf-8", "ignore")
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
if sys.version_info[0] == 2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
if isinstance(text, unicode):
return text
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
raise ValueError("Only supported when running on Python2 or Python3.")
def load_vocab_file(vocab_file):
"""
Loads a vocabulary file and turns into a {token:id} dictionary.
"""
vocab_dict = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as vocab:
while True:
token = convert_to_unicode(vocab.readline())
if not token:
break
token = token.strip()
vocab_dict[token] = index
index += 1
return vocab_dict
def convert_by_vocab_dict(vocab_dict, items):
"""
Converts a sequence of [tokens|ids] according to the vocab dict.
"""
output = []
for item in items:
if item in vocab_dict:
output.append(vocab_dict[item])
else:
output.append(vocab_dict["<unk>"])
return output
class WhiteSpaceTokenizer():
"""
Whitespace tokenizer.
"""
def __init__(self, vocab_file):
self.vocab_dict = load_vocab_file(vocab_file)
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()}
def _is_whitespace_char(self, char):
"""
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here).
"""
if char in (" ", "\t", "\n", "\r"):
return True
uni = unicodedata.category(char)
if uni == "Zs":
return True
return False
def _is_control_char(self, char):
"""
Checks if it is a control character.
"""
if char in ("\t", "\n", "\r"):
return False
uni = unicodedata.category(char)
if uni in ("Cc", "Cf"):
return True
return False
def _clean_text(self, text):
"""
Remove invalid characters and cleanup whitespace.
"""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or self._is_control_char(char):
continue
if self._is_whitespace_char(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def _whitespace_tokenize(self, text):
"""
Clean whitespace and split text into tokens.
"""
text = text.strip()
text = text.lower()
if text.endswith("."):
text = text.replace(".", " .")
if not text:
tokens = []
else:
tokens = text.split()
return tokens
def tokenize(self, text):
"""
Tokenizes text.
"""
text = convert_to_unicode(text)
text = self._clean_text(text)
tokens = self._whitespace_tokenize(text)
return tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab_dict(self.vocab_dict, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab_dict(self.inv_vocab_dict, ids)

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""weight init"""
import math
import numpy as np
from mindspore import Tensor, Parameter
def gru_default_state(batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
'''Weight init for gru cell'''
stdv = 1 / math.sqrt(hidden_size)
weight_i = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), name='weight_i')
weight_h = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (hidden_size, 3*hidden_size)).astype(np.float32)), name='weight_h')
bias_i = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_i')
bias_h = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_h')
init_h = Tensor(np.zeros((batch_size, hidden_size)).astype(np.float16))
return weight_i, weight_h, bias_i, bias_h, init_h
def dense_default_state(in_channel, out_channel):
'''Weight init for dense cell'''
stdv = 1 / math.sqrt(in_channel)
weight = Tensor(np.random.uniform(-stdv, stdv, (out_channel, in_channel)).astype(np.float32))
bias = Tensor(np.random.uniform(-stdv, stdv, (out_channel)).astype(np.float32))
return weight, bias

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train script"""
import os
import time
import argparse
import ast
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.common import set_seed
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn.optim import Adam
from src.config import config
from src.seq2seq import Seq2Seq
from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell
from src.dataset import create_gru_dataset
from src.lr_schedule import dynamic_lr
set_seed(1)
parser = argparse.ArgumentParser(description="GRU training")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False)
def get_ms_timestamp():
t = time.time()
return int(round(t * 1000))
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = get_ms_timestamp()
time_stamp_init = True
def step_end(self, run_context):
"""Monitor the loss in training."""
global time_stamp_first
time_stamp_current = get_ms_timestamp()
cb_params = run_context.original_args()
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs)))
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
f.write("time: {}, epoch: {}, step: {}, loss: {}, overflow: {}, loss_scale: {}".format(
time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs[0].asnumpy()),
str(cb_params.net_outputs[1].asnumpy()),
str(cb_params.net_outputs[2].asnumpy())))
f.write('\n')
if __name__ == '__main__':
if args.run_distribute:
rank = args.rank_id
device_num = args.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config)
network = GRUWithLossCell(network)
lr = dynamic_lr(config, dataset_size)
opt = Adam(network.trainable_params(), learning_rate=lr)
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale_value,
scale_factor=config.scale_factor,
scale_window=config.scale_window)
update_cell = scale_manager.get_update_cell()
netwithgrads = GRUTrainOneStepWithLossScaleCell(network, opt, update_cell)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
#Save Checkpoint
if config.save_checkpoint:
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank_id))
cb += [ckpt_cb]
netwithgrads.set_train(True)
model = Model(netwithgrads)
model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True)