modify gnmt_v2 net for clould

Merge pull request  from zhanghuiyao/gnmt_v2_clould_new
This commit is contained in:
i-robot 2021-07-02 07:47:44 +00:00 committed by Gitee
commit fb6b56ccd6
31 changed files with 888 additions and 513 deletions

View File

@ -74,20 +74,89 @@ The process of GNMTv2 performing the text translation task is as follows:
After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
cd ./scripts
sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
- running on Ascend
# run distributed training example
cd ./scripts
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
```bash
# run training example
cd ./scripts
sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
# run evaluation example
cd ./scripts
sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
```
# run distributed training example
cd ./scripts
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
# run evaluation example
cd ./scripts
sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
```
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
```bash
# Train 1p/8p on ModelArts with Ascend
# (1) Add "config_path=/path_to_code/default_config.yaml" on the website UI interface.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "pre_train_dataset='/cache/data/wmt16_de_en/train.tok.clean.bpe.32000.en.mindrecord'" on default_config.yaml file.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "pre_train_dataset=/cache/data/wmt16_de_en/train.tok.clean.bpe.32000.en.mindrecord" on the website UI interface.
# Add other parameters on the website UI interface.
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset.)
# (4) Set the code directory to "/path/gnmt_v2" on the website UI interface.
# (5) Set the startup file to "train.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
#
# Eval 1p on ModelArts with Ascend
# (1) Add "config_path=/path_to_code/default_test_config.yaml" on the website UI interface.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on default_test_config.yaml file.
# Set "pre_train_dataset='/cache/data/wmt16_de_en/train.tok.clean.bpe.32000.en.mindrecord'" on default_test_config.yaml file.
# Set "test_dataset='/cache/data/wmt16_de_en/newstest2014.en.mindrecord'" on default_test_config.yaml file.
# Set "vocab='/cache/data/wmt16_de_en/vocab.bpe.32000'" on default_test_config.yaml file.
# Set "bpe_codes='/cache/data/wmt16_de_en/bpe.32000'" on default_test_config.yaml file.
# Set "test_tgt='/cache/data/wmt16_de_en/newstest2014.de'" on default_test_config.yaml file.
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_test_config.yaml file.
# Set "existed_ckpt='/cache/checkpoint_path/model.ckpt'" on default_test_config.yaml file.
# Set other parameters on default_test_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "pre_train_dataset=/cache/data/wmt16_de_en/train.tok.clean.bpe.32000.en.mindrecord" on the website UI interface.
# Add "test_dataset=/cache/data/wmt16_de_en/newstest2014.en.mindrecord" on the website UI interface.
# Add "vocab=/cache/data/wmt16_de_en/vocab.bpe.32000" on the website UI interface.
# Add "bpe_codes=/cache/data/wmt16_de_en/bpe.32000" on the website UI interface.
# Add "test_tgt=/cache/data/wmt16_de_en/newstest2014.de" on the website UI interface.
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
# Add "existed_ckpt=/cache/checkpoint_path/model.ckpt" on the website UI interface.
# Add other parameters on the website UI interface.
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset.)
# (4) Set the code directory to "/path/gnmt_v2" on the website UI interface.
# (5) Set the startup file to "eval.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
#
# Export 1p on ModelArts with Ascend
# (1) Add "config_path=/path_to_code/default_test_config.yaml" on the website UI interface.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on default_test_config.yaml file.
# Set "vocab_file='/cache/data/wmt16_de_en/vocab.bpe.32000'" on default_test_config.yaml file.
# Set "bpe_codes='/cache/data/wmt16_de_en/bpe.32000'" on default_test_config.yaml file.
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on default_test_config.yaml file.
# Set "existed_ckpt='/cache/checkpoint_path/model.ckpt'" on default_test_config.yaml file.
# Set other parameters on default_test_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "vocab_file='/cache/data/wmt16_de_en/vocab.bpe.32000'" on the website UI interface.
# Add "bpe_codes='/cache/data/wmt16_de_en/bpe.32000'" on the website UI interface.
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
# Add "existed_ckpt='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
# Add other parameters on the website UI interface.
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset.)
# (4) Set the code directory to "/path/gnmt_v2" on the website UI interface.
# (5) Set the startup file to "export.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
```
# [Script Description](#contents)
@ -96,11 +165,12 @@ The GNMT network script and code result are as follows:
```text
├── gnmt
├── README.md // Introduction of GNMTv2 model.
├── config
│ ├──__init__.py // User interface.
│ ├──config.py // Configuration instance definition.
│ ├──config.json // Configuration file for pre-train or finetune.
│ ├──config_test.json // Configuration file for test.
├── model_utils
│ ├──__init__.py // module init file
│ ├──config.py // Parse arguments
│ ├──device_adapter.py // Device adapter for ModelArts
│ ├──local_adapter.py // Local adapter
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
├── src
│ ├──__init__.py // User interface.
│ ├──dataset
@ -138,10 +208,13 @@ The GNMT network script and code result are as follows:
│ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
│ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
│ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
├── default_config.yaml // Configurations for train
├── default_test_config.yaml // Configurations for eval
├── create_dataset.py // Dataset preparation.
├── eval.py // Infer API entry.
├── export.py // Export checkpoint file into air models.
├── mindspore_hub_conf.py // Hub config.
├── pip-requirements.txt // Requirements of third party package for modelarts.
├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry.
```
@ -165,7 +238,7 @@ You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/b
## Configuration File
The JSON file in the `config/` directory is the template configuration file.
The YAML file in the `./default_config.yaml` directory is the template configuration file.
Almost all required options and parameters can be easily assigned, including the training platform, model configuration, and optimizer parameters.
- config for GNMTv2
@ -185,11 +258,11 @@ Almost all required options and parameters can be easily assigned, including the
'existed_ckpt': "" # the absolute full path to save the checkpoint file
```
For more configuration details, please refer the script `config/config.py` file.
For more configuration details, please refer the script `./default_config.yaml` file.
## Training Process
For a pre-trained model, configure the following options in the `config/config.json` file:
For a pre-trained model, configure the following options in the `./default_config.yaml` file:
- Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
@ -219,7 +292,7 @@ Currently, inconsecutive device IDs are not supported in `scripts/run_distribute
## Inference Process
For inference using a trained model on multiple hardware platforms, such as Ascend 910.
Set options in `config/config_test.json`.
Set options in `./default_config.yaml`.
Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.

View File

@ -1,49 +0,0 @@
{
"dataset_config": {
"random_seed": 50,
"epochs": 6,
"batch_size": 128,
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord",
"fine_tune_dataset": null,
"valid_dataset": null,
"dataset_sink_mode": true
},
"model_config": {
"seq_length": 51,
"vocab_size": 32320,
"hidden_size": 1024,
"num_hidden_layers": 4,
"intermediate_size": 4096,
"hidden_dropout_prob": 0.2,
"attention_dropout_prob": 0.2,
"initializer_range": 0.1,
"label_smoothing": 0.1,
"beam_width": 2,
"length_penalty_weight": 0.6,
"max_decode_length": 50
},
"loss_scale_config": {
"init_loss_scale": 65536,
"loss_scale_factor": 2,
"scale_window": 1000
},
"learn_rate_config": {
"optimizer": "adam",
"lr": 2e-3,
"lr_scheduler": "WarmupMultiStepLR",
"lr_scheduler_power": 0.5,
"warmup_lr_remain_steps": 0.666,
"warmup_lr_decay_interval": -1,
"decay_steps": 4,
"decay_start_step": -1,
"warmup_steps": 200,
"min_lr": 1e-6
},
"checkpoint_options": {
"existed_ckpt": "",
"save_ckpt_steps": 3452,
"keep_ckpt_max": 6,
"ckpt_prefix": "gnmt",
"ckpt_path": "text_translation"
}
}

View File

@ -1,232 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Configuration class for GNMT."""
import os
import json
import copy
from typing import List
import mindspore.common.dtype as mstype
def _is_dataset_file(file: str):
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
def _get_files_from_dir(folder: str):
_files = []
for file in os.listdir(folder):
if _is_dataset_file(file):
_files.append(os.path.join(folder, file))
return _files
def get_source_list(folder: str) -> List:
"""
Get file list from a folder.
Returns:
list, file list.
"""
_list = []
if not folder:
return _list
if os.path.isdir(folder):
_list = _get_files_from_dir(folder)
else:
if _is_dataset_file(folder):
_list.append(folder)
return _list
PARAM_NODES = {"dataset_config",
"model_config",
"loss_scale_config",
"learn_rate_config",
"checkpoint_options"}
class GNMTConfig:
"""
Configuration for `GNMT`.
Args:
random_seed (int): Random seed, it can be changed.
epochs (int): Epoch number.
batch_size (int): Batch size of input dataset.
pre_train_dataset (str): Path of pre-training dataset file or folder.
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
test_dataset (str): Path of test dataset file or folder.
valid_dataset (str): Path of validation dataset file or folder.
dataset_sink_mode (bool): Whether enable dataset sink mode.
seq_length (int): Length of input sequence.
vocab_size (int): The shape of each embedding vector.
hidden_size (int): Size of embedding, attention, dim.
num_hidden_layers (int): Encoder, Decoder layers.
intermediate_size (int): Size of intermediate layer in the Transformer
encoder/decoder cell.
hidden_act (str): Activation function used in the Transformer encoder/decoder
cell.
hidden_dropout_prob (float): The dropout probability for hidden outputs.
attention_dropout_prob (float): The dropout probability for Attention module.
initializer_range (float): Initialization value of TruncatedNormal.
label_smoothing (float): Label smoothing setting.
beam_width (int): Beam width for beam search in inferring.
length_penalty_weight (float): Penalty for sentence length.
max_decode_length (int): Max decode length for inferring.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset.
init_loss_scale (int): Initialized loss scale.
loss_scale_factor (int): Loss scale factor.
scale_window (int): Window size of loss scale.
lr_scheduler (str): Learning rate scheduler. Please see the Note as follow.
optimizer (str): Optimizer for training, e.g. Adam, Lamb, momentum. Default: Adam.
lr (float): Initial learning rate.
min_lr (float): Minimum learning rate.
decay_steps (int): Decay steps.
lr_scheduler_power(float): A value used to calculate decayed learning rate.
warmup_lr_remain_steps (int or float): Start decay at 'remain_steps' iteration.
warmup_lr_decay_interval (int):interval between LR decay steps.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
existed_ckpt (str): Using existed checkpoint to keep training or not.
save_ckpt_steps (int): Interval of saving ckpt.
keep_ckpt_max (int): Max ckpt files number.
ckpt_prefix (str): Prefix of ckpt file.
ckpt_path (str): Checkpoints save path.
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
is wanted.
dtype (mstype): Data type of the input.
Note:
There are three types of learning rate scheduler, square root scheduler, polynomial
decay scheduler and warmup multistep learning rate scheduler.
In square root scheduler, the following parameters can be used, lr, decay_start_step,
warmup_steps and min_lr.
In polynomial decay scheduler, the following parameters can be used, lr, min_lr, decay_steps,
warmup_steps, lr_scheduler_power.
In warmmup multistep learning rate scheduler, the following parameters can be used, lr, warmup_steps,
warmup_lr_remain_steps, warmup_lr_decay_interval, decay_steps, lr_scheduler_power.
"""
def __init__(self,
random_seed=50,
epochs=6, batch_size=128,
pre_train_dataset: str = None,
fine_tune_dataset: str = None,
test_dataset: str = None,
valid_dataset: str = None,
dataset_sink_mode=True,
seq_length=51, vocab_size=32320, hidden_size=1024,
num_hidden_layers=4, intermediate_size=4096,
hidden_act="tanh",
hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
initializer_range=0.1,
label_smoothing=0.1,
beam_width=2,
length_penalty_weight=0.6,
max_decode_length=50,
input_mask_from_dataset=False,
init_loss_scale=65536,
loss_scale_factor=2, scale_window=1000,
lr_scheduler="WarmupMultiStepLR",
optimizer="adam",
lr=2e-3, min_lr=1e-6,
decay_steps=4, lr_scheduler_power=0.5,
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
decay_start_step=-1, warmup_steps=200,
existed_ckpt="", save_ckpt_steps=3452, keep_ckpt_max=6,
ckpt_prefix="gnmt", ckpt_path: str = None,
save_graphs=False,
dtype=mstype.float32):
self.save_graphs = save_graphs
self.random_seed = random_seed
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
self.test_dataset = get_source_list(test_dataset) # type: List[str]
if not isinstance(epochs, int) and epochs < 0:
raise ValueError("`epoch` must be type of int.")
self.epochs = epochs
self.dataset_sink_mode = dataset_sink_mode
self.ckpt_path = ckpt_path
self.keep_ckpt_max = keep_ckpt_max
self.save_ckpt_steps = save_ckpt_steps
self.ckpt_prefix = ckpt_prefix
self.existed_ckpt = existed_ckpt
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_dropout_prob = attention_dropout_prob
self.initializer_range = initializer_range
self.label_smoothing = label_smoothing
self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight
self.max_decode_length = max_decode_length
self.input_mask_from_dataset = input_mask_from_dataset
self.compute_type = mstype.float16
self.dtype = dtype
self.scale_window = scale_window
self.loss_scale_factor = loss_scale_factor
self.init_loss_scale = init_loss_scale
self.optimizer = optimizer
self.lr = lr
self.lr_scheduler = lr_scheduler
self.min_lr = min_lr
self.lr_scheduler_power = lr_scheduler_power
self.warmup_lr_remain_steps = warmup_lr_remain_steps
self.warmup_lr_decay_interval = warmup_lr_decay_interval
self.decay_steps = decay_steps
self.decay_start_step = decay_start_step
self.warmup_steps = warmup_steps
@classmethod
def from_dict(cls, json_object: dict):
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
_params = {}
for node in PARAM_NODES:
for key in json_object[node]:
_params[key] = json_object[node][key]
return cls(**_params)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `TransformerConfig` from a json file of parameters."""
with open(json_file, "r") as reader:
return cls.from_dict(json.load(reader))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

View File

@ -1,50 +0,0 @@
{
"dataset_config": {
"random_seed": 50,
"epochs": 6,
"batch_size": 128,
"pre_train_dataset": null,
"fine_tune_dataset": null,
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord",
"valid_dataset": null,
"dataset_sink_mode": true
},
"model_config": {
"seq_length": 107,
"vocab_size": 32320,
"hidden_size": 1024,
"num_hidden_layers": 4,
"intermediate_size": 4096,
"hidden_dropout_prob": 0.2,
"attention_dropout_prob": 0.2,
"initializer_range": 0.1,
"label_smoothing": 0.1,
"beam_width": 2,
"length_penalty_weight": 0.6,
"max_decode_length": 80
},
"loss_scale_config": {
"init_loss_scale": 65536,
"loss_scale_factor": 2,
"scale_window": 1000
},
"learn_rate_config": {
"optimizer": "adam",
"lr": 2e-3,
"lr_scheduler": "WarmupMultiStepLR",
"lr_scheduler_power": 0.5,
"warmup_lr_remain_steps": 0.666,
"warmup_lr_decay_interval": -1,
"decay_steps": 4,
"decay_start_step": -1,
"warmup_steps": 200,
"min_lr": 1e-6
},
"checkpoint_options": {
"existed_ckpt": "/home/workspace/gnmt_v2/gnmt-6_3452.ckpt",
"save_ckpt_steps": 3452,
"keep_ckpt_max": 6,
"ckpt_prefix": "gnmt",
"ckpt_path": "text_translation"
}
}

View File

@ -0,0 +1,85 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# dataset_config
random_seed: 50
epochs: 6
batch_size: 128
pre_train_dataset: "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord"
fine_tune_dataset: ""
test_dataset: ""
valid_dataset: ""
dataset_sink_mode: true
input_mask_from_dataset: False
# model_config
seq_length: 51
vocab_size: 32320
hidden_size: 1024
num_hidden_layers: 4
intermediate_size: 4096
hidden_dropout_prob: 0.2
attention_dropout_prob: 0.2
initializer_range: 0.1
label_smoothing: 0.1
beam_width: 2
length_penalty_weight: 0.6
max_decode_length: 50
# loss_scale_config
init_loss_scale: 65536
loss_scale_factor: 2
scale_window: 1000
# learn_rate_config
optimizer: "adam"
lr: 0.002 # 2e-3
lr_scheduler: "WarmupMultiStepLR"
lr_scheduler_power: 0.5
warmup_lr_remain_steps: 0.666
warmup_lr_decay_interval: -1
decay_steps: 4
decay_start_step: -1
warmup_steps: 200
min_lr: 0.000001 #1e-6
# checkpoint_options
existed_ckpt: ""
save_ckpt_steps: 3452
keep_ckpt_max: 6
ckpt_prefix: "gnmt"
ckpt_path: "text_translation"
# export option
file_name: "gnmt_v2"
file_format: "AIR"
vocab_file: ""
bpe_codes: ""
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: 'Target device type'
file_name: "output file name."
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"
infer_config: "gnmt_v2 config file"
vocab_file: "existed checkpoint address."
bpe_codes: "bpe codes to use."

View File

@ -0,0 +1,94 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# dataset_config
random_seed: 50
epochs: 6
batch_size: 128
pre_train_dataset: ""
fine_tune_dataset: ""
test_dataset: "/home/workspace/dataset_menu/newstest2014.en.mindrecord"
valid_dataset: ""
dataset_sink_mode: true
input_mask_from_dataset: False
# model_config
seq_length: 107
vocab_size: 32320
hidden_size: 1024
num_hidden_layers: 4
intermediate_size: 4096
hidden_dropout_prob: 0.2
attention_dropout_prob: 0.2
initializer_range: 0.1
label_smoothing: 0.1
beam_width: 2
length_penalty_weight: 0.6
max_decode_length: 80
# loss_scale_config
init_loss_scale: 65536
loss_scale_factor: 2
scale_window: 1000
# learn_rate_config
optimizer: "adam"
lr: 0.002 # 2e-3
lr_scheduler: "WarmupMultiStepLR"
lr_scheduler_power: 0.5
warmup_lr_remain_steps: 0.666
warmup_lr_decay_interval: -1
decay_steps: 4
decay_start_step: -1
warmup_steps: 200
min_lr: 0.000001 # 1e-6
# checkpoint_options
existed_ckpt: "/home/workspace/gnmt_v2/gnmt-6_3452.ckpt"
save_ckpt_steps: 3452
keep_ckpt_max: 6
ckpt_prefix: "gnmt"
ckpt_path: "text_translation"
# eval option
bpe_codes: ""
test_tgt: ""
vocab: ""
output: "./output.npz"
# export option
file_name: "gnmt_v2"
file_format: "AIR"
vocab_file: ""
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: 'Target device type'
# eval option
bpe_codes: "bpe codes to use."
test_tgt: "data file of the test target"
output: "result file path."
file_name: "output file name."
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"
infer_config: "gnmt_v2 config file"
vocab_file: "existed checkpoint address."

View File

@ -13,65 +13,87 @@
# limitations under the License.
# ============================================================================
"""Evaluation api."""
import argparse
import pickle
import os
import time
from mindspore.common import dtype as mstype
from config import GNMTConfig
from src.gnmt_model import infer
from src.gnmt_model.bleu_calculate import bleu_calculate
from src.dataset.tokenizer import Tokenizer
from src.utils.get_config import get_config
parser = argparse.ArgumentParser(description='gnmt')
parser.add_argument("--config", type=str, required=True,
help="model config json file path.")
parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--bpe_codes", type=str, required=True,
help="bpe codes to use.")
parser.add_argument("--test_tgt", type=str, required=True,
default=None,
help="data file of the test target")
parser.add_argument("--output", type=str, required=False,
default="./output.npz",
help="result file path.")
from model_utils.config import config as default_config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, default_config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if default_config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(default_config.data_path, default_config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(default_config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
if __name__ == '__main__':
args, _ = parser.parse_known_args()
_check_args(args.config)
_config = get_config(args.config)
_config.test_dataset = args.test_dataset
_config.existed_ckpt = args.existed_ckpt
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
'''run eval.'''
_config = get_config(default_config)
result = infer(_config)
with open(args.output, "wb") as f:
with open(_config.output, "wb") as f:
pickle.dump(result, f, 1)
result_npy_addr = args.output
vocab = args.vocab
bpe_codes = args.bpe_codes
test_tgt = args.test_tgt
result_npy_addr = _config.output
vocab = _config.vocab
bpe_codes = _config.bpe_codes
test_tgt = _config.test_tgt
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
print(f"BLEU scores is :{scores}")
if __name__ == '__main__':
run_eval()

View File

@ -13,48 +13,85 @@
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import argparse
import os
import time
import numpy as np
from mindspore import Tensor, context, Parameter
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export
from config import GNMTConfig
from src.gnmt_model.gnmt import GNMT
from src.gnmt_model.gnmt_for_infer import GNMTInferCell
from src.utils import zero_weight
from src.utils.load_weights import load_infer_weights
from src.utils.get_config import get_config
parser = argparse.ArgumentParser(description="gnmt_v2 export")
parser.add_argument("--file_name", type=str, default="gnmt_v2", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument('--infer_config', type=str, required=True, help='gnmt_v2 config file')
parser.add_argument("--existed_ckpt", type=str, required=True, help="existed checkpoint address.")
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
parser.add_argument("--bpe_codes", type=str, required=True, help="bpe codes to use.")
args = parser.parse_args()
from model_utils.config import config as default_config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=False)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, default_config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if default_config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(default_config.data_path, default_config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(default_config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
default_config.file_name = os.path.join(default_config.output_path, default_config.file_name)
def get_config(config_file):
tfm_config = GNMTConfig.from_json_file(config_file)
tfm_config.compute_type = mstype.float16
tfm_config.dtype = mstype.float32
return tfm_config
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
'''run export.'''
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend",
reserve_class_name_in_scope=False)
if __name__ == '__main__':
config = get_config(args.infer_config)
config.existed_ckpt = args.existed_ckpt
vocab = args.vocab_file
bpe_codes = args.bpe_codes
config = get_config(default_config)
tfm_model = GNMT(config=config,
is_training=False,
@ -94,4 +131,8 @@ if __name__ == '__main__':
source_ids = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
source_mask = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
export(tfm_infer, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)
export(tfm_infer, source_ids, source_mask, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
run_export()

View File

@ -13,26 +13,15 @@
# limitations under the License.
# ============================================================================
"""hub config."""
import os
import mindspore.common.dtype as mstype
from config import GNMTConfig
from src.gnmt_model import GNMTNetworkWithLoss, GNMT
from src.utils.get_config import get_config
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
from model_utils.config import config as default_config
def create_network(name, *args, **kwargs):
"""create gnmt network."""
config = get_config(default_config)
if name == "gnmt":
default_config_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "config/config.json")
config_path = kwargs.get("config", default_config_path)
config = get_config(config_path)
is_training = kwargs.get("is_training", False)
if is_training:
return GNMTNetworkWithLoss(config, is_training=is_training, *args)

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GNMTv2 model configuration."""
from .config import GNMTConfig
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"GNMTConfig"
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,5 @@
numpy
pyyaml
subword-nmt==0.3.7
sacrebleu==1.4.14
sacremoses==0.0.35

View File

@ -1,4 +1,5 @@
numpy
pyyaml
subword-nmt==0.3.7
sacrebleu==1.4.14
sacremoses==0.0.35

View File

@ -43,12 +43,15 @@ do
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
cp ../../*.py .
cp ../../*.yaml .
cp -r ../../src .
cp -r ../../config .
cp -r ../../model_utils .
export RANK_ID=$i
export DEVICE_ID=$i
config_path="${current_exec_path}/device${i}/default_config.yaml"
echo "config path is : ${config_path}"
python ../../train.py \
--config=${current_exec_path}/device${i}/config/config.json \
--config_path=$config_path \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit
done

View File

@ -46,13 +46,18 @@ then
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval
cp -r ../config ./eval
cp -r ../model_utils ./eval
cd ./eval || exit
echo "start for evaluation"
env > env.log
config_path="${current_exec_path}/eval/default_test_config.yaml"
echo "config path is : ${config_path}"
python eval.py \
--config=${current_exec_path}/eval/config/config_test.json \
--config_path=$config_path \
--test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \

View File

@ -35,12 +35,17 @@ then
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train
cp -r ../config ./train
cp -r ../model_utils ./train
cd ./train || exit
echo "start for training"
env > env.log
config_path="${current_exec_path}/train/default_test_config.yaml"
echo "config path is : ${config_path}"
python train.py \
--config=${current_exec_path}/train/config/config.json \
--config_path=$config_path \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 &
cd ..

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""GNMTv2 Init."""
from config.config import GNMTConfig
from .gnmt import GNMT
from .attention import BahdanauAttention
from .gnmt_for_train import GNMTTraining, LabelSmoothedCrossEntropyCriterion, \
@ -29,6 +28,5 @@ __all__ = [
"GNMTNetworkWithLoss",
"GNMT",
"BahdanauAttention",
"GNMTConfig",
"bleu_calculate"
]

View File

@ -26,7 +26,7 @@ class CreateAttentionPaddingsFromInputPaddings(nn.Cell):
Create attention mask according to input mask.
Args:
config (GNMTConfig): Config class.
config: Config class.
Returns:
Tensor, shape of (N, T, T).

View File

@ -20,7 +20,6 @@ from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .dynamic_rnn import DynamicRNNNet
from .create_attention import RecurrentAttention
@ -45,7 +44,7 @@ class GNMTDecoder(nn.Cell):
"""
def __init__(self,
config: GNMTConfig,
config,
is_training: bool,
use_one_hot_embeddings: bool = False,
initializer_range=0.1,

View File

@ -72,7 +72,7 @@ class BeamDecoderStep(nn.Cell):
Multi-layer transformer decoder step.
Args:
config (GNMTConfig): The config of Transformer.
config: The config of Transformer.
"""
def __init__(self,

View File

@ -17,7 +17,6 @@ from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .dynamic_rnn import DynamicRNNNet
@ -26,7 +25,7 @@ class GNMTEncoder(nn.Cell):
Implements of GNMT encoder.
Args:
config (GNMTConfig): Configuration of GNMT network.
config: Configuration of GNMT network.
is_training (bool): Whether to train.
compute_type (mstype): Mindspore data type.
@ -35,7 +34,7 @@ class GNMTEncoder(nn.Cell):
"""
def __init__(self,
config: GNMTConfig,
config,
is_training: bool,
compute_type=mstype.float32):
super(GNMTEncoder, self).__init__()

View File

@ -19,7 +19,6 @@ from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from config.config import GNMTConfig
from .embedding import EmbeddingLookup
from .create_attn_padding import CreateAttentionPaddingsFromInputPaddings
from .beam_search import BeamSearchDecoder, TileBeam
@ -36,7 +35,7 @@ class GNMT(nn.Cell):
In GNMT, we define T = src_max_len, T' = tgt_max_len.
Args:
config (GNMTConfig): Model config.
config: Model config.
is_training (bool): Whether is training.
use_one_hot_embeddings (bool): Whether use one-hot embedding.
@ -45,7 +44,7 @@ class GNMT(nn.Cell):
"""
def __init__(self,
config: GNMTConfig,
config,
is_training: bool = False,
use_one_hot_embeddings: bool = False,
use_positional_embedding: bool = True,

View File

@ -67,7 +67,7 @@ def gnmt_infer(config, dataset):
Run infer with GNMT.
Args:
config (GNMTConfig): Config.
config: Config.
dataset (Dataset): Dataset.
Returns:
@ -161,7 +161,7 @@ def infer(config):
GNMT infer api.
Args:
config (GNMTConfig): Config.
config: Config.
Returns:
list, result with

View File

@ -34,7 +34,7 @@ class PredLogProbs(nn.Cell):
Get log probs.
Args:
config (GNMTConfig): The config of GNMT.
config: The config of GNMT.
Returns:
Tensor, log softmax output.
@ -67,7 +67,7 @@ class GNMTTraining(nn.Cell):
GNMT training network.
Args:
config (GNMTConfig): The config of GNMT.
config: The config of GNMT.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
@ -102,7 +102,7 @@ class LabelSmoothedCrossEntropyCriterion(nn.Cell):
Label Smoothed Cross-Entropy Criterion.
Args:
config (GNMTConfig): The config of GNMT.
config: The config of GNMT.
Returns:
Tensor, final loss.

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Get Config."""
import os
from typing import List
import mindspore.common.dtype as mstype
def _is_dataset_file(file: str):
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
def _get_files_from_dir(folder: str):
_files = []
for file in os.listdir(folder):
if _is_dataset_file(file):
_files.append(os.path.join(folder, file))
return _files
def get_source_list(folder: str) -> List:
"""
Get file list from a folder.
Returns:
list, file list.
"""
_list = []
if not folder:
return _list
if os.path.isdir(folder):
_list = _get_files_from_dir(folder)
else:
if _is_dataset_file(folder):
_list.append(folder)
return _list
def get_config(config):
'''get config.'''
config.pre_train_dataset = None if config.pre_train_dataset == "" else config.pre_train_dataset
config.fine_tune_dataset = None if config.fine_tune_dataset == "" else config.fine_tune_dataset
config.valid_dataset = None if config.valid_dataset == "" else config.valid_dataset
config.test_dataset = None if config.test_dataset == "" else config.test_dataset
if hasattr(config, 'test_tgt'):
config.test_tgt = None if config.test_tgt == "" else config.test_tgt
config.pre_train_dataset = get_source_list(config.pre_train_dataset)
config.fine_tune_dataset = get_source_list(config.fine_tune_dataset)
config.valid_dataset = get_source_list(config.valid_dataset)
config.test_dataset = get_source_list(config.test_dataset)
if not isinstance(config.epochs, int) and config.epochs < 0:
raise ValueError("`epoch` must be type of int.")
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config

View File

@ -24,7 +24,7 @@ def load_infer_weights(config):
Load weights from ckpt or npz.
Args:
config (GNMTConfig): Config.
config: Config.
Returns:
dict, weights.

View File

@ -16,7 +16,6 @@
import time
from mindspore.train.callback import Callback
from config import GNMTConfig
class LossCallBack(Callback):
@ -34,7 +33,7 @@ class LossCallBack(Callback):
time_stamp_init = False
time_stamp_first = 0
def __init__(self, config: GNMTConfig, per_print_times: int = 1):
def __init__(self, config, per_print_times: int = 1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Train api."""
import os
import argparse
import time
import numpy as np
import mindspore.common.dtype as mstype
@ -30,39 +30,19 @@ from mindspore.communication import management as MultiAscend
from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed
from config import GNMTConfig
from src.dataset import load_dataset
from src.gnmt_model import GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
from src.utils import LossCallBack
from src.utils import one_weight, weight_variable
from src.utils.lr_scheduler import square_root_schedule, polynomial_decay_scheduler, Warmup_MultiStepLR_scheduler
from src.utils.optimizer import Adam
from src.utils.get_config import get_config
parser = argparse.ArgumentParser(description='GNMT train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
from model_utils.config import config as default_config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
raise RuntimeError("`DEVICE_ID` can not be None.")
device_id = int(device_id)
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=True,
device_id=device_id)
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def _train(model, config: GNMTConfig,
def _train(model, config,
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
callbacks: list = None):
"""
@ -70,7 +50,7 @@ def _train(model, config: GNMTConfig,
Args:
model (Model): MindSpore model instance.
config (GNMTConfig): Config of mass model.
config: Config of mass model.
pre_training_dataset (Dataset): Pre-training dataset.
fine_tune_dataset (Dataset): Fine-tune dataset.
test_dataset (Dataset): Test dataset.
@ -177,7 +157,7 @@ def _get_optimizer(config, network, lr):
return optimizer
def _build_training_pipeline(config: GNMTConfig,
def _build_training_pipeline(config,
pre_training_dataset=None,
fine_tune_dataset=None,
test_dataset=None):
@ -185,7 +165,7 @@ def _build_training_pipeline(config: GNMTConfig,
Build training pipeline.
Args:
config (GNMTConfig): Config of mass model.
config: Config of mass model.
pre_training_dataset (Dataset): Pre-training dataset.
fine_tune_dataset (Dataset): Fine-tune dataset.
test_dataset (Dataset): Test dataset.
@ -259,12 +239,12 @@ def _setup_parallel_env():
)
def train_parallel(config: GNMTConfig):
def train_parallel(config):
"""
Train model with multi ascend chips.
Args:
config (GNMTConfig): Config for MASS model.
config: Config for MASS model.
"""
_setup_parallel_env()
print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
@ -297,12 +277,12 @@ def train_parallel(config: GNMTConfig):
test_dataset=test_dataset)
def train_single(config: GNMTConfig):
def train_single(config):
"""
Train model on single device.
Args:
config (GNMTConfig): Config for model.
config: Config for model.
"""
print(" | Starting training on single device.")
@ -322,22 +302,79 @@ def train_single(config: GNMTConfig):
test_dataset=test_dataset)
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, default_config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if default_config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(default_config.data_path, default_config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(default_config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
default_config.ckpt_path = os.path.join(default_config.output_path, default_config.ckpt_path)
if __name__ == '__main__':
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''run train.'''
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
raise RuntimeError("`DEVICE_ID` can not be None.")
device_id = int(device_id)
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend",
reserve_class_name_in_scope=True, device_id=device_id)
_rank_size = os.getenv('RANK_SIZE')
args, _ = parser.parse_known_args()
_check_args(args.config)
_config = get_config(args.config)
_config.pre_train_dataset = args.pre_train_dataset
_config = get_config(default_config)
_config.pre_train_dataset = default_config.pre_train_dataset
set_seed(_config.random_seed)
if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config)
else:
train_single(_config)
if __name__ == '__main__':
run_train()