!17086 crnn_seq2seq_ocr used on ModelArts.

From: @ZhengBina
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-29 16:49:21 +08:00 committed by Gitee
commit cb948586ca
14 changed files with 594 additions and 175 deletions

View File

@ -48,18 +48,17 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released
## [Quick Start](#contents)
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
- Running on Ascend
```shell
# distribute training example in Ascend
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
# evaluation example in Ascend
$ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
$ bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]
# standalone training example in Ascend
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
$ bash run_standalone_train.sh [TRAIN_DATA_DIR]
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
@ -67,6 +66,56 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released
Please follow the instructions in the link below:
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
- Running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows.
- Training with 8 cards on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface.
# (4) Set the startup file to /{path}/crnn_seq2seq_ocr/train.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml.
# 1. Set ”is_distributed=1“
# 2. Set ”enable_modelarts=True“
# 3. Set ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# b. adding on the website UI interface.
# 1. Add ”is_distributed=1“
# 2. Add ”enable_modelarts=True“
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of 8 cards.
# (10) Create your job.
```
- evaluating with single card on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface.
# (4) Set the startup file to /{path}/crnn_seq2seq_ocr/eval.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml.
# 1. Set ”enable_modelarts=True“
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# b. adding on the website UI interface.
# 1. Set ”enable_modelarts=True“
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of a single card.
# (10) Create your job.
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
@ -79,9 +128,13 @@ crnn-seq2seq-ocr
│   ├── run_eval_ascend.sh # Launch Ascend evaluation
│   └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs)
├── src
|   |── scripts
│   |   ├── config.py # parsing parameter configuration file of "*.yaml"
│   |   ├── device_adapter.py # local or ModelArts training
│   |   ├── local_adapter.py # get related environment variables in local training
│   |   └── moxing_adapter.py # get related environment variables in ModelArts training
│   ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper
│   ├── cnn.py # VGG network
│   ├── config.py # Parameter configuration
│   ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth
│   ├── dataset.py # Data preprocessing for training and evaluation
│   ├── gru.py # GRU cell wrapper
@ -90,8 +143,9 @@ crnn-seq2seq-ocr
│   ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure
│   └── utils.py # Utility functions for training and data pre-processing
│   ├── weight_init.py # weight initialization of LSTM and GRU
└── train.py # Training script
├── eval.py # Evaluation Script
├── general_chars.txt # general chars
└── train.py # Training script
```
### [Script Parameters](#contents)
@ -100,10 +154,10 @@ crnn-seq2seq-ocr
```shell
# distributed training on Ascend
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
# standalone training
Usage: bash run_standalone_train.sh [DATASET_PATH]
Usage: bash run_standalone_train.sh [TRAIN_DATA_DIR]
```
#### Parameters Configuration
@ -116,14 +170,14 @@ Parameters for both training and evaluation can be set in config.py.
## [Training Process](#contents)
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
- Set options in `default_config.yaml`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
### [Training](#contents)
- Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now.
``` bash
bash run_standalone_train.sh [DATASET_PATH]
bash run_standalone_train.sh [TRAIN_DATA_DIR]
```
#### [Distributed Training](#contents)
@ -131,7 +185,7 @@ bash run_standalone_train.sh [DATASET_PATH]
- Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend.
``` bash
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
Check the `train_parallel0/log.txt` and you will get outputs as following:
@ -149,7 +203,7 @@ epoch time: 1559886.096 ms, per step time: 382.231 ms
- Run `run_eval_ascend.sh` for evaluation on Ascend.
``` bash
bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]
```
Check the `eval/log` and you will get outputs as following:

View File

@ -0,0 +1,89 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
modelarts_dataset_unzip_name: None
# ==============================================================================
#train-related
is_distributed: 0
rank_id: 0
train_data_dir: ''
batch_size: 32
num_epochs: 20
keep_checkpoint_max: 20
#eval-related
eval_batch_size: 32
test_data_dir: ''
checkpoint_path: None
# logging-related
log_interval: 100
pre_checkpoint_path: ''
ckpt_path: "outputs/"
ckpt_interval: None
is_save_on_master: 0
# dataset-related
mindrecord_dir: ''
data_root: ''
annotation_file: ''
val_data_root: ''
val_annotation_file: ''
data_json: ''
go_shift: 1
characters_dictionary: {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3}
labels_not_use: ['%#<23>?%', '%#背景#%', '%#不识<E4B88D>?%', '#%不识<E4B88D>?#', '%#模糊#%', '%#模糊#%']
vocab_path: "./general_chars.txt"
# model-related
img_width: 512
img_height: 128
channel_size: 3
conv_out_dim: 384
encoder_hidden_size: 128
decoder_hidden_size: 128
decoder_output_size: 10000
dropout_p: 0.1
max_length: 64
attn_num_layers: 1
teacher_force_ratio: 0.5
#optimizer-related
lr: 0.0008
adam_beta1: 0.5
adam_beta2: 0.999
loss_scale: 1024
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
is_distributed: 'Distribute train or not, 1 for yes, 0 for no. Default: 0'
rank_id: "Local rank of distributed. Default: 0"
train_data_dir: "Train dataset directory."
log_interval: "Logging interval steps. Default: 100"
ckpt_path: "Checkpoint save location. Default: outputs/"
pre_checkpoint_path: "Checkpoint save location."
ckpt_interval: "Save checkpoint interval. Default: None"
is_save_on_master: "Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0"
test_data_dir: "Test Dataset path"
checkpoint_path: "Checkpoint of AttentionOCR (Default:None)."

View File

@ -19,7 +19,6 @@ CRNN-Seq2Seq-OCR Evaluation.
import os
import codecs
import argparse
import numpy as np
import mindspore.ops.operations as P
@ -29,11 +28,13 @@ from mindspore.common import set_seed
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.utils import initialize_vocabulary
from src.dataset import create_ocr_val_dataset
from src.attention_ocr import AttentionOCRInfer
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
set_seed(1)
@ -75,30 +76,20 @@ def LCS_length(str1, str2):
return lcs[len1 % 2][-1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation")
parser.add_argument("--dataset_path", type=str, default="",
help="Test Dataset path")
parser.add_argument("--checkpoint_path", type=str, default=None,
help="Checkpoint of AttentionOCR (Default:None).")
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
@moxing_wrapper()
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
prefix = "fsns.mindrecord"
mindrecord_dir = args.dataset_path
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if config.enable_modelarts:
mindrecord_file = os.path.join(config.data_path, prefix + "0")
else:
mindrecord_file = os.path.join(config.test_data_dir, prefix + "0")
print("mindrecord_file", mindrecord_file)
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
print("Dataset creation Done!")
#Network
# Network
network = AttentionOCRInfer(config.eval_batch_size,
int(config.img_width / 4),
config.encoder_hidden_size,
@ -106,15 +97,16 @@ if __name__ == '__main__':
config.decoder_output_size,
config.max_length,
config.dropout_p)
ckpt = load_checkpoint(args.checkpoint_path)
checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
ckpt = load_checkpoint(checkpoint_path)
load_param_into_net(network, ckpt)
network.set_train(False)
print("Checkpoint loading Done!")
vocab, rev_vocab = initialize_vocabulary(config.vocab_path)
eos_id = config.characters_dictionary.get("eos_id")
sos_id = config.characters_dictionary.get("go_id")
vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path)
_, rev_vocab = initialize_vocabulary(vocab_path)
eos_id = config.characters_dictionary.eos_id
sos_id = config.characters_dictionary.go_id
num_correct_char = 0
num_total_char = 0
@ -125,20 +117,20 @@ if __name__ == '__main__':
incorrect_file = 'result_incorrect.txt'
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
for data in data_loader:
images = Tensor(data["image"])
decoder_inputs = Tensor(data["decoder_input"])
decoder_targets = Tensor(data["decoder_target"])
# decoder_targets = Tensor(data["decoder_target"])
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
dtype=np.float16), mstype.float16)
decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32))
decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32))
encoder_outputs = network.encoder(images)
batch_decoded_label = []
for di in range(decoder_inputs.shape[1]):
for _ in range(decoder_inputs.shape[1]):
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = P.Argmax()(decoder_output)
ni = P.ExpandDims()(topi, 1)
@ -179,3 +171,5 @@ if __name__ == '__main__':
print('\nnum of total words = %d' % (num_total_word))
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))
if __name__ == '__main__':
run_eval()

View File

