merge crnn

This commit is contained in:
maijianqiang 2021-06-19 18:11:13 +08:00
parent 6fb3981170
commit 100fa004e9
20 changed files with 619 additions and 182 deletions

View File

@ -127,7 +127,11 @@ crnn
│   ├── run_eval.sh # Launch evaluation
│   └── run_standalone_train.sh # Launch standalone training(1 pcs)
├── src
│   ├── config.py # Parameter configuration
│ ├── model_utils
│   ├── config.py # Parameter config
│   ├── moxing_adapter.py # modelarts device configuration
│   └── device_adapter.py # Device Config
│   └── local_adapter.py # local device config
│   ├── crnn.py # crnn network definition
│   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
│   ├── dataset.py # Data preprocessing for training and evaluation
@ -140,6 +144,8 @@ crnn
│   └── svt_dataset.py # Data preprocessing for SVT
└── train.py # Training script
├── eval.py # Evaluation Script
├── default_config.yaml # config file
```
### [Script Parameters](#contents)
@ -156,7 +162,7 @@ Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
#### Parameters Configuration
Parameters for both training and evaluation can be set in config.py.
Parameters for both training and evaluation can be set in default_config.yaml.
```shell
max_text_length": 23, # max number of digits in each
@ -210,6 +216,59 @@ epoch: 10 step: 14110, loss is 0.0029097411
Epoch time: 2743.688s, per step time: 0.097s
```
- running on ModelArts
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
```python
# Example of using distributed training dpn on modelarts :
# Data set storage method
# ├── crnn_dataset # dataset dir
# ├──train # train dir
# ├── mnt # train dataset dir
# ├── pred_trained # pred_train
# ├── eval # eval dir
# ├── IIIT5K-Word_V3.0 # eval dataset dir
# ├── checkpoint # checkpoint dir
# ├── svt # checkpoint dir
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
# a. set "enable_modelarts=True"
# set "run_distribute=True"
# set "save_checkpoint_path=/cache/train/checkpoint"
# set "train_dataset_path=/cache/data/mnt/ramdisk/max/90kDICT32px"
#
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (3) Set the code path on the modelarts interface "/path/crnn"。
# (4) Set the model's startup file on the modelarts interface "train.py" 。
# (5) Set the data path of the model on the modelarts interface ".../crnn_dataset/train"(choices crnn_dataset/train Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (6) start trainning the model。
# Example of using model inference on modelarts
# (1) Place the trained model to the corresponding position of the bucket。
# (2) chocie a or b。
# a.set "enable_modelarts=True"
# set "eval_dataset=svt" or eval_dataset=iiit5k
# set "eval_dataset_path=/cache/data/svt/converted/img/" or eval_dataset_path=/cache/data/IIIT5K-Word_V3/IIIT5K/
# set "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name"
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (4) Set the code path on the modelarts interface "/path/crnn"。
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
# (6) Set the data path of the model on the modelarts interface ".../crnn_dataset/eval"(choices crnn/eval Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
# (7) Start model inference。
```
## [Evaluation Process](#contents)
### [Evaluation](#contents)
@ -241,6 +300,27 @@ python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [
The ckpt_file parameter is required,
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
- Export MindIR on Modelarts
```Modelarts
Export MindIR example on ModelArts
Data storage method is the same as training
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。
# a. set "enable_modelarts=True"
# set "file_name=/cache/train/crnn"
# set "file_format=MINDIR"
# set "ckpt_file=/cache/data/checkpoint file name"
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
# Set the parameters required by method a on the modelarts interface
# Note: The path parameter does not need to be quoted
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
# (3) Set the code path on the modelarts interface "/path/crnn"。
# (4) Set the model's startup file on the modelarts interface "export.py" 。
# (5) Set the data path of the model on the modelarts interface ".../crnn_dataset/eval/checkpoint"(choices crnn_dataset/eval/checkpoint Folder path) ,
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
```
### Infer on Ascend310
Before performing inference, the mindir file must bu exported by export script on the 910 environment. We only provide an example of inference using MINDIR model.

View File

@ -0,0 +1,94 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
enable_modelarts: False
# url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
# ======================================================================================
# common options
run_distribute: False
model: "lowercase"
# ======================================================================================
# Training options
label_dict: "abcdefghijklmnopqrstuvwxyz0123456789"
train_dataset: "synth"
max_text_length: 23
image_width: 100
image_height: 32
batch_size: 64
epoch_size: 10
hidden_size: 256
learning_rate: 0.02
momentum: 0.95
nesterov: True
save_checkpoint: True
save_checkpoint_steps: 1000
keep_checkpoint_max: 30
save_checkpoint_path: "./"
class_num: 37
input_size: 512
num_step: 24
use_dropout: True
blank: 36
train_dataset_path: ""
train_eval_dataset: "svt"
train_eval_dataset_path: ""
run_eval: False
save_best_ckpt: True
eval_start_epoch: 5
eval_interval: 5
# ======================================================================================
# Eval options
eval_dataset: "svt"
eval_dataset_path: ""
checkpoint_path: ""
# ======================================================================================
# export options
device_id: 0
ckpt_file: ""
file_name: "crnn"
file_format: "MINDIR"
# ======================================================================================
#postprocess
ann_file: True
result_path: True
dataset: "ic03"
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of input data"
output_pah: "The location of the output file"
device_target: "device id of GPU or Ascend. (Default: None)"
enable_profiling: "Whether enable profiling while training default: False"
file_name: "CNN&CTC output air name"
file_format: "choices [AIR, MINDIR]"
ckpt_file: "Checkpoint file path."
run_distribute: "Run distribute, default is false."
train_dataset_path: "train Dataset path, default is None"
model: "Model type, default is lowercase"
train_dataset: "choices [synth, ic03, ic13, svt, iiit5k]"
train_eval_dataset: "choices [synth, ic03, ic13, svt, iiit5k]"
train_eval_dataset_path: "Dataset path, default is None"
run_eval: "Run evaluation when training, default is False."
save_best_ckpt: "Save best checkpoint when run_eval is True, default is True."
eval_start_epoch: "Evaluation start epoch when run_eval is True, default is 5."
eval_interval: "Evaluation interval when run_eval is True, default is 5."
eval_dataset_path: "eval Dataset, default is None."
checkpoint_path: "checkpoint file path, default is None"
ann_file: "ann file."
result_path: "image file path."
dataset: "choices=['ic03', 'ic13', 'svt', 'iiit5k']"

View File

@ -13,60 +13,55 @@
# limitations under the License.
# ============================================================================
"""Warpctc evaluation"""
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import crnn
from src.metric import CRNNAccuracy
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
set_seed(1)
parser = argparse.ArgumentParser(description="CRNN eval")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
args_opt = parser.parse_args()
if args_opt.model == 'lowcase':
from src.config import config1 as config
else:
from src.config import config2 as config
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
@moxing_wrapper(pre_process=None)
def crnn_eval():
if config.device_target == 'Ascend':
device_id = get_device_id()
context.set_context(device_id=device_id)
config.batch_size = 1
max_text_length = config.max_text_length
input_size = config.input_size
# input_size = config.input_size
# create dataset
dataset = create_dataset(name=args_opt.dataset,
dataset_path=args_opt.dataset_path,
dataset = create_dataset(name=config.eval_dataset,
dataset_path=config.eval_dataset_path,
batch_size=config.batch_size,
is_training=False,
config=config)
step_size = dataset.get_dataset_size()
# step_size = dataset.get_dataset_size()
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = crnn(config)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
# start evaluation
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
res = model.eval(dataset, dataset_sink_mode=config.device_target == 'Ascend')
print("result:", res, flush=True)
if __name__ == '__main__':
crnn_eval()

View File

@ -14,34 +14,34 @@
# ============================================================================
""" export model for CRNN """
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor, context, load_checkpoint, export
from src.crnn import crnn
from src.config import config1 as config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
parser = argparse.ArgumentParser(description="CRNN_export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="crnn", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":
def modelarts_pre_process():
pass
@moxing_wrapper(pre_process=modelarts_pre_process)
def model_export():
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
config.batch_size = 1
net = crnn(config)
load_checkpoint(args.ckpt_file, net=net)
load_checkpoint(config.ckpt_file, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([1, 3, config.image_height, config.image_width]), ms.float32)
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
model_export()

View File

@ -14,17 +14,10 @@
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from src.metric import CRNNAccuracy
from src.config import config1 as config
from src.model_utils.config import config
parser = argparse.ArgumentParser(description="yolov3_darknet53 inference")
parser.add_argument("--ann_file", type=str, required=True, help="ann file.")
parser.add_argument("--result_path", type=str, required=True, help="image file path.")
parser.add_argument("--dataset", type=str, default="ic03", choices=['ic03', 'ic13', 'svt', 'iiit5k'])
args = parser.parse_args()
def read_annotation(ann_file):
file = open(ann_file)
@ -37,6 +30,7 @@ def read_annotation(ann_file):
return ann
def read_ic13_annotation(ann_file):
file = open(ann_file)
@ -48,6 +42,7 @@ def read_ic13_annotation(ann_file):
return ann
def read_svt_annotation(ann_file):
file = open(ann_file)
@ -59,17 +54,18 @@ def read_svt_annotation(ann_file):
return ann
def get_eval_result(result_path, ann_file):
"""
Calculate accuracy according to the annotation file and result file.
"""
metrics = CRNNAccuracy(config)
if args.dataset == "ic03" or args.dataset == "iiit5k":
if config.dataset == "ic03" or config.dataset == "iiit5k":
ann = read_annotation(ann_file)
elif args.dataset == "ic13":
elif config.dataset == "ic13":
ann = read_ic13_annotation(ann_file)
elif args.dataset == "svt":
elif config.dataset == "svt":
ann = read_svt_annotation(ann_file)
for img_name, label in ann.items():
@ -80,5 +76,6 @@ def get_eval_result(result_path, ann_file):
print("result CRNNAccuracy is: ", metrics.eval())
metrics.clear()
if __name__ == '__main__':
get_eval_result(args.result_path, args.ann_file)
get_eval_result(config.result_path, config.ann_file)

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
echo "Usage: sh scripts/run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
@ -51,12 +51,13 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ./*.py ./train_parallel$i
cp -r scripts/ ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./*yaml ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
cd ..
done

View File

@ -58,12 +58,14 @@ run_ascend() {
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cp ./*.py ./eval
cp -r ./src ./eval
cp -r ./scripts ./eval
cp ./*yaml ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
python eval.py --eval_dataset=$DATASET_NAME --eval_dataset_path=$1 --checkpoint_path=$2 --device_target=Ascend> log.txt 2>&1 &
cd ..
}
@ -72,15 +74,16 @@ run_gpu() {
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cp ./*.py ./eval
cp -r ./src ./eval
cp -r ./scripts ./eval
cp ./*yaml ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset=$DATASET_NAME \
--dataset_path=$1 \
python eval.py --eval_dataset=$DATASET_NAME \
--eval_dataset_path=$1 \
--checkpoint_path=$2 \
--platform=GPU \
--dataset=$DATASET_NAME > log.txt 2>&1 &
--device_target=GPU > log.txt 2>&1 &
cd ..
}

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 3 ] && [ $# != 2 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)"
echo "Usage: sh scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)"
exit 1
fi
@ -49,13 +49,13 @@ run_ascend() {
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$1 --device_target=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
env >env.log
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$1 --device_target=GPU > log.txt 2>&1 &
cd ..
}
@ -63,9 +63,12 @@ if [ -d "train" ]; then
rm -rf ./train
fi
WORKDIR=./train${DEVICE_ID}
rm -rf $WORKDIR
mkdir $WORKDIR
cp ../*.py $WORKDIR
cp -r ../src $WORKDIR
cp ./*.py $WORKDIR
cp -r ./src $WORKDIR
cp -r ./scripts $WORKDIR
cp ./*yaml $WORKDIR
cd $WORKDIR || exit
if [ "Ascend" == $PLATFORM ]; then

View File

@ -1,42 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
# use for low case number
config1 = EasyDict({
"max_text_length": 23,
"image_width": 100,
"image_height": 32,
"batch_size": 64,
"epoch_size": 10,
"hidden_size": 256,
"learning_rate": 0.02,
"momentum": 0.95,
"nesterov": True,
"save_checkpoint": True,
"save_checkpoint_steps": 1000,
"keep_checkpoint_max": 30,
"save_checkpoint_path": "./",
"class_num": 37,
"input_size": 512,
"num_step": 24,
"use_dropout": True,
"blank": 36
})

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vc
from src.config import config1, label_dict
from src.model_utils.config import config as config1
from src.ic03_dataset import IC03Dataset
from src.ic13_dataset import IC13Dataset
from src.iiit5k_dataset import IIIT5KDataset
@ -75,8 +75,8 @@ class CaptchaDataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -17,7 +17,7 @@
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
from src.model_utils.config import config as config1
ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -48,7 +48,7 @@ class IC03Dataset:
if filter_by_dict:
flag = True
for c in label:
if c not in label_dict:
if c not in config.label_dict:
flag = False
break
if not flag:
@ -73,8 +73,8 @@ class IC03Dataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -17,7 +17,7 @@
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
from src.model_utils.config import config as config1
ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -47,7 +47,7 @@ class IC13Dataset:
if filter_by_dict:
flag = True
for c in label:
if c not in label_dict:
if c not in config.label_dict:
flag = False
break
if not flag:
@ -70,8 +70,8 @@ class IC13Dataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -17,7 +17,7 @@
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
from src.model_utils.config import config as config1
ImageFile.LOAD_TRUNCATED_IMAGES = True
class IIIT5KDataset:
@ -62,8 +62,8 @@ class IIIT5KDataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
_config_path = '../../default_config.yaml'
class Config:
"""
Configuration namespace. Convert dictionary to members
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
"""
Parse command line arguments to the configuration according to the default yaml
Args:
parser: Parent parser
cfg: Base configuration
helper: Helper description
cfg_path: Path to the default yaml config
"""
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file
Args:
yaml_path: Path to the yaml config
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
print(cfg_helper)
except:
raise ValueError('Failed to parse yaml')
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments
Args:
args: command line arguments
cfg: Base configuration
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments
"""
parser = argparse.ArgumentParser(description='default name', add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
help='Config file path')
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return 'Local Job'

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_syn_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local
Uploca data from local directory to remote obs in contrast
"""
import moxing as mox
import time
global _global_syn_count
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
_global_syn_count += 1
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print('from path: ', from_path)
print('to path: ', to_path)
mox.file.copy_parallel(from_path, to_path)
print('===finished data synchronization===')
try:
os.mknod(sync_lock)
except IOError:
pass
print('===save flag===')
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print('Finish sync data from {} to {}'.format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print('Dataset downloaded: ', os.listdir(config.data_path))
if config.checkpoint_url:
if not os.path.exists(config.load_path):
# os.makedirs(config.load_path)
print('=' * 20 + 'makedirs')
if os.path.isdir(config.load_path):
print('=' * 20 + 'makedirs success')
else:
print('=' * 20 + 'makedirs fail')
sync_data(config.checkpoint_url, config.load_path)
print('Preload downloaded: ', os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print('Workspace downloaded: ', os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print('Start to copy output directory')
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -17,7 +17,7 @@
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
from src.model_utils.config import config as config1
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SVTDataset:
@ -60,8 +60,8 @@ class SVTDataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -14,8 +14,6 @@
# ============================================================================
"""crnn training"""
import os
import argparse
import ast
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
@ -24,43 +22,20 @@ from mindspore.context import ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import crnn
from src.crnn_for_train import TrainOneStepCellWithGradClip
from src.metric import CRNNAccuracy
from src.eval_callback import EvalCallBack
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_rank_id, get_device_num, get_device_id
set_seed(1)
parser = argparse.ArgumentParser(description="crnn training")
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'],
help='Running platform, only support Ascend now. Default is Ascend.')
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=5,
help="Evaluation start epoch when run_eval is True, default is 5.")
parser.add_argument("--eval_interval", type=int, default=5,
help="Evaluation interval when run_eval is True, default is 5.")
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if args_opt.model == 'lowercase':
from src.config import config1 as config
else:
from src.config import config2 as config
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
def apply_eval(eval_param):
evaluation_model = eval_param["model"]
@ -69,17 +44,27 @@ def apply_eval(eval_param):
res = evaluation_model.eval(eval_ds)
return res[metrics_name]
if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute:
if args_opt.platform == 'Ascend':
def modelarts_pre_process():
pass
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
if config.device_target == 'Ascend':
device_id = get_device_id()
context.set_context(device_id=device_id)
# lr_scale = 1
if config.run_distribute:
if config.device_target == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
# lr_scale = 1
device_num = get_device_num()
rank = get_rank_id()
else:
init()
lr_scale = 1
# lr_scale = 1
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context()
@ -92,7 +77,8 @@ if __name__ == '__main__':
max_text_length = config.max_text_length
# create dataset
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
dataset = create_dataset(name=config.train_dataset, dataset_path=config.train_dataset_path,
batch_size=config.batch_size,
num_shards=device_num, shard_id=rank, config=config)
step_size = dataset.get_dataset_size()
# define lr
@ -111,18 +97,18 @@ if __name__ == '__main__':
# define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
if args_opt.run_eval:
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
eval_dataset = create_dataset(name=args_opt.eval_dataset,
dataset_path=args_opt.eval_dataset_path,
if config.run_eval:
if config.train_eval_dataset_path is None or (not os.path.isdir(config.train_eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(config.train_eval_dataset_path))
eval_dataset = create_dataset(name=config.train_eval_dataset,
dataset_path=config.train_eval_dataset_path,
batch_size=config.batch_size,
is_training=False,
config=config)
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
metrics_name="acc")
callbacks += [eval_cb]
@ -132,3 +118,7 @@ if __name__ == '__main__':
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(config.epoch_size, dataset, callbacks=callbacks)
if __name__ == '__main__':
train()