forked from mindspore-Ecosystem/mindspore
modify gru network for clould
This commit is contained in:
parent
9dca0b9feb
commit
7371e9b28e
|
@ -1,53 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting
|
||||
"""
|
||||
import argparse
|
||||
|
||||
|
||||
def parser_args():
|
||||
"""Config for BGCF"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--dataset", type=str, default="Beauty", help="choose which dataset")
|
||||
parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr", help="minddata path")
|
||||
parser.add_argument("-de", "--device", type=str, default='0', help="device id")
|
||||
parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100], help="top K")
|
||||
parser.add_argument('-w', '--workers', type=int, default=8, help="number of process to generate data")
|
||||
parser.add_argument("-ckpt", "--ckptpath", type=str, default="./ckpts", help="checkpoint path")
|
||||
|
||||
parser.add_argument("-eps", "--epsilon", type=float, default=1e-8, help="optimizer parameter")
|
||||
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate")
|
||||
parser.add_argument("-l2", "--l2", type=float, default=0.03, help="l2 coefficient")
|
||||
parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'],
|
||||
help="activation function")
|
||||
parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3],
|
||||
help="dropout ratio for different aggregation layer")
|
||||
parser.add_argument("-log", "--log_name", type=str, default='test', help="log name")
|
||||
|
||||
parser.add_argument("-e", "--num_epoch", type=int, default=600, help="epoch sizes for training")
|
||||
parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128],
|
||||
help="user and item embedding dimension")
|
||||
parser.add_argument("-b", "--batch_pairs", type=int, default=5000, help="batch size")
|
||||
parser.add_argument('--eval_interval', type=int, default=20, help="evaluation interval")
|
||||
|
||||
parser.add_argument("-neg", "--num_neg", type=int, default=10, help="negative sampling rate ")
|
||||
parser.add_argument("-g1", "--raw_neighs", type=int, default=40, help="num of sampling neighbors in raw graph")
|
||||
parser.add_argument("-g2", "--gnew_neighs", type=int, default=20, help="num of sampling neighbors in sample graph")
|
||||
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim")
|
||||
parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient")
|
||||
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
|
||||
return parser.parse_args()
|
|
@ -76,19 +76,73 @@ nltk.download()
|
|||
|
||||
# [Quick Start](#content)
|
||||
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
- Running on local with Ascend
|
||||
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET_PATH]
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
|
||||
# run distributed training example
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH]
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET_PATH]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
|
||||
```
|
||||
# run distributed training example
|
||||
sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [CKPT_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
- Running on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
```python
|
||||
# Train 8p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "run_distribute=True" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "run_distribute=True" on the website UI interface.
|
||||
# Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32_0" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (3) Set the code directory to "/path/gru" on the website UI interface.
|
||||
# (4) Set the startup file to "train.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
#
|
||||
# Train 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32_0" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (3) Set the code directory to "/path/gru" on the website UI interface.
|
||||
# (4) Set the startup file to "train.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
#
|
||||
# Eval 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
|
||||
# Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "ckpt_file=/cache/checkpoint_path/model.ckpt" on the website UI interface.
|
||||
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
|
||||
# Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (3) Set the code directory to "/path/gru" on the website UI interface.
|
||||
# (4) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
```
|
||||
|
||||
# [Script Description](#content)
|
||||
|
||||
|
@ -97,9 +151,14 @@ The GRU network script and code result are as follows:
|
|||
```text
|
||||
├── gru
|
||||
├── README.md // Introduction of GRU model.
|
||||
├── model_utils
|
||||
│ ├──__init__.py // module init file
|
||||
│ ├──config.py // Parse arguments
|
||||
│ ├──device_adapter.py // Device adapter for ModelArts
|
||||
│ ├──local_adapter.py // Local adapter
|
||||
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
|
||||
├── src
|
||||
| ├──gru.py // gru cell architecture.
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──create_data.py // Dataset preparation.
|
||||
│ ├──dataset.py // Dataset loader to feed into model.
|
||||
│ ├──gru_for_infer.py // GRU eval model architecture.
|
||||
|
@ -118,6 +177,10 @@ The GRU network script and code result are as follows:
|
|||
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
|
||||
│ ├──run_eval.sh // shell script for standalone eval on ascend.
|
||||
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
|
||||
├── default_config.yaml // Configurations
|
||||
├── postprocess.py // GRU postprocess script.
|
||||
├── preprocess.py // GRU preprocess script.
|
||||
├── export.py // Export API entry.
|
||||
├── eval.py // Infer API entry.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
|
@ -213,22 +276,49 @@ Parameters for both training and evaluation can be set in config.py. All the dat
|
|||
sh parse_output.sh target.txt output.txt /path/vocab.en
|
||||
```
|
||||
|
||||
Extra: We recommend doing this locally, but you can also do it on modelarts by running a python script with the following command "os.system("sh parse_output.sh target.txt output.txt /path/vocab.en")".
|
||||
|
||||
- After parse output, we will get target.txt.forbleu and output.txt.forbleu.To calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
|
||||
|
||||
```bash
|
||||
perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu
|
||||
```
|
||||
|
||||
Extra: We recommend doing this locally, but you can also do it on modelarts by running a python script with the following command "os.system("perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu")".
|
||||
|
||||
Note: The `DATASET_PATH` is path to mindrecord. eg. train: /dataset_path/multi30k_train_mindrecord_0 eval: /dataset_path/multi30k_test_mindrecord
|
||||
|
||||
## [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
- Export on local
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
```python
|
||||
# The ckpt_file parameter is required, `EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows)
|
||||
|
||||
```python
|
||||
# Eval 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
|
||||
# Set "file_name='./gru'" on default_config.yaml file.
|
||||
# Set "file_format='MINDIR'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "ckpt_file='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_trained_ckpt/'" on the website UI interface.
|
||||
# Add "file_name='./gru'" on the website UI interface.
|
||||
# Add "file_format='MINDIR'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Set the code directory to "/path/gru" on the website UI interface.
|
||||
# (3) Set the startup file to "export.py" on the website UI interface.
|
||||
# (4) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (5) Create your job.
|
||||
```
|
||||
|
||||
## [Inference Process](#contents)
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
batch_size: 16
|
||||
eval_batch_size: 1
|
||||
src_vocab_size: 8154
|
||||
trg_vocab_size: 6113
|
||||
encoder_embedding_size: 256
|
||||
decoder_embedding_size: 256
|
||||
hidden_size: 512
|
||||
max_length: 32
|
||||
num_epochs: 30
|
||||
save_checkpoint: True
|
||||
ckpt_epoch: 10
|
||||
target_file: "target.txt"
|
||||
output_file: "output.txt"
|
||||
keep_checkpoint_max: 30
|
||||
base_lr: 0.001
|
||||
warmup_step: 300
|
||||
momentum: 0.9
|
||||
init_loss_scale_value: 1024
|
||||
scale_factor: 2
|
||||
scale_window: 2000
|
||||
warmup_ratio: 0.333333
|
||||
teacher_force_ratio: 0.5
|
||||
|
||||
run_distribute: False
|
||||
dataset_path: ""
|
||||
pre_trained: ""
|
||||
ckpt_path: "outputs/"
|
||||
outputs_dir: "./"
|
||||
ckpt_file: ""
|
||||
|
||||
# export option
|
||||
file_name: "gru"
|
||||
file_format: "MINDIR"
|
||||
|
||||
# postprocess option
|
||||
label_dir: ""
|
||||
result_dir: "./result_Files"
|
||||
|
||||
# preprocess option
|
||||
device_num: 1
|
||||
result_path: "./preprocess_Result/"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
|
||||
run_distribute: "Run distribute, default: false."
|
||||
dataset_path: "Dataset path"
|
||||
pre_trained: "Pretrained file path."
|
||||
ckpt_path: "Checkpoint save location. Default: outputs/"
|
||||
outputs_dir: "Checkpoint save location. Default: outputs/"
|
||||
ckpt_file: "ckpt file path"
|
||||
# export option
|
||||
file_name: "output file name."
|
||||
file_format: "file format. choices in ['AIR', 'MINDIR']"
|
||||
# postprocess option
|
||||
label_dir: "label data dir"
|
||||
result_dir: "infer result Files"
|
||||
# preprocess option
|
||||
device_num: "Use device nums, default is 1"
|
||||
result_path: "result path"
|
|
@ -14,46 +14,97 @@
|
|||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_infer import GRUInferCell
|
||||
from src.config import config
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.output_file = os.path.join(config.output_path, config.output_file)
|
||||
config.target_file = os.path.join(config.output_path, config.target_file)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_gru_eval():
|
||||
"""
|
||||
Transformer evaluation.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='GRU eval')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
|
||||
parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Dataset path, default: f`sns.")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
||||
device_id=args.device_id, save_graphs=False)
|
||||
mindrecord_file = args.dataset_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, reserve_class_name_in_scope=False,
|
||||
device_id=get_device_id(), save_graphs=False)
|
||||
mindrecord_file = config.dataset_path
|
||||
if not os.path.exists(mindrecord_file):
|
||||
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
||||
raise ValueError(mindrecord_file)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
||||
dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size,
|
||||
dataset_path=mindrecord_file, rank_size=get_device_num(), rank_id=0,
|
||||
do_shuffle=False, is_training=False)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config, is_training=False)
|
||||
network = GRUInferCell(network)
|
||||
network.set_train(False)
|
||||
if args.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(args.ckpt_file)
|
||||
if config.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(network, parameter_dict)
|
||||
model = Model(network)
|
||||
|
||||
|
|
|
@ -13,35 +13,39 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export script."""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_infer import GRUInferCell
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='export')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
|
||||
parser.add_argument('--file_name', type=str, default="gru", help='output file name.')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format.")
|
||||
parser.add_argument('--ckpt_file', type=str, required=True, help='ckpt file path')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
||||
device_id=args.device_id, save_graphs=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
"""run export."""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, reserve_class_name_in_scope=False,
|
||||
device_id=get_device_id(), save_graphs=False)
|
||||
network = Seq2Seq(config, is_training=False)
|
||||
network = GRUInferCell(network)
|
||||
network.set_train(False)
|
||||
if args.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(args.ckpt_file)
|
||||
if config.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(network, parameter_dict)
|
||||
|
||||
source_ids = Tensor(np.random.uniform(0.0, 1e5, size=[config.eval_batch_size, config.max_length]).astype(np.int32))
|
||||
target_ids = Tensor(np.random.uniform(0.0, 1e5, size=[config.eval_batch_size, config.max_length]).astype(np.int32))
|
||||
export(network, source_ids, target_ids, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network, source_ids, target_ids, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -18,24 +18,17 @@ postprocess script.
|
|||
'''
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="postprocess")
|
||||
parser.add_argument("--label_dir", type=str, default="", help="label data dir")
|
||||
parser.add_argument("--result_dir", type=str, default="./result_Files", help="infer result Files")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
from model_utils.config import config
|
||||
|
||||
if __name__ == "__main__":
|
||||
file_name = os.listdir(args.label_dir)
|
||||
file_name = os.listdir(config.label_dir)
|
||||
predictions = []
|
||||
target_sents = []
|
||||
for f in file_name:
|
||||
target_ids = np.fromfile(os.path.join(args.label_dir, f), np.int32)
|
||||
target_ids = np.fromfile(os.path.join(config.label_dir, f), np.int32)
|
||||
target_sents.append(target_ids.reshape(config.eval_batch_size, config.max_length))
|
||||
predicted_ids = np.fromfile(os.path.join(args.result_dir, f.split('.')[0] + '_0.bin'), np.int32)
|
||||
predicted_ids = np.fromfile(os.path.join(config.result_dir, f.split('.')[0] + '_0.bin'), np.int32)
|
||||
predictions.append(predicted_ids.reshape(config.eval_batch_size, config.max_length - 1))
|
||||
|
||||
f_output = open(config.output_file, 'w')
|
||||
|
|
|
@ -14,27 +14,19 @@
|
|||
# ============================================================================
|
||||
"""GRU preprocess script."""
|
||||
import os
|
||||
import argparse
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='GRU preprocess')
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Dataset path, default: f`sns.")
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
|
||||
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
|
||||
if __name__ == "__main__":
|
||||
mindrecord_file = args.dataset_path
|
||||
mindrecord_file = config.dataset_path
|
||||
if not os.path.exists(mindrecord_file):
|
||||
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
||||
raise ValueError(mindrecord_file)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
||||
dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset_path=mindrecord_file, rank_size=config.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
|
||||
source_ids_path = os.path.join(args.result_path, "00_data")
|
||||
target_ids_path = os.path.join(args.result_path, "01_data")
|
||||
source_ids_path = os.path.join(config.result_path, "00_data")
|
||||
target_ids_path = os.path.join(config.result_path, "01_data")
|
||||
os.makedirs(source_ids_path)
|
||||
os.makedirs(target_ids_path)
|
||||
|
||||
|
|
|
@ -58,11 +58,13 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp -r ../model_utils ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log &
|
||||
python train.py --run_distribute=True --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
||||
done
|
|
@ -49,8 +49,10 @@ fi
|
|||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
|
|
@ -42,10 +42,12 @@ fi
|
|||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
|
||||
python train.py --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU config"""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"batch_size": 16,
|
||||
"eval_batch_size": 1,
|
||||
"src_vocab_size": 8154,
|
||||
"trg_vocab_size": 6113,
|
||||
"encoder_embedding_size": 256,
|
||||
"decoder_embedding_size": 256,
|
||||
"hidden_size": 512,
|
||||
"max_length": 32,
|
||||
"num_epochs": 30,
|
||||
"save_checkpoint": True,
|
||||
"ckpt_epoch": 10,
|
||||
"target_file": "target.txt",
|
||||
"output_file": "output.txt",
|
||||
"keep_checkpoint_max": 30,
|
||||
"base_lr": 0.001,
|
||||
"warmup_step": 300,
|
||||
"momentum": 0.9,
|
||||
"init_loss_scale_value": 1024,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 2000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"teacher_force_ratio": 0.5
|
||||
})
|
|
@ -18,7 +18,7 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
|
||||
de.config.set_seed(1)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore import Tensor
|
|||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
|
||||
class GRUInferCell(nn.Cell):
|
||||
'''
|
||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
from src.loss import NLLLoss
|
||||
|
||||
class GRUWithLossCell(nn.Cell):
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
"""train script"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
|
@ -25,31 +23,25 @@ from mindspore.train import Model
|
|||
from mindspore.common import set_seed
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.nn.optim import Adam
|
||||
from src.config import config
|
||||
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_rank_id, get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="GRU training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False)
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
@ -89,17 +81,72 @@ class LossCallBack(Callback):
|
|||
str(cb_params.net_outputs[2].asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.run_distribute:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.outputs_dir = os.path.join(config.output_path, config.outputs_dir)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
"""run train."""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id(), save_graphs=False)
|
||||
rank = get_rank_id()
|
||||
device_num = get_device_num()
|
||||
if config.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
mindrecord_file = args.dataset_path
|
||||
mindrecord_file = config.dataset_path
|
||||
if not os.path.exists(mindrecord_file):
|
||||
print("dataset file {} not exists, please check!".format(mindrecord_file))
|
||||
raise ValueError(mindrecord_file)
|
||||
|
@ -120,15 +167,18 @@ if __name__ == '__main__':
|
|||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
#Save Checkpoint
|
||||
# Save Checkpoint
|
||||
if config.save_checkpoint:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size,
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch * dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/')
|
||||
save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(get_rank_id()) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank_id))
|
||||
prefix='{}'.format(get_rank_id()))
|
||||
cb += [ckpt_cb]
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue