forked from mindspore-Ecosystem/mindspore
!17086 crnn_seq2seq_ocr used on ModelArts.
From: @ZhengBina Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
cb948586ca
|
@ -48,18 +48,17 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released
|
|||
## [Quick Start](#contents)
|
||||
|
||||
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```shell
|
||||
# distribute training example in Ascend
|
||||
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
|
||||
# evaluation example in Ascend
|
||||
$ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
$ bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]
|
||||
|
||||
# standalone training example in Ascend
|
||||
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
$ bash run_standalone_train.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
@ -67,6 +66,56 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released
|
|||
Please follow the instructions in the link below:
|
||||
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
- Running on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows.
|
||||
|
||||
- Training with 8 cards on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/crnn_seq2seq_ocr/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml.
|
||||
# 1. Set ”is_distributed=1“
|
||||
# 2. Set ”enable_modelarts=True“
|
||||
# 3. Set ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”is_distributed=1“
|
||||
# 2. Add ”enable_modelarts=True“
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of 8 cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/crnn_seq2seq_ocr/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -79,9 +128,13 @@ crnn-seq2seq-ocr
|
|||
│ ├── run_eval_ascend.sh # Launch Ascend evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs)
|
||||
├── src
|
||||
| |── scripts
|
||||
│ | ├── config.py # parsing parameter configuration file of "*.yaml"
|
||||
│ | ├── device_adapter.py # local or ModelArts training
|
||||
│ | ├── local_adapter.py # get related environment variables in local training
|
||||
│ | └── moxing_adapter.py # get related environment variables in ModelArts training
|
||||
│ ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper
|
||||
│ ├── cnn.py # VGG network
|
||||
│ ├── config.py # Parameter configuration
|
||||
│ ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── gru.py # GRU cell wrapper
|
||||
|
@ -90,8 +143,9 @@ crnn-seq2seq-ocr
|
|||
│ ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure
|
||||
│ └── utils.py # Utility functions for training and data pre-processing
|
||||
│ ├── weight_init.py # weight initialization of LSTM and GRU
|
||||
└── train.py # Training script
|
||||
├── eval.py # Evaluation Script
|
||||
├── general_chars.txt # general chars
|
||||
└── train.py # Training script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
@ -100,10 +154,10 @@ crnn-seq2seq-ocr
|
|||
|
||||
```shell
|
||||
# distributed training on Ascend
|
||||
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_PATH]
|
||||
Usage: bash run_standalone_train.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
@ -116,14 +170,14 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
- Set options in `default_config.yaml`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_PATH]
|
||||
bash run_standalone_train.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
@ -131,7 +185,7 @@ bash run_standalone_train.sh [DATASET_PATH]
|
|||
- Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
Check the `train_parallel0/log.txt` and you will get outputs as following:
|
||||
|
@ -149,7 +203,7 @@ epoch time: 1559886.096 ms, per step time: 382.231 ms
|
|||
- Run `run_eval_ascend.sh` for evaluation on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
Check the `eval/log` and you will get outputs as following:
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: None
|
||||
# ==============================================================================
|
||||
#train-related
|
||||
is_distributed: 0
|
||||
rank_id: 0
|
||||
train_data_dir: ''
|
||||
batch_size: 32
|
||||
num_epochs: 20
|
||||
keep_checkpoint_max: 20
|
||||
#eval-related
|
||||
eval_batch_size: 32
|
||||
test_data_dir: ''
|
||||
checkpoint_path: None
|
||||
# logging-related
|
||||
log_interval: 100
|
||||
pre_checkpoint_path: ''
|
||||
ckpt_path: "outputs/"
|
||||
ckpt_interval: None
|
||||
is_save_on_master: 0
|
||||
# dataset-related
|
||||
mindrecord_dir: ''
|
||||
data_root: ''
|
||||
annotation_file: ''
|
||||
val_data_root: ''
|
||||
val_annotation_file: ''
|
||||
data_json: ''
|
||||
go_shift: 1
|
||||
characters_dictionary: {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3}
|
||||
labels_not_use: ['%#<23>?%', '%#背景#%', '%#不识<E4B88D>?%', '#%不识<E4B88D>?#', '%#模糊#%', '%#模糊#%']
|
||||
vocab_path: "./general_chars.txt"
|
||||
# model-related
|
||||
img_width: 512
|
||||
img_height: 128
|
||||
channel_size: 3
|
||||
conv_out_dim: 384
|
||||
encoder_hidden_size: 128
|
||||
decoder_hidden_size: 128
|
||||
decoder_output_size: 10000
|
||||
dropout_p: 0.1
|
||||
max_length: 64
|
||||
attn_num_layers: 1
|
||||
teacher_force_ratio: 0.5
|
||||
#optimizer-related
|
||||
lr: 0.0008
|
||||
adam_beta1: 0.5
|
||||
adam_beta2: 0.999
|
||||
loss_scale: 1024
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
is_distributed: 'Distribute train or not, 1 for yes, 0 for no. Default: 0'
|
||||
rank_id: "Local rank of distributed. Default: 0"
|
||||
train_data_dir: "Train dataset directory."
|
||||
|
||||
log_interval: "Logging interval steps. Default: 100"
|
||||
ckpt_path: "Checkpoint save location. Default: outputs/"
|
||||
pre_checkpoint_path: "Checkpoint save location."
|
||||
ckpt_interval: "Save checkpoint interval. Default: None"
|
||||
is_save_on_master: "Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0"
|
||||
|
||||
test_data_dir: "Test Dataset path"
|
||||
checkpoint_path: "Checkpoint of AttentionOCR (Default:None)."
|
||||
|
||||
|
||||
|
|
@ -19,7 +19,6 @@ CRNN-Seq2Seq-OCR Evaluation.
|
|||
|
||||
import os
|
||||
import codecs
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
|
@ -29,11 +28,13 @@ from mindspore.common import set_seed
|
|||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.utils import initialize_vocabulary
|
||||
from src.dataset import create_ocr_val_dataset
|
||||
from src.attention_ocr import AttentionOCRInfer
|
||||
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -75,30 +76,20 @@ def LCS_length(str1, str2):
|
|||
|
||||
return lcs[len1 % 2][-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Test Dataset path")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None,
|
||||
help="Checkpoint of AttentionOCR (Default:None).")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
@moxing_wrapper()
|
||||
def run_eval():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
prefix = "fsns.mindrecord"
|
||||
mindrecord_dir = args.dataset_path
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if config.enable_modelarts:
|
||||
mindrecord_file = os.path.join(config.data_path, prefix + "0")
|
||||
else:
|
||||
mindrecord_file = os.path.join(config.test_data_dir, prefix + "0")
|
||||
print("mindrecord_file", mindrecord_file)
|
||||
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
|
||||
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
|
||||
print("Dataset creation Done!")
|
||||
|
||||
#Network
|
||||
# Network
|
||||
network = AttentionOCRInfer(config.eval_batch_size,
|
||||
int(config.img_width / 4),
|
||||
config.encoder_hidden_size,
|
||||
|
@ -106,15 +97,16 @@ if __name__ == '__main__':
|
|||
config.decoder_output_size,
|
||||
config.max_length,
|
||||
config.dropout_p)
|
||||
|
||||
ckpt = load_checkpoint(args.checkpoint_path)
|
||||
checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
|
||||
ckpt = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(network, ckpt)
|
||||
network.set_train(False)
|
||||
print("Checkpoint loading Done!")
|
||||
|
||||
vocab, rev_vocab = initialize_vocabulary(config.vocab_path)
|
||||
eos_id = config.characters_dictionary.get("eos_id")
|
||||
sos_id = config.characters_dictionary.get("go_id")
|
||||
vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path)
|
||||
_, rev_vocab = initialize_vocabulary(vocab_path)
|
||||
eos_id = config.characters_dictionary.eos_id
|
||||
sos_id = config.characters_dictionary.go_id
|
||||
|
||||
num_correct_char = 0
|
||||
num_total_char = 0
|
||||
|
@ -125,20 +117,20 @@ if __name__ == '__main__':
|
|||
incorrect_file = 'result_incorrect.txt'
|
||||
|
||||
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
|
||||
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
|
||||
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
|
||||
|
||||
for data in data_loader:
|
||||
images = Tensor(data["image"])
|
||||
decoder_inputs = Tensor(data["decoder_input"])
|
||||
decoder_targets = Tensor(data["decoder_target"])
|
||||
# decoder_targets = Tensor(data["decoder_target"])
|
||||
|
||||
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
|
||||
dtype=np.float16), mstype.float16)
|
||||
decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32))
|
||||
decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32))
|
||||
encoder_outputs = network.encoder(images)
|
||||
batch_decoded_label = []
|
||||
|
||||
for di in range(decoder_inputs.shape[1]):
|
||||
for _ in range(decoder_inputs.shape[1]):
|
||||
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = P.Argmax()(decoder_output)
|
||||
ni = P.ExpandDims()(topi, 1)
|
||||
|
@ -179,3 +171,5 @@ if __name__ == '__main__':
|
|||
print('\nnum of total words = %d' % (num_total_word))
|
||||
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
|
||||
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -39,9 +39,9 @@ fi
|
|||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
echo "error: TRAIN_DATA_DIR=$PATH2 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -58,9 +58,11 @@ do
|
|||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp ../*.txt ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log &
|
||||
python train.py --is_distribute=1 --train_data_dir=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
echo "Usage: sh run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -34,7 +34,7 @@ echo $PATH2
|
|||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a folder"
|
||||
echo "error: TEST_DATA_DIR=$PATH1 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -56,10 +56,11 @@ fi
|
|||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.txt ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
python eval.py --device_target="Ascend" --test_data_dir=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
cd ..
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]"
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -31,9 +31,9 @@ get_real_path(){
|
|||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a file"
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -50,9 +50,11 @@ fi
|
|||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp ../*.yaml ./train
|
||||
cp ../*.txt ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log &
|
||||
python train.py --train_data_dir=$PATH1 --is_distributed=0 &> log &
|
||||
cd ..
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""Config parameters for CRNN-Seq2Seq-OCR model."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
|
||||
config = ed({
|
||||
|
||||
# dataset-related
|
||||
"mindrecord_dir": "",
|
||||
"data_root": "",
|
||||
"annotation_file": "",
|
||||
|
||||
"val_data_root": "",
|
||||
"val_annotation_file": "",
|
||||
"data_json": "",
|
||||
|
||||
"go_shift": 1,
|
||||
"characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3},
|
||||
"labels_not_use": [u'%#<23>?%', u'%#背景#%', u'%#不识<E4B88D>?%', u'#%不识<EFBFBD>?#', u'%#模糊#%', u'%#模糊#%'],
|
||||
"vocab_path": "./general_chars.txt",
|
||||
|
||||
#model-related
|
||||
"img_width": 512,
|
||||
"img_height": 128,
|
||||
"channel_size": 3,
|
||||
"conv_out_dim": 384,
|
||||
"encoder_hidden_size": 128,
|
||||
"decoder_hidden_size": 128,
|
||||
"decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space
|
||||
"dropout_p": 0.1,
|
||||
"max_length": 64,
|
||||
"attn_num_layers": 1,
|
||||
"teacher_force_ratio": 0.5,
|
||||
|
||||
#optimizer-related
|
||||
"lr": 0.0008,
|
||||
"adam_beta1": 0.5,
|
||||
"adam_beta2": 0.999,
|
||||
"loss_scale": 1024,
|
||||
|
||||
#train-related
|
||||
"batch_size": 32,
|
||||
"num_epochs": 20,
|
||||
"keep_checkpoint_max": 20,
|
||||
|
||||
#eval-related
|
||||
"eval_batch_size": 32
|
||||
})
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
from src.model_utils.config import config
|
||||
from utils import initialize_vocabulary
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import mindspore.dataset.vision.py_transforms as P
|
|||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
class AugmentationOps():
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r', encoding='utf-8') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -16,10 +16,9 @@
|
|||
CRNN-Seq2Seq-OCR train.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
import time
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -31,62 +30,78 @@ from mindspore import context
|
|||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import create_ocr_train_dataset
|
||||
from src.logger import get_logger
|
||||
from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper
|
||||
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
def parse_args():
|
||||
"""Parse train arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training')
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# device related
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0,
|
||||
help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
|
||||
parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1')
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
#dataset related
|
||||
parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
|
||||
|
||||
parser.add_argument('--is_save_on_master', type=int, default=0,
|
||||
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
return args
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
|
||||
if args.is_distributed:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
if config.is_distributed:
|
||||
rank = get_rank_id()
|
||||
device_num = get_device_num()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
|
@ -96,25 +111,31 @@ def train():
|
|||
device_num = 1
|
||||
|
||||
# Logger
|
||||
args.logger = get_logger(args.outputs_dir, rank)
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
config.logger = get_logger(config.outputs_dir, rank)
|
||||
config.rank_save_ckpt_flag = 0
|
||||
if config.is_save_on_master:
|
||||
if rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
|
||||
# DATASET
|
||||
dataset = create_ocr_train_dataset(args.mindrecord_file,
|
||||
prefix = "fsns.mindrecord"
|
||||
if config.enable_modelarts:
|
||||
mindrecord_file = os.path.join(config.data_path, prefix + "0")
|
||||
else:
|
||||
mindrecord_file = os.path.join(config.train_data_dir, prefix + "0")
|
||||
dataset = create_ocr_train_dataset(mindrecord_file,
|
||||
config.batch_size,
|
||||
rank_size=device_num,
|
||||
rank_id=rank)
|
||||
args.steps_per_epoch = dataset.get_dataset_size()
|
||||
args.logger.info('Finish loading dataset')
|
||||
config.steps_per_epoch = dataset.get_dataset_size()
|
||||
config.logger.info('Finish loading dataset')
|
||||
|
||||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
args.logger.save_args(args)
|
||||
if not config.ckpt_interval:
|
||||
config.ckpt_interval = config.steps_per_epoch
|
||||
config.logger.save_args(config)
|
||||
|
||||
network = AttentionOCR(config.batch_size,
|
||||
int(config.img_width / 4),
|
||||
|
@ -124,8 +145,10 @@ def train():
|
|||
config.max_length,
|
||||
config.dropout_p)
|
||||
|
||||
if args.pre_checkpoint_path:
|
||||
param_dict = load_checkpoint(args.pre_checkpoint_path)
|
||||
if config.pre_checkpoint_path:
|
||||
config.pre_checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.pre_checkpoint_path
|
||||
)
|
||||
param_dict = load_checkpoint(config.pre_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
network = AttentionOCRWithLossCell(network, config.max_length)
|
||||
|
@ -136,13 +159,13 @@ def train():
|
|||
|
||||
network = TrainingWrapper(network, opt, sens=config.loss_scale)
|
||||
|
||||
args.logger.info('Finished get network')
|
||||
config.logger.info('Finished get network')
|
||||
|
||||
callback = [TimeMonitor(data_size=1), LossMonitor()]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch,
|
||||
if config.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/')
|
||||
save_ckpt_path = os.path.join(config.outputs_dir, 'checkpoints' + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix="crnn_seq2seq_ocr")
|
||||
|
@ -151,7 +174,7 @@ def train():
|
|||
model = Model(network)
|
||||
model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False)
|
||||
|
||||
args.logger.info('==========Training Done===============')
|
||||
config.logger.info('==========Training Done===============')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue