adding cpu mode for crnn
This commit is contained in:
parent
b58f8efa54
commit
032d2626cb
|
@ -64,7 +64,7 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
|
|||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware
|
||||
- Prepare hardware environment with Ascend processor or GPU.
|
||||
- Prepare hardware environment with Ascend, GPU or CPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
|
@ -105,6 +105,16 @@ We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples f
|
|||
$ bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] GPU
|
||||
```
|
||||
|
||||
- Running on CPU
|
||||
|
||||
```shell
|
||||
# standalone training example in CPU
|
||||
$ bash scripts/run_standalone_train_cpu.sh [DATASET_NAME] [DATASET_PATH]
|
||||
|
||||
# evaluation example in CPU
|
||||
$ bash scripts/run_eval_cpu.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
DATASET_NAME is one of `ic03`, `ic13`, `svt`, `iiit5k`, `synth`.
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
@ -142,9 +152,11 @@ crnn
|
|||
├── convert_svt.py # Convert the original SVT dataset
|
||||
├── requirements.txt # Requirements for this dataset
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
|
||||
│ ├── run_eval.sh # Launch evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training(1 pcs)
|
||||
│ ├── run_standalone_train_cpu.sh # Launch standalone training in CPU
|
||||
│ ├── run_eval_cpu.sh # Launch evaluation in CPU
|
||||
│ ├── run_distribute_train.sh # Launch distributed training in Ascend or GPU(8 pcs)
|
||||
│ ├── run_eval.sh # Launch evaluation in Ascend or GPU
|
||||
│ └── run_standalone_train.sh # Launch standalone training in Ascend or GPU(1 pcs)
|
||||
├── src
|
||||
│ ├── model_utils
|
||||
│ ├── config.py # Parameter config
|
||||
|
@ -172,11 +184,14 @@ crnn
|
|||
#### Training Script Parameters
|
||||
|
||||
```shell
|
||||
# distributed training
|
||||
# distributed training in Ascend or GPU
|
||||
Usage: bash scripts/run_distribute_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] [RANK_TABLE_FILE](if Ascend)
|
||||
|
||||
# standalone training
|
||||
# standalone training in Ascend or GPU
|
||||
Usage: bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
|
||||
# standalone training in CPU
|
||||
Usage: bash scripts/run_standalone_train_cpu.sh [DATASET_NAME] [DATASET_PATH]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
@ -220,6 +235,12 @@ max_text_length": 23, # max number of digits in each
|
|||
bash scripts/run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
|
||||
```
|
||||
|
||||
- Or run `run_standalone_train_cpu.sh` for non-distributed training of CRNN model in CPU.
|
||||
|
||||
``` bash
|
||||
bash scripts/run_standalone_train_cpu.sh [DATASET_NAME] [DATASET_PATH]
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend or GPU
|
||||
|
@ -296,6 +317,8 @@ Epoch time: 2743.688s, per step time: 0.097s
|
|||
|
||||
``` bash
|
||||
bash scripts/run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
|
||||
|
||||
bash scripts/run_eval_cpu.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
Check the `eval/log.txt` and you will get outputs as following:
|
||||
|
@ -313,7 +336,7 @@ You can add `run_eval` to start shell and set it True.You need also add `eval_da
|
|||
### [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] --device_target [DEVICE_TARGET] --model_version [MODEL_VERSION](required for cpu)
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -46,7 +46,7 @@ def mat_to_list(mat_file):
|
|||
for elem in testdata:
|
||||
img_name = elem[0]
|
||||
label = elem[1]
|
||||
ann = img_name+',' +label
|
||||
ann = img_name[0] + ',' + label[0]
|
||||
ann_output.append(ann)
|
||||
return ann_output
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -51,7 +51,7 @@ def crnn_eval():
|
|||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = crnn(config, full_precision=config.device_target == 'GPU')
|
||||
net = crnn(config, full_precision=config.device_target != 'Ascend')
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
@ -62,6 +62,5 @@ def crnn_eval():
|
|||
res = model.eval(dataset, dataset_sink_mode=config.device_target == 'Ascend')
|
||||
print("result:", res, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
crnn_eval()
|
||||
|
|
|
@ -23,6 +23,7 @@ from src.model_utils.moxing_adapter import moxing_wrapper
|
|||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
|
||||
def modelarts_pre_process():
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: bash run_eval_cpu.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cp ./*yaml ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --eval_dataset=$DATASET_NAME \
|
||||
--eval_dataset_path=$PATH1 \
|
||||
--checkpoint_path=$PATH2 \
|
||||
--model_version="V2" \
|
||||
--device_target=CPU > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: bash scripts/run_standalone_train_cpu.sh [DATASET_NAME] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train_cpu
|
||||
rm -rf $WORKDIR
|
||||
mkdir $WORKDIR
|
||||
cp ./*.py $WORKDIR
|
||||
cp -r ./src $WORKDIR
|
||||
cp ./*yaml $WORKDIR
|
||||
cd $WORKDIR || exit
|
||||
|
||||
env >env.log
|
||||
python train.py --train_dataset=$DATASET_NAME \
|
||||
--train_dataset_path=$PATH1 \
|
||||
--device_target=CPU \
|
||||
--model_version="V2" > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,17 @@
|
|||
# ============================================================================
|
||||
"""Warpctc network definition."""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
|
||||
|
@ -42,6 +45,46 @@ class Conv(nn.Cell):
|
|||
out = self.Relu(out)
|
||||
return out
|
||||
|
||||
class LSTMCPU(nn.Cell):
|
||||
"""Stacked LSTM (Long Short-Term Memory) layers for CPU"""
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=True,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
super(LSTMCPU, self).__init__()
|
||||
self.transpose = P.Transpose()
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.dropout = dropout
|
||||
self.lstm = P.LSTM(input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
has_bias=has_bias,
|
||||
bidirectional=bidirectional,
|
||||
dropout=float(dropout))
|
||||
weight_size = 0
|
||||
gate_size = 4 * hidden_size
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
for layer in range(num_layers):
|
||||
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
|
||||
increment_size = gate_size * input_layer_size
|
||||
increment_size += gate_size * hidden_size
|
||||
if has_bias:
|
||||
increment_size += gate_size
|
||||
weight_size += increment_size * num_directions
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
|
||||
|
||||
def construct(self, x, hx):
|
||||
h, c = hx
|
||||
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
||||
return x, (h, c)
|
||||
|
||||
class VGG(nn.Cell):
|
||||
"""VGG Network structure"""
|
||||
def __init__(self, is_training=True):
|
||||
|
@ -76,8 +119,10 @@ class BidirectionalLSTM(nn.Cell):
|
|||
|
||||
def __init__(self, nIn, nHidden, nOut, batch_size):
|
||||
super(BidirectionalLSTM, self).__init__()
|
||||
|
||||
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.rnn = LSTMCPU(nIn, nHidden, bidirectional=True)
|
||||
else:
|
||||
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
||||
self.embedding = nn.Dense(in_channels=nHidden * 2, out_channels=nOut)
|
||||
self.h0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
|
||||
self.c0 = Tensor(np.zeros([1 * 2, batch_size, nHidden]).astype(np.float32))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -88,7 +88,7 @@ def train():
|
|||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = crnn(config, full_precision=config.device_target == 'GPU')
|
||||
net = crnn(config, full_precision=config.device_target != 'Ascend')
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
|
Loading…
Reference in New Issue