!19144 Add support of GPU to CRNN
Merge pull request !19144 from lear/dev2
This commit is contained in:
commit
9055537f62
|
@ -57,8 +57,8 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
|
|||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Hardware
|
||||
- Prepare hardware environment with Ascend processor or GPU.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
|
@ -73,19 +73,32 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
|
|||
|
||||
```shell
|
||||
# distribute training example in Ascend
|
||||
$ bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
$ bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] Ascend [RANK_TABLE_FILE]
|
||||
|
||||
# evaluation example in Ascend
|
||||
$ bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
$ bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] Ascend
|
||||
|
||||
# standalone training example in Ascend
|
||||
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
$ bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] Ascend
|
||||
|
||||
# offline inference on Ascend310
|
||||
$ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE_PATH] [DATASET] [DEVICE_ID]
|
||||
$ bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE_PATH] [DATASET] [DEVICE_ID]
|
||||
|
||||
```
|
||||
|
||||
- Running on GPU
|
||||
|
||||
```shell
|
||||
# distribute training example in GPU
|
||||
$ bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] GPU
|
||||
|
||||
# evaluation example in GPU
|
||||
$ bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] GPU
|
||||
|
||||
# standalone training example in GPU
|
||||
$ bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] GPU
|
||||
```
|
||||
|
||||
DATASET_NAME is one of `ic03`, `ic13`, `svt`, `iiit5k`, `synth`.
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
@ -123,25 +136,25 @@ crnn
|
|||
├── convert_svt.py # Convert the original SVT dataset
|
||||
├── requirements.txt # Requirements for this dataset
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
|
||||
│ ├── run_eval.sh # Launch evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training(1 pcs)
|
||||
│ ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
|
||||
│ ├── run_eval.sh # Launch evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training(1 pcs)
|
||||
├── src
|
||||
│ ├── model_utils
|
||||
│ ├── config.py # Parameter config
|
||||
│ ├── moxing_adapter.py # modelarts device configuration
|
||||
│ └── device_adapter.py # Device Config
|
||||
│ └── local_adapter.py # local device config
|
||||
│ ├── crnn.py # crnn network definition
|
||||
│ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── eval_callback.py
|
||||
│ ├── ic03_dataset.py # Data preprocessing for IC03
|
||||
│ ├── ic13_dataset.py # Data preprocessing for IC13
|
||||
│ ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
|
||||
│ ├── loss.py # Ctcloss definition
|
||||
│ ├── metric.py # accuracy metric for crnn network
|
||||
│ └── svt_dataset.py # Data preprocessing for SVT
|
||||
│ ├── config.py # Parameter config
|
||||
│ ├── moxing_adapter.py # modelarts device configuration
|
||||
│ └── device_adapter.py # Device Config
|
||||
│ └── local_adapter.py # local device config
|
||||
│ ├── crnn.py # crnn network definition
|
||||
│ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── eval_callback.py
|
||||
│ ├── ic03_dataset.py # Data preprocessing for IC03
|
||||
│ ├── ic13_dataset.py # Data preprocessing for IC13
|
||||
│ ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
|
||||
│ ├── loss.py # Ctcloss definition
|
||||
│ ├── metric.py # accuracy metric for crnn network
|
||||
│ └── svt_dataset.py # Data preprocessing for SVT
|
||||
└── train.py # Training script
|
||||
├── eval.py # Evaluation Script
|
||||
├── default_config.yaml # config file
|
||||
|
@ -153,11 +166,11 @@ crnn
|
|||
#### Training Script Parameters
|
||||
|
||||
```shell
|
||||
# distributed training in Ascend
|
||||
Usage: bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
# distributed training
|
||||
Usage: bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
Usage: bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
@ -195,18 +208,18 @@ max_text_length": 23, # max number of digits in each
|
|||
|
||||
### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, only support Ascend now.
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, support Ascend and GPU now.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
|
||||
bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend.
|
||||
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend or GPU
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)
|
||||
```
|
||||
|
||||
Check the `train_parallel0/log.txt` and you will get outputs as following:
|
||||
|
@ -276,7 +289,7 @@ Epoch time: 2743.688s, per step time: 0.097s
|
|||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
``` bash
|
||||
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
|
||||
bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
|
||||
```
|
||||
|
||||
Check the `eval/log.txt` and you will get outputs as following:
|
||||
|
@ -352,37 +365,37 @@ result CRNNAccuracy is: 0.806666666666
|
|||
|
||||
#### [Training Performance](#contents)
|
||||
|
||||
| Parameters | Ascend 910 |
|
||||
| -------------------------- | --------------------------------------------------|
|
||||
| Model Version | v1.0 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 12/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 |
|
||||
| Dataset | Synth |
|
||||
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | CTCLoss |
|
||||
| outputs | probability |
|
||||
| Loss | 0.0029097411 |
|
||||
| Speed | 118ms/step(8pcs) |
|
||||
| Total time | 557 mins |
|
||||
| Parameters (M) | 83M (.ckpt file) |
|
||||
| Checkpoint for Fine tuning | 20.3M (.ckpt file) |
|
||||
| Parameters | Ascend 910 | Tesla V100 |
|
||||
| -------------------------- | --------------------------------------------------|---------------------------------------------------|
|
||||
| Model Version | v1.0 | v2.0 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | Tesla V100; CPU 2.60GHz, 72cores; Memory 256G; OS Ubuntu 18.04.3 |
|
||||
| uploaded Date | 12/15/2020 (month/day/year) | 6/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 | 1.2.0 |
|
||||
| Dataset | Synth | Synth |
|
||||
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 | epoch=10, steps per epoch=14110, batch_size = 64 |
|
||||
| Optimizer | SGD | SGD |
|
||||
| Loss Function | CTCLoss | CTCLoss |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 0.0029097411 | 0.0029097411 |
|
||||
| Speed | 118ms/step(8pcs) | 36ms/step(8pcs) |
|
||||
| Total time | 557 mins | 189 mins |
|
||||
| Parameters (M) | 83M (.ckpt file) | 96M |
|
||||
| Checkpoint for Fine tuning | 20.3M (.ckpt file) | |
|
||||
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) |
|
||||
|
||||
#### [Evaluation Performance](#contents)
|
||||
|
||||
| Parameters | SVT | IIIT5K |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| Model Version | V1.0 | V1.0 |
|
||||
| Resource | Ascend 910; OS Euler2.8 | Ascend 910 |
|
||||
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 | 1.0.1 |
|
||||
| Dataset | SVT | IIIT5K |
|
||||
| batch_size | 1 | 1 |
|
||||
| outputs | ACC | ACC |
|
||||
| Accuracy | 80.8% | 79.7% |
|
||||
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) |
|
||||
| Parameters | SVT | IIIT5K | SVT | IIIT5K |
|
||||
| ------------------- | --------------------------- | --------------------------- | --------------------------- | --------------------------- |
|
||||
| Model Version | V1.0 | V1.0 | V2.0 | V2.0 |
|
||||
| Resource | Ascend 910; OS Euler2.8 | Ascend 910 | Tesla V100 | Tesla V100 |
|
||||
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) | 6/11/2021 (month/day/year) | 6/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 | 1.0.1 | 1.2.0 | 1.2.0 |
|
||||
| Dataset | SVT | IIIT5K | SVT | IIIT5K |
|
||||
| batch_size | 1 | 1 | 1 | 1 |
|
||||
| outputs | ACC | ACC | ACC | ACC |
|
||||
| Accuracy | 80.8% | 79.7% | 81.92% | 80.2% |
|
||||
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) | 96M (.ckpt file) | 96M (.ckpt file) |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ checkpoint_url: ""
|
|||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
device_target: "GPU"
|
||||
enable_profiling: False
|
||||
|
||||
# ======================================================================================
|
||||
|
@ -32,6 +32,7 @@ nesterov: True
|
|||
save_checkpoint: True
|
||||
save_checkpoint_steps: 1000
|
||||
keep_checkpoint_max: 30
|
||||
per_print_time: 100
|
||||
save_checkpoint_path: "./"
|
||||
class_num: 37
|
||||
input_size: 512
|
||||
|
@ -42,9 +43,10 @@ train_dataset_path: ""
|
|||
train_eval_dataset: "svt"
|
||||
train_eval_dataset_path: ""
|
||||
run_eval: False
|
||||
eval_all_saved_ckpts: False
|
||||
save_best_ckpt: True
|
||||
eval_start_epoch: 5
|
||||
eval_interval: 5
|
||||
eval_interval: 1
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
|
|
|
@ -51,7 +51,7 @@ def crnn_eval():
|
|||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = crnn(config)
|
||||
net = crnn(config, full_precision=config.device_target == 'GPU')
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh scripts/run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
if [ $# != 4 ] && [ $# != 3 ] && [ $# != 6 ] && [ $# != 5 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -28,36 +28,59 @@ get_real_path() {
|
|||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
PLATFORM=$3
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp -r scripts/ ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
if [ "GPU" == $PLATFORM ]; then
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train_parallel
|
||||
rm -rf $WORKDIR
|
||||
mkdir $WORKDIR
|
||||
cp ./*.py $WORKDIR
|
||||
cp -r ./src $WORKDIR
|
||||
cp -r ./scripts $WORKDIR
|
||||
cp ./*yaml $WORKDIR
|
||||
cd $WORKDIR || exit
|
||||
echo "start distributed training with $DEVICE_NUM GPUs."
|
||||
env >env.log
|
||||
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --train_dataset=$DATASET_NAME --train_dataset_path=$PATH2 --device_target=GPU --run_distribute=True > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
elif [ "Ascend" == $PLATFORM ]; then
|
||||
PATH1=$(get_real_path $4)
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp -r scripts/ ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --train_dataset_path=$PATH2 --run_distribute=True --train_dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
||||
|
|
|
@ -16,10 +16,9 @@
|
|||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def _bn(channel):
|
||||
|
@ -71,6 +70,27 @@ class VGG(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class BidirectionalLSTM(nn.Cell):
|
||||
|
||||
def __init__(self, nIn, nHidden, nOut, batch_size):
|
||||
super(BidirectionalLSTM, self).__init__()
|
||||
|
||||
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
||||
self.embedding = nn.Dense(in_channels=nHidden * 2, out_channels=nOut)
|
||||
self.h0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
|
||||
self.c0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
|
||||
|
||||
def construct(self, x):
|
||||
recurrent, _ = self.rnn(x, (self.h0, self.c0))
|
||||
T, b, h = P.Shape()(recurrent)
|
||||
t_rec = P.Reshape()(recurrent, (T * b, h,))
|
||||
|
||||
out = self.embedding(t_rec) # [T * b, nOut]
|
||||
out = P.Reshape()(out, (T, b, -1,))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CRNN(nn.Cell):
|
||||
"""
|
||||
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
|
||||
|
@ -88,86 +108,21 @@ class CRNN(nn.Cell):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.num_classes = config.class_num
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1 = Parameter(w1.astype(np.float32), name="w1")
|
||||
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2 = Parameter(w2.astype(np.float32), name="w2")
|
||||
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")
|
||||
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
|
||||
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
|
||||
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
|
||||
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
|
||||
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
|
||||
|
||||
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
|
||||
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
|
||||
self.fc.to_float(mstype.float32)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.vgg = VGG()
|
||||
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.rnn_dropout = nn.Dropout(0.9)
|
||||
self.use_dropout = config.use_dropout
|
||||
self.rnn = nn.SequentialCell([
|
||||
BidirectionalLSTM(self.input_size, self.hidden_size, self.hidden_size, self.batch_size),
|
||||
BidirectionalLSTM(self.hidden_size, self.hidden_size, self.num_classes, self.batch_size)])
|
||||
|
||||
def construct(self, x):
|
||||
x = self.vgg(x)
|
||||
|
||||
x = self.reshape(x, (self.batch_size, self.input_size, -1))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
bw_x = self.reverse_seq1(x, self.seq_length)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
|
||||
y1_out = self.concat1((y1, y1_bw))
|
||||
if self.use_dropout:
|
||||
y1_out = self.rnn_dropout(y1_out)
|
||||
|
||||
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
|
||||
bw_y = self.reverse_seq3(y1_out, self.seq_length)
|
||||
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
|
||||
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
|
||||
y2_out = self.concat1((y2, y2_bw))
|
||||
if self.use_dropout:
|
||||
y2_out = self.dropout(y2_out)
|
||||
x = self.rnn(x)
|
||||
|
||||
output = ()
|
||||
for i in range(F.shape(y2_out)[0]):
|
||||
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
|
||||
y2_after_fc = self.expand_dims(y2_after_fc, 0)
|
||||
output += (y2_after_fc,)
|
||||
output = self.concat(output)
|
||||
return output
|
||||
return x
|
||||
|
||||
|
||||
def crnn(config, full_precision=False):
|
||||
|
|
|
@ -28,6 +28,21 @@ from src.svt_dataset import SVTDataset
|
|||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
def check_image_is_valid(image):
|
||||
if image is None:
|
||||
return False
|
||||
|
||||
h, w, c = image.shape
|
||||
if h * w * c == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
letters = [letter for letter in config1.label_dict]
|
||||
|
||||
def text_to_labels(text):
|
||||
return list(map(lambda x: letters.index(x.lower()), text))
|
||||
|
||||
class CaptchaDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
@ -61,24 +76,37 @@ class CaptchaDataset:
|
|||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
self.sample_num = len(self.img_names)
|
||||
self.batch_size = config.batch_size
|
||||
print("There are totally {} samples".format(self.sample_num))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
return self.sample_num
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
try:
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
except IOError:
|
||||
print("%s is a corrupted image" % img_name)
|
||||
return self[item + 1]
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
if not check_image_is_valid(image):
|
||||
print("%s is a corrupted image" % img_name)
|
||||
return self[item + 1]
|
||||
|
||||
text = self.img_names[img_name]
|
||||
label_unexpanded = text_to_labels(text)
|
||||
label = np.full(self.max_text_length, self.blank)
|
||||
if self.max_text_length < len(label_unexpanded):
|
||||
label_len = self.max_text_length
|
||||
else:
|
||||
label_len = len(label_unexpanded)
|
||||
for j in range(label_len):
|
||||
label[j] = label_unexpanded[j]
|
||||
return image, label
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
import glob
|
||||
from mindspore import save_checkpoint, load_checkpoint, load_param_into_net
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
@ -30,7 +31,7 @@ class EvalCallBack(Callback):
|
|||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
best_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
|
@ -41,7 +42,7 @@ class EvalCallBack(Callback):
|
|||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
eval_all_saved_ckpts=False, ckpt_directory="./", best_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
|
@ -50,11 +51,14 @@ class EvalCallBack(Callback):
|
|||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.eval_all_saved_ckpts = eval_all_saved_ckpts
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.ckpt_directory = ckpt_directory
|
||||
self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
|
||||
self.last_ckpt_path = os.path.join(ckpt_directory, "last.ckpt")
|
||||
self.metrics_name = metrics_name
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
|
@ -72,20 +76,41 @@ class EvalCallBack(Callback):
|
|||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
if self.eval_all_saved_ckpts:
|
||||
ckpt_list = glob.glob(os.path.join(self.ckpt_directory, "crnn*.ckpt"))
|
||||
net = self.eval_param_dict["model"].train_network
|
||||
save_checkpoint(net, self.last_ckpt_path)
|
||||
for ckpt_path in ckpt_list:
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("{}: {}".format(self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_epoch = cur_epoch
|
||||
self.best_res = res
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if os.path.exists(self.best_ckpt_path):
|
||||
self.remove_ckpoint_file(self.best_ckpt_path)
|
||||
if self.save_best_ckpt:
|
||||
save_checkpoint(net, self.best_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
|
||||
param_dict = load_checkpoint(self.last_ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
self.remove_ckpoint_file(self.last_ckpt_path)
|
||||
else:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.best_ckpt_path):
|
||||
self.remove_ckpoint_file(self.best_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
||||
|
|
@ -59,6 +59,7 @@ class IC03Dataset:
|
|||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
self.label_dict = config.label_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
@ -73,8 +74,8 @@ class IC03Dataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
if c in self.label_dict:
|
||||
label.append(self.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -58,6 +58,7 @@ class IC13Dataset:
|
|||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
self.label_dict = config.label_dict
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
def __getitem__(self, item):
|
||||
|
@ -70,8 +71,8 @@ class IC13Dataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
if c in self.label_dict:
|
||||
label.append(self.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -48,6 +48,7 @@ class IIIT5KDataset:
|
|||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
self.label_dict = config.label_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
@ -62,8 +63,8 @@ class IIIT5KDataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
if c in self.label_dict:
|
||||
label.append(self.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# ============================================================================
|
||||
"""CTC Loss."""
|
||||
import numpy as np
|
||||
from mindspore.nn.loss.loss import Loss
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CTCLoss(Loss):
|
||||
class CTCLoss(_Loss):
|
||||
"""
|
||||
CTCLoss definition
|
||||
|
||||
|
|
|
@ -22,12 +22,13 @@ class CRNNAccuracy(nn.Metric):
|
|||
Define accuracy metric for warpctc network.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, print_flag=True):
|
||||
super(CRNNAccuracy).__init__()
|
||||
self.config = config
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
self.blank = config.blank
|
||||
self.print_flag = print_flag
|
||||
|
||||
def clear(self):
|
||||
self._correct_num = 0
|
||||
|
@ -45,7 +46,8 @@ class CRNNAccuracy(nn.Metric):
|
|||
str_label = self._convert_labels(y)
|
||||
|
||||
for pred, label in zip(str_pred, str_label):
|
||||
print(pred, " :: ", label)
|
||||
if self.print_flag:
|
||||
print(pred, " :: ", label)
|
||||
edit_distance = Levenshtein.distance(pred, label)
|
||||
self._total_num += 1
|
||||
if edit_distance == 0:
|
||||
|
|
|
@ -46,6 +46,7 @@ class SVTDataset:
|
|||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
self.label_dict = config.label_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
@ -60,8 +61,8 @@ class SVTDataset:
|
|||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in config.label_dict:
|
||||
label.append(config.label_dict.index(c))
|
||||
if c in self.label_dict:
|
||||
label.append(self.label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
|
|
@ -81,13 +81,14 @@ def train():
|
|||
batch_size=config.batch_size,
|
||||
num_shards=device_num, shard_id=rank, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
print("step_size:", step_size)
|
||||
# define lr
|
||||
lr_init = config.learning_rate
|
||||
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = crnn(config)
|
||||
net = crnn(config, full_precision=config.device_target == 'GPU')
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
@ -95,9 +96,10 @@ def train():
|
|||
# define model
|
||||
model = Model(net_with_grads)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
callbacks = [LossMonitor(per_print_times=config.per_print_time),
|
||||
TimeMonitor(data_size=step_size)]
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
if config.run_eval:
|
||||
if config.run_eval and rank == 0:
|
||||
if config.train_eval_dataset_path is None or (not os.path.isdir(config.train_eval_dataset_path)):
|
||||
raise ValueError("{} is not a existing path.".format(config.train_eval_dataset_path))
|
||||
eval_dataset = create_dataset(name=config.train_eval_dataset,
|
||||
|
@ -105,19 +107,19 @@ def train():
|
|||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config, print_flag=False)})
|
||||
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
|
||||
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="acc")
|
||||
ckpt_directory=save_ckpt_path, best_ckpt_name="best_acc.ckpt",
|
||||
eval_all_saved_ckpts=config.eval_all_saved_ckpts, metrics_name="acc")
|
||||
callbacks += [eval_cb]
|
||||
if config.save_checkpoint and rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=config.device_target == 'Ascend')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue