forked from mindspore-Ecosystem/mindspore
merge crnn
This commit is contained in:
parent
6fb3981170
commit
100fa004e9
|
@ -127,7 +127,11 @@ crnn
|
|||
│ ├── run_eval.sh # Launch evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training(1 pcs)
|
||||
├── src
|
||||
│ ├── config.py # Parameter configuration
|
||||
│ ├── model_utils
|
||||
│ ├── config.py # Parameter config
|
||||
│ ├── moxing_adapter.py # modelarts device configuration
|
||||
│ └── device_adapter.py # Device Config
|
||||
│ └── local_adapter.py # local device config
|
||||
│ ├── crnn.py # crnn network definition
|
||||
│ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
|
@ -140,6 +144,8 @@ crnn
|
|||
│ └── svt_dataset.py # Data preprocessing for SVT
|
||||
└── train.py # Training script
|
||||
├── eval.py # Evaluation Script
|
||||
├── default_config.yaml # config file
|
||||
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
@ -156,7 +162,7 @@ Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
|||
|
||||
#### Parameters Configuration
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
Parameters for both training and evaluation can be set in default_config.yaml.
|
||||
|
||||
```shell
|
||||
max_text_length": 23, # max number of digits in each
|
||||
|
@ -210,6 +216,59 @@ epoch: 10 step: 14110, loss is 0.0029097411
|
|||
Epoch time: 2743.688s, per step time: 0.097s
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training dpn on modelarts :
|
||||
# Data set storage method
|
||||
|
||||
# ├── crnn_dataset # dataset dir
|
||||
# ├──train # train dir
|
||||
# ├── mnt # train dataset dir
|
||||
# ├── pred_trained # pred_train
|
||||
# ├── eval # eval dir
|
||||
# ├── IIIT5K-Word_V3.0 # eval dataset dir
|
||||
# ├── checkpoint # checkpoint dir
|
||||
# ├── svt # checkpoint dir
|
||||
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "run_distribute=True"
|
||||
# set "save_checkpoint_path=/cache/train/checkpoint"
|
||||
# set "train_dataset_path=/cache/data/mnt/ramdisk/max/90kDICT32px"
|
||||
#
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/crnn"。
|
||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../crnn_dataset/train"(choices crnn_dataset/train Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a.set "enable_modelarts=True"
|
||||
# set "eval_dataset=svt" or eval_dataset=iiit5k
|
||||
# set "eval_dataset_path=/cache/data/svt/converted/img/" or eval_dataset_path=/cache/data/IIIT5K-Word_V3/IIIT5K/
|
||||
# set "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/crnn"。
|
||||
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../crnn_dataset/eval"(choices crnn/eval Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#contents)
|
||||
|
@ -241,6 +300,27 @@ python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [
|
|||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
- Export MindIR on Modelarts
|
||||
|
||||
```Modelarts
|
||||
Export MindIR example on ModelArts
|
||||
Data storage method is the same as training
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "file_name=/cache/train/crnn"
|
||||
# set "file_format=MINDIR"
|
||||
# set "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/crnn"。
|
||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../crnn_dataset/eval/checkpoint"(choices crnn_dataset/eval/checkpoint Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
```
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
Before performing inference, the mindir file must bu exported by export script on the 910 environment. We only provide an example of inference using MINDIR model.
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
run_distribute: False
|
||||
model: "lowercase"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
label_dict: "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
train_dataset: "synth"
|
||||
max_text_length: 23
|
||||
image_width: 100
|
||||
image_height: 32
|
||||
batch_size: 64
|
||||
epoch_size: 10
|
||||
hidden_size: 256
|
||||
learning_rate: 0.02
|
||||
momentum: 0.95
|
||||
nesterov: True
|
||||
save_checkpoint: True
|
||||
save_checkpoint_steps: 1000
|
||||
keep_checkpoint_max: 30
|
||||
save_checkpoint_path: "./"
|
||||
class_num: 37
|
||||
input_size: 512
|
||||
num_step: 24
|
||||
use_dropout: True
|
||||
blank: 36
|
||||
train_dataset_path: ""
|
||||
train_eval_dataset: "svt"
|
||||
train_eval_dataset_path: ""
|
||||
run_eval: False
|
||||
save_best_ckpt: True
|
||||
eval_start_epoch: 5
|
||||
eval_interval: 5
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
eval_dataset: "svt"
|
||||
eval_dataset_path: ""
|
||||
checkpoint_path: ""
|
||||
|
||||
# ======================================================================================
|
||||
# export options
|
||||
device_id: 0
|
||||
ckpt_file: ""
|
||||
file_name: "crnn"
|
||||
file_format: "MINDIR"
|
||||
|
||||
# ======================================================================================
|
||||
#postprocess
|
||||
ann_file: True
|
||||
result_path: True
|
||||
dataset: "ic03"
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
file_name: "CNN&CTC output air name"
|
||||
file_format: "choices [AIR, MINDIR]"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
run_distribute: "Run distribute, default is false."
|
||||
train_dataset_path: "train Dataset path, default is None"
|
||||
model: "Model type, default is lowercase"
|
||||
train_dataset: "choices [synth, ic03, ic13, svt, iiit5k]"
|
||||
train_eval_dataset: "choices [synth, ic03, ic13, svt, iiit5k]"
|
||||
train_eval_dataset_path: "Dataset path, default is None"
|
||||
run_eval: "Run evaluation when training, default is False."
|
||||
save_best_ckpt: "Save best checkpoint when run_eval is True, default is True."
|
||||
eval_start_epoch: "Evaluation start epoch when run_eval is True, default is 5."
|
||||
eval_interval: "Evaluation interval when run_eval is True, default is 5."
|
||||
eval_dataset_path: "eval Dataset, default is None."
|
||||
checkpoint_path: "checkpoint file path, default is None"
|
||||
ann_file: "ann file."
|
||||
result_path: "image file path."
|
||||
dataset: "choices=['ic03', 'ic13', 'svt', 'iiit5k']"
|
|
@ -13,60 +13,55 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import crnn
|
||||
from src.metric import CRNNAccuracy
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CRNN eval")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.model == 'lowcase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def crnn_eval():
|
||||
if config.device_target == 'Ascend':
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
config.batch_size = 1
|
||||
max_text_length = config.max_text_length
|
||||
input_size = config.input_size
|
||||
# input_size = config.input_size
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset,
|
||||
dataset_path=args_opt.dataset_path,
|
||||
dataset = create_dataset(name=config.eval_dataset,
|
||||
dataset_path=config.eval_dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# step_size = dataset.get_dataset_size()
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = crnn(config)
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||
res = model.eval(dataset, dataset_sink_mode=config.device_target == 'Ascend')
|
||||
print("result:", res, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
crnn_eval()
|
||||
|
|
|
@ -14,34 +14,34 @@
|
|||
# ============================================================================
|
||||
|
||||
""" export model for CRNN """
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, context, load_checkpoint, export
|
||||
|
||||
from src.crnn import crnn
|
||||
from src.config import config1 as config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="CRNN_export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="crnn", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
config.batch_size = 1
|
||||
net = crnn(config)
|
||||
|
||||
load_checkpoint(args.ckpt_file, net=net)
|
||||
load_checkpoint(config.ckpt_file, net=net)
|
||||
net.set_train(False)
|
||||
|
||||
input_data = Tensor(np.zeros([1, 3, config.image_height, config.image_width]), ms.float32)
|
||||
|
||||
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
||||
|
|
|
@ -14,17 +14,10 @@
|
|||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from src.metric import CRNNAccuracy
|
||||
from src.config import config1 as config
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="yolov3_darknet53 inference")
|
||||
parser.add_argument("--ann_file", type=str, required=True, help="ann file.")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="image file path.")
|
||||
parser.add_argument("--dataset", type=str, default="ic03", choices=['ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
args = parser.parse_args()
|
||||
|
||||
def read_annotation(ann_file):
|
||||
file = open(ann_file)
|
||||
|
@ -37,6 +30,7 @@ def read_annotation(ann_file):
|
|||
|
||||
return ann
|
||||
|
||||
|
||||
def read_ic13_annotation(ann_file):
|
||||
file = open(ann_file)
|
||||
|
||||
|
@ -48,6 +42,7 @@ def read_ic13_annotation(ann_file):
|
|||
|
||||
return ann
|
||||
|
||||
|
||||
def read_svt_annotation(ann_file):
|
||||
file = open(ann_file)
|
||||
|
||||
|
@ -59,17 +54,18 @@ def read_svt_annotation(ann_file):
|
|||
|
||||
return ann
|
||||
|
||||
|
||||
def get_eval_result(result_path, ann_file):
|
||||
"""
|
||||
Calculate accuracy according to the annotation file and result file.
|
||||
"""
|
||||
metrics = CRNNAccuracy(config)
|
||||
|
||||
if args.dataset == "ic03" or args.dataset == "iiit5k":
|
||||
if config.dataset == "ic03" or config.dataset == "iiit5k":
|
||||
ann = read_annotation(ann_file)
|
||||
elif args.dataset == "ic13":
|
||||
elif config.dataset == "ic13":
|
||||
ann = read_ic13_annotation(ann_file)
|
||||
elif args.dataset == "svt":
|
||||
elif config.dataset == "svt":
|
||||
ann = read_svt_annotation(ann_file)
|
||||
|
||||
for img_name, label in ann.items():
|
||||
|
@ -80,5 +76,6 @@ def get_eval_result(result_path, ann_file):
|
|||
print("result CRNNAccuracy is: ", metrics.eval())
|
||||
metrics.clear()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_eval_result(args.result_path, args.ann_file)
|
||||
get_eval_result(config.result_path, config.ann_file)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
echo "Usage: sh scripts/run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -51,12 +51,13 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
|||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp -r scripts/ ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -58,12 +58,14 @@ run_ascend() {
|
|||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cp -r ./scripts ./eval
|
||||
cp ./*yaml ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||
python eval.py --eval_dataset=$DATASET_NAME --eval_dataset_path=$1 --checkpoint_path=$2 --device_target=Ascend> log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
|
@ -72,15 +74,16 @@ run_gpu() {
|
|||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cp -r ./scripts ./eval
|
||||
cp ./*yaml ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --dataset=$DATASET_NAME \
|
||||
--dataset_path=$1 \
|
||||
python eval.py --eval_dataset=$DATASET_NAME \
|
||||
--eval_dataset_path=$1 \
|
||||
--checkpoint_path=$2 \
|
||||
--platform=GPU \
|
||||
--dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
--device_target=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ] && [ $# != 2 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)"
|
||||
echo "Usage: sh scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -49,13 +49,13 @@ run_ascend() {
|
|||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||
python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$1 --device_target=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
||||
python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$1 --device_target=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
|
@ -63,9 +63,12 @@ if [ -d "train" ]; then
|
|||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train${DEVICE_ID}
|
||||
rm -rf $WORKDIR
|
||||
mkdir $WORKDIR
|
||||
cp ../*.py $WORKDIR
|
||||
cp -r ../src $WORKDIR
|
||||
cp ./*.py $WORKDIR
|
||||
cp -r ./src $WORKDIR
|
||||
cp -r ./scripts $WORKDIR
|
||||
cp ./*yaml $WORKDIR
|
||||
cd $WORKDIR || exit
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
|
||||
|
||||
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
|
||||
# use for low case number
|
||||
config1 = EasyDict({
|
||||
"max_text_length": 23,
|
||||
"image_width": 100,
|
||||
"image_height": 32,
|
||||
"batch_size": 64,
|
||||
"epoch_size": 10,
|
||||
"hidden_size": 256,
|
||||
"learning_rate": 0.02,
|
||||
"momentum": 0.95,
|
||||
"nesterov": True,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 1000,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./",
|
||||
"class_num": 37,
|
||||
"input_size": 512,
|
||||
"num_step": 24,
|
||||
"use_dropout": True,
|
||||
"blank": 36
|
||||
})
|
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vc
|
||||
from src.config import config1, label_dict
|
||||
from src.model_utils.config import config as config1
|
||||
from src.ic03_dataset import IC03Dataset
|
||||
from src.ic13_dataset import IC13Dataset
|
||||
from src.iiit5k_dataset import IIIT5KDataset
|
||||
|
@ -75,8 +75,8 @@ class CaptchaDataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.model_utils.config import config as config1
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ class IC03Dataset:
|
|||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
if c not in config.label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
|
@ -73,8 +73,8 @@ class IC03Dataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.model_utils.config import config as config1
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ class IC13Dataset:
|
|||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
if c not in config.label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
|
@ -70,8 +70,8 @@ class IC13Dataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.model_utils.config import config as config1
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class IIIT5KDataset:
|
||||
|
@ -62,8 +62,8 @@ class IIIT5KDataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
_config_path = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.model_utils.config import config as config1
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class SVTDataset:
|
||||
|
@ -60,8 +60,8 @@ class SVTDataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
# ============================================================================
|
||||
"""crnn training"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
@ -24,43 +22,20 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.nn.wrap import WithLossCell
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import crnn
|
||||
from src.crnn_for_train import TrainOneStepCellWithGradClip
|
||||
from src.metric import CRNNAccuracy
|
||||
from src.eval_callback import EvalCallBack
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_rank_id, get_device_num, get_device_id
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="crnn training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'],
|
||||
help='Running platform, only support Ascend now. Default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=5,
|
||||
help="Evaluation start epoch when run_eval is True, default is 5.")
|
||||
parser.add_argument("--eval_interval", type=int, default=5,
|
||||
help="Evaluation interval when run_eval is True, default is 5.")
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
|
||||
if args_opt.model == 'lowercase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
def apply_eval(eval_param):
|
||||
evaluation_model = eval_param["model"]
|
||||
|
@ -69,17 +44,27 @@ def apply_eval(eval_param):
|
|||
res = evaluation_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr_scale = 1
|
||||
if args_opt.run_distribute:
|
||||
if args_opt.platform == 'Ascend':
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
if config.device_target == 'Ascend':
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# lr_scale = 1
|
||||
if config.run_distribute:
|
||||
if config.device_target == 'Ascend':
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
# lr_scale = 1
|
||||
device_num = get_device_num()
|
||||
rank = get_rank_id()
|
||||
else:
|
||||
init()
|
||||
lr_scale = 1
|
||||
# lr_scale = 1
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -92,7 +77,8 @@ if __name__ == '__main__':
|
|||
|
||||
max_text_length = config.max_text_length
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
|
||||
dataset = create_dataset(name=config.train_dataset, dataset_path=config.train_dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
num_shards=device_num, shard_id=rank, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
|
@ -111,18 +97,18 @@ if __name__ == '__main__':
|
|||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
if args_opt.run_eval:
|
||||
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
|
||||
eval_dataset = create_dataset(name=args_opt.eval_dataset,
|
||||
dataset_path=args_opt.eval_dataset_path,
|
||||
if config.run_eval:
|
||||
if config.train_eval_dataset_path is None or (not os.path.isdir(config.train_eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(config.train_eval_dataset_path))
|
||||
eval_dataset = create_dataset(name=config.train_eval_dataset,
|
||||
dataset_path=config.train_eval_dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="acc")
|
||||
callbacks += [eval_cb]
|
||||
|
@ -132,3 +118,7 @@ if __name__ == '__main__':
|
|||
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue