forked from mindspore-Ecosystem/mindspore
init for gru
fix for copyright fix for create_dataset fix for some bug fix for parser_output fix for parse output add readme and fix for some bug fix test2016 to test delete space fix some bug for create dataset fix for comments and overflow fix a wrong bug fix for weight init
This commit is contained in:
parent
dfa6daaa57
commit
a37ad24136
|
@ -0,0 +1,252 @@
|
|||
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [GRU](#gru)
|
||||
- [Model Structure](#model-structure)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Configuration File](#configuration-file)
|
||||
- [Training Process](#training-process)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Random Situation Description](#random-situation-description)
|
||||
- [Others](#others)
|
||||
- [ModelZoo HomePage](#modelzoo-homepage)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# [GRU](#contents)
|
||||
|
||||
GRU(Gate Recurrent Unit) is a kind of recurrent neural network algorithm, just like the LSTM(Long-Short Term Memory). It was proposed by Kyunghyun Cho, Bart van Merrienboer etc. in the article "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" in 2014. In this paper, it proposes a novel neural network model called RNN Encoder-Decoder that consists of two recurrent neural networks (RNN).To improve the effect of translation task, we also refer to "Sequence to Sequence Learning with Neural Networks" and "Neural Machine Translation by Jointly Learning to Align and Translate".
|
||||
|
||||
## Paper
|
||||
|
||||
1.[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation", 2014, Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
|
||||
|
||||
2.[Paper](https://arxiv.org/pdf/1409.3215.pdf): "Sequence to Sequence Learning with Neural Networks", 2014, Ilya Sutskever, Oriol Vinyals, Quoc V. Le
|
||||
|
||||
3.[Paper](): "Neural Machine Translation by Jointly Learning to Align and Translate", 2014, Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio
|
||||
|
||||
# [Model Structure](#contents)
|
||||
|
||||
The GRU model mainly consists of an Encoder and a Decoder.The Encoder is constructed with a bidirection GRU cell.The Decoder mainly contains an attention and a GRU cell.The input of the net is sequence of words (text or sentence), and the output of the net is the probability of each word in vocab, and we choose the maximum probability one as our prediction.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
In this model, we use the Multi30K dataset as our train and test dataset.As training dataset, it provides 29,000 respectively, each containing an German sentence and its English translation.For testing dataset, it provides 1000 German and English sentences.We also provide a preprocess script to tokenize the dataset and create the vocab file.
|
||||
|
||||
# [Environment Requirements](#content)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#content)
|
||||
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET_PATH]
|
||||
|
||||
# run distributed training example
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
# [Script Description](#content)
|
||||
|
||||
The GRU network script and code result are as follows:
|
||||
|
||||
```text
|
||||
├── gru
|
||||
├── README.md // Introduction of GRU model.
|
||||
├── src
|
||||
| ├──gru.py // gru cell architecture.
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──create_data.py // Dataset preparation.
|
||||
│ ├──dataset.py // Dataset loader to feed into model.
|
||||
│ ├──gru_for_infer.py // GRU eval model architecture.
|
||||
│ ├──gru_for_train.py // GRU train model architecture.
|
||||
│ ├──loss.py // Loss architecture.
|
||||
│ ├──lr_schedule.py // Learning rate scheduler.
|
||||
│ ├──parse_output.py // Parse output file.
|
||||
│ ├──preprocess.py // Dataset preprocess.
|
||||
│ ├──seq2seq.py // Seq2seq architecture.
|
||||
│ ├──tokenization.py // tokenization for the dataset.
|
||||
│ ├──weight_init.py // Initialize weights in the net.
|
||||
├── scripts
|
||||
│ ├──create_dataset.sh // shell script for create dataset.
|
||||
│ ├──parse_output.sh // shell script for parse eval output file to calculate BLEU.
|
||||
│ ├──preprocess.sh // shell script for preprocess dataset.
|
||||
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
|
||||
│ ├──run_eval.sh // shell script for standalone eval on ascend.
|
||||
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
|
||||
├── eval.py // Infer API entry.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
```
|
||||
|
||||
## [Dataset Preparation](#content)
|
||||
|
||||
Firstly, we should download the dataset from the WMT16 official net.After downloading the Multi30k dataset file, we get six dataset file, which is show as below.And we should in put the in same directory.
|
||||
|
||||
```text
|
||||
train.de
|
||||
train.en
|
||||
val.de
|
||||
val.en
|
||||
test.de
|
||||
test.en
|
||||
```
|
||||
|
||||
Then, we can use the scripts/preprocess.sh to tokenize the dataset file and get the vocab file.
|
||||
|
||||
```bash
|
||||
bash preprocess.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
After preprocess, we will get the dataset file which is suffix with ".tok" and two vocab file, which are nameed vocab.de and vocab.en.
|
||||
Then we provided scripts/create_dataset.sh to create the dataset file which format is mindrecord.
|
||||
|
||||
```bash
|
||||
bash preprocess.sh [DATASET_PATH] [OUTPUT_PATH]
|
||||
```
|
||||
|
||||
Finally, we will get multi30k_train_mindrecord_0 ~ multi30k_train_mindrecord_8 as our train dataset, and multi30k_test_mindrecord as our test dataset.
|
||||
|
||||
## [Configuration File](#content)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs.
|
||||
|
||||
- Network Parameters
|
||||
|
||||
```text
|
||||
"batch_size": 16, # batch size of input dataset.
|
||||
"src_vocab_size": 8154, # source dataset vocabulary size.
|
||||
"trg_vocab_size": 6113, # target dataset vocabulary size.
|
||||
"encoder_embedding_size": 256, # encoder embedding size.
|
||||
"decoder_embedding_size": 256, # decoder embedding size.
|
||||
"hidden_size": 512, # hidden size of gru.
|
||||
"max_length": 32, # max sentence length.
|
||||
"num_epochs": 30, # total epoch.
|
||||
"save_checkpoint": True, # whether save checkpoint file.
|
||||
"ckpt_epoch": 1, # frequence to save checkpoint file.
|
||||
"target_file": "target.txt", # the target file.
|
||||
"output_file": "output.txt", # the output file.
|
||||
"keep_checkpoint_max": 30, # the maximum number of checkpoint file.
|
||||
"base_lr": 0.001, # init learning rate.
|
||||
"warmup_step": 300, # warmup step.
|
||||
"momentum": 0.9, # momentum in optimizer.
|
||||
"init_loss_scale_value": 1024, # init scale sense.
|
||||
'scale_factor': 2, # scale factor for dynamic loss scale.
|
||||
'scale_window': 2000, # scale window for dynamic loss scale.
|
||||
"warmup_ratio": 1/3.0, # warmup ratio.
|
||||
"teacher_force_ratio": 0.5 # teacher force ratio.
|
||||
```
|
||||
|
||||
## [Training Process](#content)
|
||||
|
||||
- Start task training on a single device and run the shell script
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
- Running scripts for distributed training of GRU. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
|
||||
|
||||
``` bash
|
||||
cd ./scripts
|
||||
sh run_distributed_train.sh [RANK_TABLE_PATH] [DATASET_PATH]
|
||||
```
|
||||
|
||||
## [Inference Process](#content)
|
||||
|
||||
- Running scripts for evaluation of GRU. The commdan as below.
|
||||
|
||||
``` bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
- After evalulation, we will get eval/target.txt and eval/output.txt.Then we can use scripts/parse_output.sh to get the translation.
|
||||
|
||||
``` bash
|
||||
cp eval/*.txt ./
|
||||
sh parse_output.sh target.txt output.txt /path/vocab.en
|
||||
```
|
||||
|
||||
- After parse output, we will get target.txt.forbleu and output.txt.forbleu.To calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
|
||||
|
||||
```bash
|
||||
perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu
|
||||
```
|
||||
|
||||
Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord
|
||||
|
||||
# [Model Description](#content)
|
||||
|
||||
## [Performance](#content)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 01/18/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Multi30k Dataset |
|
||||
| Training Parameters | epoch=30, batch_size=16 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | NLLLoss |
|
||||
| outputs | probability |
|
||||
| Speed | 50ms/step (1pcs) |
|
||||
| Epoch Time | 13.4s (1pcs) |
|
||||
| Loss | 2.5984 |
|
||||
| Params (M) | 21 |
|
||||
| Checkpoint for inference | 272M (.ckpt file) |
|
||||
| Scripts | [gru](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gru) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 01/18/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Multi30K |
|
||||
| batch_size | 1 |
|
||||
| outputs | label index |
|
||||
| Accuracy | BLEU: 30.30 |
|
||||
| Model for inference | 272M (.ckpt file) |
|
||||
|
||||
# [Random Situation Description](#content)
|
||||
|
||||
There only one random situation.
|
||||
|
||||
- Initialization of some model weights.
|
||||
|
||||
Some seeds have already been set in train.py to avoid the randomness of weight initialization.
|
||||
|
||||
# [Others](#others)
|
||||
|
||||
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
|
||||
|
||||
# [ModelZoo HomePage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
|
||||
import argparse
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_infer import GRUInferCell
|
||||
from src.config import config
|
||||
|
||||
def run_gru_eval():
|
||||
"""
|
||||
Transformer evaluation.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='GRU eval')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
|
||||
parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Dataset path, default: f`sns.")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
||||
device_id=args.device_id, save_graphs=False)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
||||
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config, is_training=False)
|
||||
network = GRUInferCell(network)
|
||||
network.set_train(False)
|
||||
if args.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, parameter_dict)
|
||||
model = Model(network)
|
||||
|
||||
predictions = []
|
||||
source_sents = []
|
||||
target_sents = []
|
||||
eval_text_len = 0
|
||||
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
source_sents.append(batch["source_ids"])
|
||||
target_sents.append(batch["target_ids"])
|
||||
source_ids = Tensor(batch["source_ids"], mstype.int32)
|
||||
target_ids = Tensor(batch["target_ids"], mstype.int32)
|
||||
predicted_ids = model.predict(source_ids, target_ids)
|
||||
print("predicts is ", predicted_ids.asnumpy())
|
||||
print("target_ids is ", target_ids)
|
||||
predictions.append(predicted_ids.asnumpy())
|
||||
eval_text_len = eval_text_len + 1
|
||||
|
||||
f_output = open(config.output_file, 'w')
|
||||
f_target = open(config.target_file, "w")
|
||||
for batch_out, true_sentence in zip(predictions, target_sents):
|
||||
for i in range(config.eval_batch_size):
|
||||
target_ids = [str(x) for x in true_sentence[i].tolist()]
|
||||
f_target.write(" ".join(target_ids) + "\n")
|
||||
token_ids = [str(x) for x in batch_out[i].tolist()]
|
||||
f_output.write(" ".join(token_ids) + "\n")
|
||||
f_output.close()
|
||||
f_target.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_gru_eval()
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh create_dataset.sh DATASET_PATH OUTPUT_PATH"
|
||||
echo "for example: sh create_dataset.sh /path/multi30k/ /path/multi30k/mindrecord/"
|
||||
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ulimit -u unlimited
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not valid"
|
||||
exit 1
|
||||
fi
|
||||
OUTPUT_PATH=$(get_real_path $2)
|
||||
echo $OUTPUT_PATH
|
||||
if [ ! -d $OUTPUT_PATH ]
|
||||
then
|
||||
echo "error: OUTPUT_PATH=$OUTPUT_PATH is not valid"
|
||||
exit 1
|
||||
fi
|
||||
paste $DATASET_PATH/train.de.tok $DATASET_PATH/train.en.tok > $DATASET_PATH/train.all
|
||||
python ../src/create_data.py --input_file $DATASET_PATH/train.all --num_splits 8 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_train_mindrecord --max_seq_length 32 --bucket [32]
|
||||
paste $DATASET_PATH/test.de.tok $DATASET_PATH/test.en.tok > $DATASET_PATH/test.all
|
||||
python ../src/create_data.py --input_file $DATASET_PATH/test.all --num_splits 1 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_test_mindrecord --max_seq_length 32 --bucket [32]
|
|
@ -0,0 +1,33 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE"
|
||||
echo "for example: sh parse_output.sh target.txt output.txt vocab.en"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ref_data=$1
|
||||
eval_output=$2
|
||||
vocab_file=$3
|
||||
|
||||
cat $ref_data \
|
||||
| python ../src/parse_output.py --vocab_file $vocab_file \
|
||||
| sed 's/@@ //g' > ${ref_data}.forbleu
|
||||
|
||||
cat $eval_output \
|
||||
| python ../src/parse_output.py --vocab_file $vocab_file \
|
||||
| sed 's/@@ //g' > ${eval_output}.forbleu
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATASET_DIR=$1
|
||||
python ../src/preprocess.py --dataset_path=$DATASET_DIR
|
|
@ -0,0 +1,68 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [CKPT_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
CKPT_FILE=$(get_real_path $1)
|
||||
echo $CKPT_FILE
|
||||
if [ ! -f $CKPT_FILE ]
|
||||
then
|
||||
echo "error: CKPT_FILE=$CKPT_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=4
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU config"""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"batch_size": 16,
|
||||
"eval_batch_size": 1,
|
||||
"src_vocab_size": 8154,
|
||||
"trg_vocab_size": 6113,
|
||||
"encoder_embedding_size": 256,
|
||||
"decoder_embedding_size": 256,
|
||||
"hidden_size": 512,
|
||||
"max_length": 32,
|
||||
"num_epochs": 30,
|
||||
"save_checkpoint": True,
|
||||
"ckpt_epoch": 10,
|
||||
"target_file": "target.txt",
|
||||
"output_file": "output.txt",
|
||||
"keep_checkpoint_max": 30,
|
||||
"base_lr": 0.001,
|
||||
"warmup_step": 300,
|
||||
"momentum": 0.9,
|
||||
"init_loss_scale_value": 1024,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 2000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"teacher_force_ratio": 0.5
|
||||
})
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create training instances for Transformer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
import tokenization
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
class SampleInstance():
|
||||
"""A single sample instance (sentence pair)."""
|
||||
|
||||
def __init__(self, source_tokens, target_tokens):
|
||||
self.source_tokens = source_tokens
|
||||
self.target_tokens = target_tokens
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "source_tokens: %s\n" % (" ".join(
|
||||
[tokenization.convert_to_printable(x) for x in self.source_tokens]))
|
||||
s += "target tokens: %s\n" % (" ".join(
|
||||
[tokenization.convert_to_printable(x) for x in self.target_tokens]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def get_instance_features(instance, tokenizer_src, tokenizer_trg, max_seq_length, bucket):
|
||||
"""Get features from `SampleInstance`s."""
|
||||
def _find_bucket_length(source_tokens, target_tokens):
|
||||
source_ids = tokenizer_src.convert_tokens_to_ids(source_tokens)
|
||||
target_ids = tokenizer_trg.convert_tokens_to_ids(target_tokens)
|
||||
num = max(len(source_ids), len(target_ids))
|
||||
assert num <= bucket[-1]
|
||||
for index in range(1, len(bucket)):
|
||||
if bucket[index - 1] < num <= bucket[index]:
|
||||
return bucket[index]
|
||||
return bucket[0]
|
||||
|
||||
def _convert_ids_and_mask(tokenizer, input_tokens, seq_max_bucket_length):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < seq_max_bucket_length:
|
||||
input_ids.append(1)
|
||||
input_mask.append(0)
|
||||
|
||||
assert len(input_ids) == seq_max_bucket_length
|
||||
assert len(input_mask) == seq_max_bucket_length
|
||||
|
||||
return input_ids, input_mask
|
||||
|
||||
seq_max_bucket_length = _find_bucket_length(instance.source_tokens, instance.target_tokens)
|
||||
source_ids, source_mask = _convert_ids_and_mask(tokenizer_src, instance.source_tokens, seq_max_bucket_length)
|
||||
target_ids, target_mask = _convert_ids_and_mask(tokenizer_trg, instance.target_tokens, seq_max_bucket_length)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["source_ids"] = np.asarray(source_ids)
|
||||
features["source_mask"] = np.asarray(source_mask)
|
||||
features["target_ids"] = np.asarray(target_ids)
|
||||
features["target_mask"] = np.asarray(target_mask)
|
||||
|
||||
return features, seq_max_bucket_length
|
||||
|
||||
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
|
||||
"""Creates `SampleInstance`s for a single sentence pair."""
|
||||
EOS = "<eos>"
|
||||
SOS = "<sos>"
|
||||
|
||||
if len(source_words) >= max_seq_length-1 or len(target_words) >= max_seq_length-1:
|
||||
if clip_to_max_len:
|
||||
source_words = source_words[:min([len(source_words, max_seq_length-2)])]
|
||||
target_words = target_words[:min([len(target_words, max_seq_length-2)])]
|
||||
else:
|
||||
return None
|
||||
source_tokens = [SOS] + source_words + [EOS]
|
||||
target_tokens = [SOS] + target_words + [EOS]
|
||||
instance = SampleInstance(
|
||||
source_tokens=source_tokens,
|
||||
target_tokens=target_tokens)
|
||||
return instance
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True,
|
||||
help='Input raw text file (or comma-separated list of files).')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
parser.add_argument("--num_splits", type=int, default=16,
|
||||
help='The MindRecord file will be split into the number of partition.')
|
||||
parser.add_argument("--src_vocab_file", type=str, required=True,
|
||||
help='The vocabulary file that the Transformer model was trained on.')
|
||||
parser.add_argument("--trg_vocab_file", type=str, required=True,
|
||||
help='The vocabulary file that the Transformer model was trained on.')
|
||||
parser.add_argument("--clip_to_max_len", type=ast.literal_eval, default=False,
|
||||
help='clip sequences to maximum sequence length.')
|
||||
parser.add_argument("--max_seq_length", type=int, default=32, help='Maximum sequence length.')
|
||||
parser.add_argument("--bucket", type=ast.literal_eval, default=[32],
|
||||
help='bucket sequence length')
|
||||
args = parser.parse_args()
|
||||
tokenizer_src = tokenization.WhiteSpaceTokenizer(vocab_file=args.src_vocab_file)
|
||||
tokenizer_trg = tokenization.WhiteSpaceTokenizer(vocab_file=args.trg_vocab_file)
|
||||
input_files = []
|
||||
for input_pattern in args.input_file.split(","):
|
||||
input_files.append(input_pattern)
|
||||
logging.info("*** Read from input files ***")
|
||||
output_file = args.output_file
|
||||
logging.info("*** Write to output files ***")
|
||||
logging.info(" %s", output_file)
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
feature_dict = {}
|
||||
for i in args.bucket:
|
||||
feature_dict[i] = []
|
||||
for input_file in input_files:
|
||||
logging.info("*** Reading from %s ***", input_file)
|
||||
with open(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 100000 == 0:
|
||||
logging.info("Read %d ...", total_read)
|
||||
if line.strip() == "":
|
||||
continue
|
||||
source_line, target_line = line.strip().split("\t")
|
||||
source_tokens = tokenizer_src.tokenize(source_line)
|
||||
target_tokens = tokenizer_trg.tokenize(target_line)
|
||||
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
|
||||
logging.info("ignore long sentence!")
|
||||
continue
|
||||
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
|
||||
clip_to_max_len=args.clip_to_max_len)
|
||||
if instance is None:
|
||||
continue
|
||||
features, seq_max_bucket_length = get_instance_features(instance, tokenizer_src, tokenizer_trg,
|
||||
args.max_seq_length, args.bucket)
|
||||
for key in feature_dict:
|
||||
if key == seq_max_bucket_length:
|
||||
feature_dict[key].append(features)
|
||||
if total_read <= 10:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("source tokens: %s", " ".join(
|
||||
[tokenization.convert_to_printable(x) for x in instance.source_tokens]))
|
||||
logging.info("target tokens: %s", " ".join(
|
||||
[tokenization.convert_to_printable(x) for x in instance.target_tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
for i in args.bucket:
|
||||
if args.num_splits == 1:
|
||||
output_file_name = output_file + '_' + str(i)
|
||||
else:
|
||||
output_file_name = output_file + '_' + str(i) + '_'
|
||||
writer = FileWriter(output_file_name, args.num_splits)
|
||||
data_schema = {"source_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_mask": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "gru")
|
||||
features_ = feature_dict[i]
|
||||
logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
|
||||
for item in features_:
|
||||
writer.write_raw_data([item])
|
||||
total_written += 1
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations, will be used in train.py."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from src.config import config
|
||||
import numpy as np
|
||||
de.config.set_seed(1)
|
||||
|
||||
def random_teacher_force(source_ids, target_ids, target_mask):
|
||||
|
||||
teacher_force = np.random.random() < config.teacher_force_ratio
|
||||
teacher_force_array = np.array([teacher_force], dtype=bool)
|
||||
return source_ids, target_ids, teacher_force_array
|
||||
|
||||
def create_gru_dataset(epoch_count=1, batch_size=1, rank_size=1, rank_id=0, do_shuffle=True, dataset_path=None,
|
||||
is_training=True):
|
||||
"""create dataset"""
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["source_ids", "target_ids",
|
||||
"target_mask"],
|
||||
shuffle=do_shuffle, num_parallel_workers=10, num_shards=rank_size, shard_id=rank_id)
|
||||
operations = random_teacher_force
|
||||
ds = ds.map(operations=operations, input_columns=["source_ids", "target_ids", "target_mask"],
|
||||
output_columns=["source_ids", "target_ids", "teacher_force"],
|
||||
column_order=["source_ids", "target_ids", "teacher_force"])
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
type_cast_op_bool = deC.TypeCast(mstype.bool_)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="source_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="target_ids")
|
||||
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
return ds
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU cell"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.weight_init import gru_default_state
|
||||
|
||||
class BidirectionGRU(nn.Cell):
|
||||
'''
|
||||
BidirectionGRU model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(BidirectionGRU, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = gru_default_state(self.batch_size,
|
||||
self.embedding_size,
|
||||
self.hidden_size)
|
||||
self.weight_bw_i, self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, self.init_bw_h = \
|
||||
gru_default_state(self.batch_size, self.embedding_size, self.hidden_size)
|
||||
self.reverse = P.ReverseV2(axis=[1])
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.text_len = config.max_length
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
BidirectionGRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): BidirectionGRU input
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
y1, _, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
|
||||
bw_x = self.reverse(x)
|
||||
y1_bw, _, _, _, _, _ = self.rnn(bw_x, self.weight_bw_i,
|
||||
self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, None, self.init_bw_h)
|
||||
y1_bw = self.reverse(y1_bw)
|
||||
output = self.concat((y1, y1_bw))
|
||||
hidden = self.concat((y1[self.text_len-1:self.text_len:1, ::, ::],
|
||||
y1_bw[self.text_len-1:self.text_len:1, ::, ::]))
|
||||
hidden = self.squeeze(hidden)
|
||||
return output, hidden
|
||||
|
||||
class GRU(nn.Cell):
|
||||
'''
|
||||
GRU model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(GRU, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = \
|
||||
gru_default_state(self.batch_size, self.embedding_size + self.hidden_size*2, self.hidden_size)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
GRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): GRU input
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
|
||||
return y1, h1
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU Infer cell"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.config import config
|
||||
|
||||
class GRUInferCell(nn.Cell):
|
||||
'''
|
||||
GRU infer consturction
|
||||
|
||||
Args:
|
||||
network: gru network
|
||||
'''
|
||||
def __init__(self, network):
|
||||
super(GRUInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.argmax = P.ArgMaxWithValue(axis=2)
|
||||
self.transpose = P.Transpose()
|
||||
self.teacher_force = Tensor(np.zeros((config.eval_batch_size)), mstype.bool_)
|
||||
def construct(self,
|
||||
encoder_inputs,
|
||||
decoder_inputs):
|
||||
predict_probs = self.network(encoder_inputs, decoder_inputs, self.teacher_force)
|
||||
predict_probs = self.transpose(predict_probs, (1, 0, 2))
|
||||
predict_ids, _ = self.argmax(predict_probs)
|
||||
return predict_ids
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU train cell"""
|
||||
from mindspore import Tensor, Parameter, ParameterTuple, context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from src.config import config
|
||||
from src.loss import NLLLoss
|
||||
|
||||
class GRUWithLossCell(nn.Cell):
|
||||
"""
|
||||
GRU network connect with loss function.
|
||||
|
||||
Args:
|
||||
network: The training network.
|
||||
|
||||
Returns:
|
||||
the output of loss function.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(GRUWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = NLLLoss()
|
||||
self.logits_shape = (-1, config.src_vocab_size)
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
self.mean = P.ReduceMean()
|
||||
self.text_len = config.max_length
|
||||
self.split = P.Split(axis=0, output_num=config.max_length-1)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.add = P.AddN()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
|
||||
'''
|
||||
GRU loss cell
|
||||
|
||||
Args:
|
||||
encoder_inputs(Tensor): encoder inputs
|
||||
decoder_inputs(Tensor): decoder inputs
|
||||
teacher_force(Tensor): teacher force flag
|
||||
|
||||
Returns:
|
||||
loss(scalar): loss output
|
||||
'''
|
||||
logits = self.network(encoder_inputs, decoder_inputs, teacher_force)
|
||||
logits = self.cast(logits, mstype.float32)
|
||||
loss_total = ()
|
||||
decoder_targets = decoder_inputs
|
||||
decoder_output = logits
|
||||
for i in range(1, self.text_len):
|
||||
loss = self.loss(self.squeeze(decoder_output[i-1:i:1, ::, ::]), decoder_targets[:, i])
|
||||
loss_total += (loss,)
|
||||
loss = self.add(loss_total) / self.text_len
|
||||
return loss
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""Defines the gradients clip."""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
t = self.cast(t, dt)
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class GRUTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GRU network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(GRUTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
if not self.gpu_target:
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0: not overflow, >0: overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NLLLoss cell"""
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
class NLLLoss(_Loss):
|
||||
'''
|
||||
NLLLoss function
|
||||
'''
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, label):
|
||||
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
|
||||
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
|
||||
return self.get_loss(loss)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for deeptext"""
|
||||
import math
|
||||
|
||||
def rsqrt_decay(warmup_steps, current_step):
|
||||
return float(max([current_step, warmup_steps])) ** -0.5
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, total_steps):
|
||||
decay_steps = total_steps - warmup_steps
|
||||
linear_decay = (total_steps - current_step) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * current_step / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
learning_rate = decayed * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = int(base_step * config.num_epochs)
|
||||
warmup_steps = int(config.warmup_step)
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ids to tokens."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tokenization
|
||||
|
||||
# Explicitly set the encoding
|
||||
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="recore nbest with smoothed sentence-level bleu.")
|
||||
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
||||
|
||||
for line in sys.stdin:
|
||||
token_ids = [int(x) for x in line.strip().split()]
|
||||
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
||||
sent = " ".join(tokens)
|
||||
sent = sent.split("<sos>")[-1]
|
||||
sent = sent.split("<eos>")[0]
|
||||
print(sent.strip())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''Dataset preprocess'''
|
||||
import os
|
||||
import argparse
|
||||
from collections import Counter
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
def create_tokenized_sentences(input_files, language):
|
||||
'''
|
||||
Create tokenized sentences files.
|
||||
|
||||
Args:
|
||||
input_files: input files.
|
||||
language: text language
|
||||
'''
|
||||
sentence = []
|
||||
total_lines = open(input_files, "r").read().splitlines()
|
||||
for line in total_lines:
|
||||
line = line.strip('\r\n ')
|
||||
line = line.lower()
|
||||
tokenize_sentence = word_tokenize(line, language)
|
||||
str_sentence = " ".join(tokenize_sentence)
|
||||
sentence.append(str_sentence)
|
||||
tokenize_file = input_files + ".tok"
|
||||
f = open(tokenize_file, "w")
|
||||
for line in sentence:
|
||||
f.write(line)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
def get_dataset_vocab(text_file, vocab_file):
|
||||
'''
|
||||
Create dataset vocab files.
|
||||
|
||||
Args:
|
||||
text_file: dataset text files.
|
||||
vocab_file: vocab file
|
||||
'''
|
||||
counter = Counter()
|
||||
text_lines = open(text_file, "r").read().splitlines()
|
||||
for line in text_lines:
|
||||
for word in line.strip('\r\n ').split(' '):
|
||||
if word:
|
||||
counter[word] += 1
|
||||
vocab = open(vocab_file, "w")
|
||||
basic_label = ["<unk>", "<pad>", "<sos>", "<eos>"]
|
||||
for label in basic_label:
|
||||
vocab.write(label + "\n")
|
||||
for key, f in sorted(counter.items(), key=lambda x: x[1], reverse=True):
|
||||
if f < 2:
|
||||
continue
|
||||
vocab.write(key + "\n")
|
||||
vocab.close()
|
||||
|
||||
def MergeText(root_dir, file_list, output_file):
|
||||
'''
|
||||
Merge text files together.
|
||||
|
||||
Args:
|
||||
root_dir: root dir
|
||||
file_list: dataset files list.
|
||||
output_file: output file after merge
|
||||
'''
|
||||
output_file = os.path.join(root_dir, output_file)
|
||||
f_output = open(output_file, "w")
|
||||
for file_name in file_list:
|
||||
text_path = os.path.join(root_dir, file_name) + ".tok"
|
||||
f = open(text_path)
|
||||
f_output.write(f.read() + "\n")
|
||||
f_output.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='gru_dataset')
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path, default: f`sns.")
|
||||
args = parser.parse_args()
|
||||
dataset_path = args.dataset_path
|
||||
src_file_list = ["train.de", "test.de", "val.de"]
|
||||
dst_file_list = ["train.en", "test.en", "val.en"]
|
||||
for file in src_file_list:
|
||||
file_path = os.path.join(dataset_path, file)
|
||||
create_tokenized_sentences(file_path, "english")
|
||||
for file in dst_file_list:
|
||||
file_path = os.path.join(dataset_path, file)
|
||||
create_tokenized_sentences(file_path, "german")
|
||||
src_all_file = "all.de.tok"
|
||||
dst_all_file = "all.en.tok"
|
||||
MergeText(dataset_path, src_file_list, src_all_file)
|
||||
MergeText(dataset_path, dst_file_list, dst_all_file)
|
||||
src_vocab = os.path.join(dataset_path, "vocab.de")
|
||||
dst_vocab = os.path.join(dataset_path, "vocab.en")
|
||||
get_dataset_vocab(os.path.join(dataset_path, src_all_file), src_vocab)
|
||||
get_dataset_vocab(os.path.join(dataset_path, dst_all_file), dst_vocab)
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Seq2Seq construction"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.gru import BidirectionGRU, GRU
|
||||
from src.weight_init import dense_default_state
|
||||
|
||||
class Attention(nn.Cell):
|
||||
'''
|
||||
Attention model
|
||||
'''
|
||||
def __init__(self, config):
|
||||
super(Attention, self).__init__()
|
||||
self.text_len = config.max_length
|
||||
self.attn = nn.Dense(in_channels=config.hidden_size * 3,
|
||||
out_channels=config.hidden_size).to_float(mstype.float16)
|
||||
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(mstype.float16)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.tanh = P.Tanh()
|
||||
self.softmax = P.Softmax()
|
||||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=2)
|
||||
self.cast = P.Cast()
|
||||
def construct(self, hidden, encoder_outputs):
|
||||
'''
|
||||
Attention construction
|
||||
|
||||
Args:
|
||||
hidden(Tensor): hidden state
|
||||
encoder_outputs(Tensor): the output of encoder
|
||||
|
||||
Returns:
|
||||
Tensor, attention output
|
||||
'''
|
||||
hidden = self.expandims(hidden, 1)
|
||||
hidden = self.tile(hidden, (1, self.text_len, 1))
|
||||
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
out = self.concat((hidden, encoder_outputs))
|
||||
out = self.attn(out)
|
||||
energy = self.tanh(out)
|
||||
attention = self.fc(energy)
|
||||
attention = self.squeeze(attention)
|
||||
attention = self.cast(attention, mstype.float32)
|
||||
attention = self.softmax(attention)
|
||||
attention = self.cast(attention, mstype.float16)
|
||||
return attention
|
||||
|
||||
class Encoder(nn.Cell):
|
||||
'''
|
||||
Encoder model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vocab_size = config.src_vocab_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
|
||||
self.rnn = BidirectionGRU(config, is_training=is_training).to_float(mstype.float16)
|
||||
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.transpose = P.Transpose()
|
||||
self.p = P.Print()
|
||||
self.cast = P.Cast()
|
||||
self.text_len = config.max_length
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.tanh = P.Tanh()
|
||||
|
||||
def construct(self, src):
|
||||
'''
|
||||
Encoder construction
|
||||
|
||||
Args:
|
||||
src(Tensor): source sentences
|
||||
|
||||
Returns:
|
||||
output(Tensor): output of rnn
|
||||
hidden(Tensor): output hidden
|
||||
'''
|
||||
embedded = self.embedding(src)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
output, hidden = self.rnn(embedded)
|
||||
hidden = self.fc(hidden)
|
||||
hidden = self.tanh(hidden)
|
||||
return output, hidden
|
||||
|
||||
class Decoder(nn.Cell):
|
||||
'''
|
||||
Decoder model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Decoder, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vocab_size = config.trg_vocab_size
|
||||
self.embedding_size = config.decoder_embedding_size
|
||||
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
|
||||
self.rnn = GRU(config, is_training=is_training).to_float(mstype.float16)
|
||||
self.text_len = config.max_length
|
||||
self.shape = P.Shape()
|
||||
self.transpose = P.Transpose()
|
||||
self.p = P.Print()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
weight, bias = dense_default_state(self.embedding_size+self.hidden_size*3, self.vocab_size)
|
||||
self.fc = nn.Dense(self.embedding_size+self.hidden_size*3, self.vocab_size,
|
||||
weight_init=weight, bias_init=bias).to_float(mstype.float16)
|
||||
self.attention = Attention(config)
|
||||
self.bmm = P.BatchMatMul()
|
||||
self.dropout = nn.Dropout(0.7)
|
||||
self.expandims = P.ExpandDims()
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
'''
|
||||
Decoder construction
|
||||
|
||||
Args:
|
||||
inputs(Tensor): decoder input
|
||||
hidden(Tensor): hidden state
|
||||
encoder_outputs(Tensor): encoder output
|
||||
|
||||
Returns:
|
||||
pred_prob(Tensor): decoder predict probility
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
embedded = self.embedding(inputs)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
attn = self.attention(hidden, encoder_outputs)
|
||||
attn = self.expandims(attn, 1)
|
||||
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
weight = self.bmm(attn, encoder_outputs)
|
||||
weight = self.transpose(weight, (1, 0, 2))
|
||||
emd_con = self.concat((embedded, weight))
|
||||
output, hidden = self.rnn(emd_con)
|
||||
out = self.concat((embedded, output, weight))
|
||||
out = self.squeeze(out)
|
||||
hidden = self.squeeze(hidden)
|
||||
prediction = self.fc(out)
|
||||
prediction = self.dropout(prediction)
|
||||
prediction = self.cast(prediction, mstype.float32)
|
||||
prediction = self.cast(prediction, mstype.float32)
|
||||
pred_prob = self.log_softmax(prediction)
|
||||
pred_prob = self.expandims(pred_prob, 0)
|
||||
return pred_prob, hidden
|
||||
|
||||
class Seq2Seq(nn.Cell):
|
||||
'''
|
||||
Seq2Seq model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Seq2Seq, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.encoder = Encoder(config, is_training=is_training)
|
||||
self.decoder = Decoder(config, is_training=is_training)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.dropout = nn.Dropout()
|
||||
self.shape = P.Shape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.sos = Tensor(np.ones((self.batch_size, 1))*2, mstype.int32)
|
||||
self.select = P.Select()
|
||||
self.text_len = config.max_length
|
||||
|
||||
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
|
||||
'''
|
||||
Seq2Seq construction
|
||||
|
||||
Args:
|
||||
encoder_inputs(Tensor): encoder input sentences
|
||||
decoder_inputs(Tensor): decoder input sentences
|
||||
teacher_force(Tensor): teacher force flag
|
||||
|
||||
Returns:
|
||||
outputs(Tensor): total predict probility
|
||||
'''
|
||||
decoder_input = self.sos
|
||||
encoder_output, hidden = self.encoder(encoder_inputs)
|
||||
decoder_hidden = hidden
|
||||
decoder_outputs = ()
|
||||
for i in range(1, self.text_len):
|
||||
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output)
|
||||
decoder_outputs += (decoder_output,)
|
||||
if self.training:
|
||||
decoder_input_force = decoder_inputs[::, i:i+1]
|
||||
decoder_input_top1, _ = self.argmax(self.squeeze(decoder_output))
|
||||
decoder_input = self.select(teacher_force, decoder_input_force, decoder_input_top1)
|
||||
else:
|
||||
decoder_input, _ = self.argmax(self.squeeze(decoder_output))
|
||||
outputs = self.concat(decoder_outputs)
|
||||
return outputs
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tokenization utilities."""
|
||||
|
||||
import sys
|
||||
import collections
|
||||
import unicodedata
|
||||
|
||||
def convert_to_printable(text):
|
||||
"""
|
||||
Converts `text` to a printable coding format.
|
||||
"""
|
||||
if sys.version_info[0] == 3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
||||
raise ValueError("Only supported when running on Python3.")
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""
|
||||
Converts `text` to Unicode format.
|
||||
"""
|
||||
if sys.version_info[0] == 3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
||||
if sys.version_info[0] == 2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
if isinstance(text, unicode):
|
||||
return text
|
||||
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
|
||||
raise ValueError("Only supported when running on Python2 or Python3.")
|
||||
|
||||
|
||||
def load_vocab_file(vocab_file):
|
||||
"""
|
||||
Loads a vocabulary file and turns into a {token:id} dictionary.
|
||||
"""
|
||||
vocab_dict = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as vocab:
|
||||
while True:
|
||||
token = convert_to_unicode(vocab.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab_dict[token] = index
|
||||
index += 1
|
||||
return vocab_dict
|
||||
|
||||
|
||||
def convert_by_vocab_dict(vocab_dict, items):
|
||||
"""
|
||||
Converts a sequence of [tokens|ids] according to the vocab dict.
|
||||
"""
|
||||
output = []
|
||||
for item in items:
|
||||
if item in vocab_dict:
|
||||
output.append(vocab_dict[item])
|
||||
else:
|
||||
output.append(vocab_dict["<unk>"])
|
||||
return output
|
||||
|
||||
|
||||
class WhiteSpaceTokenizer():
|
||||
"""
|
||||
Whitespace tokenizer.
|
||||
"""
|
||||
def __init__(self, vocab_file):
|
||||
self.vocab_dict = load_vocab_file(vocab_file)
|
||||
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()}
|
||||
|
||||
def _is_whitespace_char(self, char):
|
||||
"""
|
||||
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here).
|
||||
"""
|
||||
if char in (" ", "\t", "\n", "\r"):
|
||||
return True
|
||||
uni = unicodedata.category(char)
|
||||
if uni == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_control_char(self, char):
|
||||
"""
|
||||
Checks if it is a control character.
|
||||
"""
|
||||
if char in ("\t", "\n", "\r"):
|
||||
return False
|
||||
uni = unicodedata.category(char)
|
||||
if uni in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""
|
||||
Remove invalid characters and cleanup whitespace.
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or self._is_control_char(char):
|
||||
continue
|
||||
if self._is_whitespace_char(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _whitespace_tokenize(self, text):
|
||||
"""
|
||||
Clean whitespace and split text into tokens.
|
||||
"""
|
||||
text = text.strip()
|
||||
text = text.lower()
|
||||
if text.endswith("."):
|
||||
text = text.replace(".", " .")
|
||||
if not text:
|
||||
tokens = []
|
||||
else:
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Tokenizes text.
|
||||
"""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
tokens = self._whitespace_tokenize(text)
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab_dict(self.vocab_dict, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab_dict(self.inv_vocab_dict, ids)
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""weight init"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
def gru_default_state(batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
|
||||
'''Weight init for gru cell'''
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
weight_i = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), name='weight_i')
|
||||
weight_h = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (hidden_size, 3*hidden_size)).astype(np.float32)), name='weight_h')
|
||||
bias_i = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_i')
|
||||
bias_h = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_h')
|
||||
init_h = Tensor(np.zeros((batch_size, hidden_size)).astype(np.float16))
|
||||
return weight_i, weight_h, bias_i, bias_h, init_h
|
||||
|
||||
def dense_default_state(in_channel, out_channel):
|
||||
'''Weight init for dense cell'''
|
||||
stdv = 1 / math.sqrt(in_channel)
|
||||
weight = Tensor(np.random.uniform(-stdv, stdv, (out_channel, in_channel)).astype(np.float32))
|
||||
bias = Tensor(np.random.uniform(-stdv, stdv, (out_channel)).astype(np.float32))
|
||||
return weight, bias
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train script"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.nn.optim import Adam
|
||||
from src.config import config
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="GRU training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False)
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.rank_id = rank_id
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = get_ms_timestamp()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
global time_stamp_first
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, loss: {}, overflow: {}, loss_scale: {}".format(
|
||||
time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs[0].asnumpy()),
|
||||
str(cb_params.net_outputs[1].asnumpy()),
|
||||
str(cb_params.net_outputs[2].asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.run_distribute:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
|
||||
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config)
|
||||
network = GRUWithLossCell(network)
|
||||
lr = dynamic_lr(config, dataset_size)
|
||||
opt = Adam(network.trainable_params(), learning_rate=lr)
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale_value,
|
||||
scale_factor=config.scale_factor,
|
||||
scale_window=config.scale_window)
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
netwithgrads = GRUTrainOneStepWithLossScaleCell(network, opt, update_cell)
|
||||
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
#Save Checkpoint
|
||||
if config.save_checkpoint:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank_id))
|
||||
cb += [ckpt_cb]
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in New Issue