@ -16,7 +16,7 @@
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
@ -39,9 +39,9 @@ fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -f $PATH2 ]
if [ ! -d $PATH2 ]
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
echo "error: TRAIN_DATA_DIR=$PATH2 is not a folder"
exit 1
fi
@ -58,9 +58,11 @@ do
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp ../*.txt ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log &
python train.py --is_distribute=1 --train_data_dir=$PATH2 &> log &
cd ..
done

View File

@ -16,7 +16,7 @@
if [ $# != 2 ]
then
echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage: sh run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]"
exit 1
fi
@ -34,7 +34,7 @@ echo $PATH2
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a folder"
echo "error: TEST_DATA_DIR=$PATH1 is not a folder"
exit 1
fi
@ -56,10 +56,11 @@ fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.txt ./eval
cp ../*.yaml ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
python eval.py --device_target="Ascend" --test_data_dir=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

View File

@ -16,7 +16,7 @@
if [ $# -ne 1 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]"
echo "Usage: sh run_standalone_train_ascend.sh [TRAIN_DATA_DIR]"
exit 1
fi
@ -31,9 +31,9 @@ get_real_path(){
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a file"
echo "error: TRAIN_DATA_DIR=$PATH1 is not a folder"
exit 1
fi
@ -50,9 +50,11 @@ fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp ../*.yaml ./train
cp ../*.txt ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log &
python train.py --train_data_dir=$PATH1 --is_distributed=0 &> log &
cd ..

View File

@ -1,62 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""Config parameters for CRNN-Seq2Seq-OCR model."""
from easydict import EasyDict as ed
config = ed({
# dataset-related
"mindrecord_dir": "",
"data_root": "",
"annotation_file": "",
"val_data_root": "",
"val_annotation_file": "",
"data_json": "",
"go_shift": 1,
"characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3},
"labels_not_use": [u'%#<23>?%', u'%#背景#%', u'%#不识<E4B88D>?%', u'#%不识<EFBFBD>?#', u'%#模糊#%', u'%#模糊#%'],
"vocab_path": "./general_chars.txt",
#model-related
"img_width": 512,
"img_height": 128,
"channel_size": 3,
"conv_out_dim": 384,
"encoder_hidden_size": 128,
"decoder_hidden_size": 128,
"decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space
"dropout_p": 0.1,
"max_length": 64,
"attn_num_layers": 1,
"teacher_force_ratio": 0.5,
#optimizer-related
"lr": 0.0008,
"adam_beta1": 0.5,
"adam_beta2": 0.999,
"loss_scale": 1024,
#train-related
"batch_size": 32,
"num_epochs": 20,
"keep_checkpoint_max": 20,
#eval-related
"eval_batch_size": 32
})

View File

@ -19,7 +19,7 @@ import numpy as np
from mindspore.mindrecord import FileWriter
from config import config
from src.model_utils.config import config
from utils import initialize_vocabulary

View File

@ -24,7 +24,7 @@ import mindspore.dataset.vision.py_transforms as P
import mindspore.dataset.transforms.c_transforms as ops
import mindspore.common.dtype as mstype
from src.config import config
from src.model_utils.config import config
class AugmentationOps():

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r', encoding='utf-8') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
# print("os.mknod({}) success".format(sync_lock))
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -16,10 +16,9 @@
CRNN-Seq2Seq-OCR train.
"""
import os
import argparse
import datetime
import time
import os
import mindspore.nn as nn
import mindspore.common.dtype as mstype
@ -31,62 +30,78 @@ from mindspore import context
from mindspore.communication.management import init
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.dataset import create_ocr_train_dataset
from src.logger import get_logger
from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
set_seed(1)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.modelarts_dataset_unzip_name:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
def parse_args():
"""Parse train arguments."""
parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training')
sync_lock = "/tmp/unzip_sync.lock"
# device related
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
# distributed related
parser.add_argument('--is_distributed', type=int, default=0,
help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0')
parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1')
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
#dataset related
parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.')
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
parser.add_argument('--is_save_on_master', type=int, default=0,
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0')
args, _ = parser.parse_known_args()
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
return args
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
"""Train function."""
args = parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
if args.is_distributed:
rank = args.rank_id
device_num = args.device_num
if config.is_distributed:
rank = get_rank_id()
device_num = get_device_num()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
@ -96,25 +111,31 @@ def train():
device_num = 1
# Logger
args.logger = get_logger(args.outputs_dir, rank)
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
config.logger = get_logger(config.outputs_dir, rank)
config.rank_save_ckpt_flag = 0
if config.is_save_on_master:
if rank == 0:
args.rank_save_ckpt_flag = 1
config.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
config.rank_save_ckpt_flag = 1
# DATASET
dataset = create_ocr_train_dataset(args.mindrecord_file,
prefix = "fsns.mindrecord"
if config.enable_modelarts:
mindrecord_file = os.path.join(config.data_path, prefix + "0")
else:
mindrecord_file = os.path.join(config.train_data_dir, prefix + "0")
dataset = create_ocr_train_dataset(mindrecord_file,
config.batch_size,
rank_size=device_num,
rank_id=rank)
args.steps_per_epoch = dataset.get_dataset_size()
args.logger.info('Finish loading dataset')
config.steps_per_epoch = dataset.get_dataset_size()
config.logger.info('Finish loading dataset')
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
args.logger.save_args(args)
if not config.ckpt_interval:
config.ckpt_interval = config.steps_per_epoch
config.logger.save_args(config)
network = AttentionOCR(config.batch_size,
int(config.img_width / 4),
@ -124,8 +145,10 @@ def train():
config.max_length,
config.dropout_p)
if args.pre_checkpoint_path:
param_dict = load_checkpoint(args.pre_checkpoint_path)
if config.pre_checkpoint_path:
config.pre_checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.pre_checkpoint_path
)
param_dict = load_checkpoint(config.pre_checkpoint_path)
load_param_into_net(network, param_dict)
network = AttentionOCRWithLossCell(network, config.max_length)
@ -136,13 +159,13 @@ def train():
network = TrainingWrapper(network, opt, sens=config.loss_scale)
args.logger.info('Finished get network')
config.logger.info('Finished get network')
callback = [TimeMonitor(data_size=1), LossMonitor()]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch,
if config.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/')
save_ckpt_path = os.path.join(config.outputs_dir, 'checkpoints' + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix="crnn_seq2seq_ocr")
@ -151,7 +174,7 @@ def train():
model = Model(network)
model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False)
args.logger.info('==========Training Done===============')
config.logger.info('==========Training Done===============')
if __name__ == "__main__":