!19144 Add support of GPU to CRNN

Merge pull request !19144 from lear/dev2
This commit is contained in:
i-robot 2021-07-02 02:05:35 +00:00 committed by Gitee
commit 9055537f62
14 changed files with 256 additions and 202 deletions

View File

@ -57,8 +57,8 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- Hardware
- Prepare hardware environment with Ascend processor or GPU.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
@ -73,19 +73,32 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
```shell
# distribute training example in Ascend
$ bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
$ bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] Ascend [RANK_TABLE_FILE]
# evaluation example in Ascend
$ bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
$ bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] Ascend
# standalone training example in Ascend
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
$ bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] Ascend
# offline inference on Ascend310
$ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE_PATH] [DATASET] [DEVICE_ID]
$ bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE_PATH] [DATASET] [DEVICE_ID]
```
- Running on GPU
```shell
# distribute training example in GPU
$ bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] GPU
# evaluation example in GPU
$ bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] GPU
# standalone training example in GPU
$ bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] GPU
```
DATASET_NAME is one of `ic03`, `ic13`, `svt`, `iiit5k`, `synth`.
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
@ -123,25 +136,25 @@ crnn
├── convert_svt.py # Convert the original SVT dataset
├── requirements.txt # Requirements for this dataset
├── scripts
   ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
   ├── run_eval.sh # Launch evaluation
   └── run_standalone_train.sh # Launch standalone training(1 pcs)
├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
├── run_eval.sh # Launch evaluation
└── run_standalone_train.sh # Launch standalone training(1 pcs)
├── src
│ ├── model_utils
   ├── config.py # Parameter config
   ├── moxing_adapter.py # modelarts device configuration
   └── device_adapter.py # Device Config
   └── local_adapter.py # local device config
   ├── crnn.py # crnn network definition
   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
   ├── dataset.py # Data preprocessing for training and evaluation
   ├── eval_callback.py
   ├── ic03_dataset.py # Data preprocessing for IC03
   ├── ic13_dataset.py # Data preprocessing for IC13
   ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
   ├── loss.py # Ctcloss definition
   ├── metric.py # accuracy metric for crnn network
   └── svt_dataset.py # Data preprocessing for SVT
├── config.py # Parameter config
├── moxing_adapter.py # modelarts device configuration
└── device_adapter.py # Device Config
└── local_adapter.py # local device config
├── crnn.py # crnn network definition
├── crnn_for_train.py # crnn network with grad, loss and gradient clip
├── dataset.py # Data preprocessing for training and evaluation
├── eval_callback.py
├── ic03_dataset.py # Data preprocessing for IC03
├── ic13_dataset.py # Data preprocessing for IC13
├── iiit5k_dataset.py # Data preprocessing for IIIT5K
├── loss.py # Ctcloss definition
├── metric.py # accuracy metric for crnn network
└── svt_dataset.py # Data preprocessing for SVT
└── train.py # Training script
├── eval.py # Evaluation Script
├── default_config.yaml # config file
@ -153,11 +166,11 @@ crnn
#### Training Script Parameters
```shell
# distributed training in Ascend
Usage: bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
# distributed training
Usage: bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)
# standalone training
Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
Usage: bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
```
#### Parameters Configuration
@ -195,18 +208,18 @@ max_text_length": 23, # max number of digits in each
### [Training](#contents)
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, only support Ascend now.
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, support Ascend and GPU now.
``` bash
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
```
#### [Distributed Training](#contents)
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend.
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend or GPU
``` bash
bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)
```
Check the `train_parallel0/log.txt` and you will get outputs as following:
@ -276,7 +289,7 @@ Epoch time: 2743.688s, per step time: 0.097s
- Run `run_eval.sh` for evaluation.
``` bash
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
```
Check the `eval/log.txt` and you will get outputs as following:
@ -352,37 +365,37 @@ result CRNNAccuracy is: 0.806666666666
#### [Training Performance](#contents)
| Parameters | Ascend 910 |
| -------------------------- | --------------------------------------------------|
| Model Version | v1.0 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
| uploaded Date | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.1 |
| Dataset | Synth |
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 |
| Optimizer | SGD |
| Loss Function | CTCLoss |
| outputs | probability |
| Loss | 0.0029097411 |
| Speed | 118ms/step(8pcs) |
| Total time | 557 mins |
| Parameters (M) | 83M (.ckpt file) |
| Checkpoint for Fine tuning | 20.3M (.ckpt file) |
| Parameters | Ascend 910 | Tesla V100 |
| -------------------------- | --------------------------------------------------|---------------------------------------------------|
| Model Version | v1.0 | v2.0 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | Tesla V100; CPU 2.60GHz, 72cores; Memory 256G; OS Ubuntu 18.04.3 |
| uploaded Date | 12/15/2020 (month/day/year) | 6/11/2021 (month/day/year) |
| MindSpore Version | 1.0.1 | 1.2.0 |
| Dataset | Synth | Synth |
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 | epoch=10, steps per epoch=14110, batch_size = 64 |
| Optimizer | SGD | SGD |
| Loss Function | CTCLoss | CTCLoss |
| outputs | probability | probability |
| Loss | 0.0029097411 | 0.0029097411 |
| Speed | 118ms/step(8pcs) | 36ms/step(8pcs) |
| Total time | 557 mins | 189 mins |
| Parameters (M) | 83M (.ckpt file) | 96M |
| Checkpoint for Fine tuning | 20.3M (.ckpt file) | |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) |
#### [Evaluation Performance](#contents)
| Parameters | SVT | IIIT5K |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | V1.0 | V1.0 |
| Resource | Ascend 910; OS Euler2.8 | Ascend 910 |
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.1 | 1.0.1 |
| Dataset | SVT | IIIT5K |
| batch_size | 1 | 1 |
| outputs | ACC | ACC |
| Accuracy | 80.8% | 79.7% |
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) |
| Parameters | SVT | IIIT5K | SVT | IIIT5K |
| ------------------- | --------------------------- | --------------------------- | --------------------------- | --------------------------- |
| Model Version | V1.0 | V1.0 | V2.0 | V2.0 |
| Resource | Ascend 910; OS Euler2.8 | Ascend 910 | Tesla V100 | Tesla V100 |
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) | 6/11/2021 (month/day/year) | 6/11/2021 (month/day/year) |
| MindSpore Version | 1.0.1 | 1.0.1 | 1.2.0 | 1.2.0 |
| Dataset | SVT | IIIT5K | SVT | IIIT5K |
| batch_size | 1 | 1 | 1 | 1 |
| outputs | ACC | ACC | ACC | ACC |
| Accuracy | 80.8% | 79.7% | 81.92% | 80.2% |
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) | 96M (.ckpt file) | 96M (.ckpt file) |
## [Description of Random Situation](#contents)

View File

@ -8,7 +8,7 @@ checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
device_target: "GPU"
enable_profiling: False
# ======================================================================================
@ -32,6 +32,7 @@ nesterov: True
save_checkpoint: True
save_checkpoint_steps: 1000
keep_checkpoint_max: 30
per_print_time: 100
save_checkpoint_path: "./"
class_num: 37
input_size: 512
@ -42,9 +43,10 @@ train_dataset_path: ""
train_eval_dataset: "svt"
train_eval_dataset_path: ""
run_eval: False
eval_all_saved_ckpts: False
save_best_ckpt: True
eval_start_epoch: 5
eval_interval: 5
eval_interval: 1
# ======================================================================================
# Eval options

View File

@ -51,7 +51,7 @@ def crnn_eval():
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = crnn(config)
net = crnn(config, full_precision=config.device_target == 'GPU')
# load checkpoint
param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(net, param_dict)

View File

@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh scripts/run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
if [ $# != 4 ] && [ $# != 3 ] && [ $# != 6 ] && [ $# != 5 ]; then
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)"
exit 1
fi
@ -28,36 +28,59 @@ get_real_path() {
}
DATASET_NAME=$1
PATH1=$(get_real_path $2)
PATH2=$(get_real_path $3)
PLATFORM=$3
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
PATH2=$(get_real_path $2)
if [ ! -d $PATH2 ]; then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ./*.py ./train_parallel$i
cp -r scripts/ ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./*yaml ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
if [ "GPU" == $PLATFORM ]; then
if [ -d "train" ]; then
rm -rf ./train
fi
WORKDIR=./train_parallel
rm -rf $WORKDIR
mkdir $WORKDIR
cp ./*.py $WORKDIR
cp -r ./src $WORKDIR
cp -r ./scripts $WORKDIR
cp ./*yaml $WORKDIR
cd $WORKDIR || exit
echo "start distributed training with $DEVICE_NUM GPUs."
env >env.log
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$PATH2 --device_target=GPU --run_distribute=True > log.txt 2>&1 &
cd ..
done
elif [ "Ascend" == $PLATFORM ]; then
PATH1=$(get_real_path $4)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ./*.py ./train_parallel$i
cp -r scripts/ ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./*yaml ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
cd ..
done
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

View File

@ -16,10 +16,9 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal
def _bn(channel):
@ -71,6 +70,27 @@ class VGG(nn.Cell):
return x
class BidirectionalLSTM(nn.Cell):
def __init__(self, nIn, nHidden, nOut, batch_size):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Dense(in_channels=nHidden * 2, out_channels=nOut)
self.h0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
self.c0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
def construct(self, x):
recurrent, _ = self.rnn(x, (self.h0, self.c0))
T, b, h = P.Shape()(recurrent)
t_rec = P.Reshape()(recurrent, (T * b, h,))
out = self.embedding(t_rec) # [T * b, nOut]
out = P.Reshape()(out, (T, b, -1,))
return out
class CRNN(nn.Cell):
"""
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
@ -88,86 +108,21 @@ class CRNN(nn.Cell):
self.hidden_size = config.hidden_size
self.num_classes = config.class_num
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1 = Parameter(w1.astype(np.float32), name="w1")
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2 = Parameter(w2.astype(np.float32), name="w2")
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.vgg = VGG()
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.5)
self.rnn_dropout = nn.Dropout(0.9)
self.use_dropout = config.use_dropout
self.rnn = nn.SequentialCell([
BidirectionalLSTM(self.input_size, self.hidden_size, self.hidden_size, self.batch_size),
BidirectionalLSTM(self.hidden_size, self.hidden_size, self.num_classes, self.batch_size)])
def construct(self, x):
x = self.vgg(x)
x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
bw_x = self.reverse_seq1(x, self.seq_length)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
y1_out = self.concat1((y1, y1_bw))
if self.use_dropout:
y1_out = self.rnn_dropout(y1_out)
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
bw_y = self.reverse_seq3(y1_out, self.seq_length)
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
y2_out = self.concat1((y2, y2_bw))
if self.use_dropout:
y2_out = self.dropout(y2_out)
x = self.rnn(x)
output = ()
for i in range(F.shape(y2_out)[0]):
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
y2_after_fc = self.expand_dims(y2_after_fc, 0)
output += (y2_after_fc,)
output = self.concat(output)
return output
return x
def crnn(config, full_precision=False):

View File

@ -28,6 +28,21 @@ from src.svt_dataset import SVTDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
def check_image_is_valid(image):
if image is None:
return False
h, w, c = image.shape
if h * w * c == 0:
return False
return True
letters = [letter for letter in config1.label_dict]
def text_to_labels(text):
return list(map(lambda x: letters.index(x.lower()), text))
class CaptchaDataset:
"""
create train or evaluation dataset for crnn
@ -61,24 +76,37 @@ class CaptchaDataset:
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
self.sample_num = len(self.img_names)
self.batch_size = config.batch_size
print("There are totally {} samples".format(self.sample_num))
def __len__(self):
return len(self.img_names)
return self.sample_num
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
try:
im = Image.open(os.path.join(self.img_root_dir, img_name))
except IOError:
print("%s is a corrupted image" % img_name)
return self[item + 1]
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in config.label_dict:
label.append(config.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
if not check_image_is_valid(image):
print("%s is a corrupted image" % img_name)
return self[item + 1]
text = self.img_names[img_name]
label_unexpanded = text_to_labels(text)
label = np.full(self.max_text_length, self.blank)
if self.max_text_length < len(label_unexpanded):
label_len = self.max_text_length
else:
label_len = len(label_unexpanded)
for j in range(label_len):
label[j] = label_unexpanded[j]
return image, label

View File

@ -16,7 +16,8 @@
import os
import stat
from mindspore import save_checkpoint
import glob
from mindspore import save_checkpoint, load_checkpoint, load_param_into_net
from mindspore import log as logger
from mindspore.train.callback import Callback
@ -30,7 +31,7 @@ class EvalCallBack(Callback):
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
best_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
@ -41,7 +42,7 @@ class EvalCallBack(Callback):
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
eval_all_saved_ckpts=False, ckpt_directory="./", best_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
@ -50,11 +51,14 @@ class EvalCallBack(Callback):
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.eval_all_saved_ckpts = eval_all_saved_ckpts
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.ckpt_directory = ckpt_directory
self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
self.last_ckpt_path = os.path.join(ckpt_directory, "last.ckpt")
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
@ -72,20 +76,41 @@ class EvalCallBack(Callback):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
if self.eval_all_saved_ckpts:
ckpt_list = glob.glob(os.path.join(self.ckpt_directory, "crnn*.ckpt"))
net = self.eval_param_dict["model"].train_network
save_checkpoint(net, self.last_ckpt_path)
for ckpt_path in ckpt_list:
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
res = self.eval_function(self.eval_param_dict)
print("{}: {}".format(self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_epoch = cur_epoch
self.best_res = res
print("update best result: {}".format(res), flush=True)
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)
if self.save_best_ckpt:
save_checkpoint(net, self.best_ckpt_path)
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
param_dict = load_checkpoint(self.last_ckpt_path)
load_param_into_net(net, param_dict)
self.remove_ckpoint_file(self.last_ckpt_path)
else:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -59,6 +59,7 @@ class IC03Dataset:
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
self.label_dict = config.label_dict
def __len__(self):
return len(self.img_names)
@ -73,8 +74,8 @@ class IC03Dataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in config.label_dict:
label.append(config.label_dict.index(c))
if c in self.label_dict:
label.append(self.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -58,6 +58,7 @@ class IC13Dataset:
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
self.label_dict = config.label_dict
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
@ -70,8 +71,8 @@ class IC13Dataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in config.label_dict:
label.append(config.label_dict.index(c))
if c in self.label_dict:
label.append(self.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -48,6 +48,7 @@ class IIIT5KDataset:
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
self.label_dict = config.label_dict
def __len__(self):
return len(self.img_names)
@ -62,8 +63,8 @@ class IIIT5KDataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in config.label_dict:
label.append(config.label_dict.index(c))
if c in self.label_dict:
label.append(self.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -14,13 +14,13 @@
# ============================================================================
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import Loss
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(Loss):
class CTCLoss(_Loss):
"""
CTCLoss definition

View File

@ -22,12 +22,13 @@ class CRNNAccuracy(nn.Metric):
Define accuracy metric for warpctc network.
"""
def __init__(self, config):
def __init__(self, config, print_flag=True):
super(CRNNAccuracy).__init__()
self.config = config
self._correct_num = 0
self._total_num = 0
self.blank = config.blank
self.print_flag = print_flag
def clear(self):
self._correct_num = 0
@ -45,7 +46,8 @@ class CRNNAccuracy(nn.Metric):
str_label = self._convert_labels(y)
for pred, label in zip(str_pred, str_label):
print(pred, " :: ", label)
if self.print_flag:
print(pred, " :: ", label)
edit_distance = Levenshtein.distance(pred, label)
self._total_num += 1
if edit_distance == 0:

View File

@ -46,6 +46,7 @@ class SVTDataset:
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
self.label_dict = config.label_dict
def __len__(self):
return len(self.img_names)
@ -60,8 +61,8 @@ class SVTDataset:
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in config.label_dict:
label.append(config.label_dict.index(c))
if c in self.label_dict:
label.append(self.label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -81,13 +81,14 @@ def train():
batch_size=config.batch_size,
num_shards=device_num, shard_id=rank, config=config)
step_size = dataset.get_dataset_size()
print("step_size:", step_size)
# define lr
lr_init = config.learning_rate
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = crnn(config)
net = crnn(config, full_precision=config.device_target == 'GPU')
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
net_with_loss = WithLossCell(net, loss)
@ -95,9 +96,10 @@ def train():
# define model
model = Model(net_with_grads)
# define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
callbacks = [LossMonitor(per_print_times=config.per_print_time),
TimeMonitor(data_size=step_size)]
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
if config.run_eval:
if config.run_eval and rank == 0:
if config.train_eval_dataset_path is None or (not os.path.isdir(config.train_eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(config.train_eval_dataset_path))
eval_dataset = create_dataset(name=config.train_eval_dataset,
@ -105,19 +107,19 @@ def train():
batch_size=config.batch_size,
is_training=False,
config=config)
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config, print_flag=False)})
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
metrics_name="acc")
ckpt_directory=save_ckpt_path, best_ckpt_name="best_acc.ckpt",
eval_all_saved_ckpts=config.eval_all_saved_ckpts, metrics_name="acc")
callbacks += [eval_cb]
if config.save_checkpoint and rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(config.epoch_size, dataset, callbacks=callbacks)
model.train(config.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=config.device_target == 'Ascend')
if __name__ == '__main__':