cnn direction model

This commit is contained in:
avakh 2020-12-01 14:14:04 -05:00
parent f6450a614b
commit 830b8f3e93
10 changed files with 1592 additions and 0 deletions

View File

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train resnet."""
import argparse
import os
import random
import numpy as np
from src.cnn_direction_model import CNNDirectionModel
from src.config import config1 as config
from src.dataset import create_dataset_eval
from mindspore import context
from mindspore import dataset as de
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
if __name__ == '__main__':
# init context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset_eval(args_opt.dataset_path + "/ocr_eval_pos.mindrecord", config=config)
step_size = dataset.get_dataset_size()
print("step_size ", step_size)
# define net
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
# eval model
res = model.eval(dataset, dataset_sink_mode=False)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,5 @@
mindspore
numpy
Pillow
python-opencv
scikit-image

View File

@ -0,0 +1,88 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 3 ]
then
PATH3=$(get_real_path $3)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 3 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 2 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
fi
if [ $# == 3 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
fi
cd ..
done

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
echo "start evaluation for device $DEVICE_ID"
env > env.log
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log &
cd ..

View File

@ -0,0 +1,72 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=3
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 2 ]
then
PATH2=$(get_real_path $2)
fi
if [ $# == 2 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 1 ]
then
python train.py --dataset_path=$PATH1 &> log &
fi
if [ $# == 2 ]
then
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi
cd ..

View File

@ -0,0 +1,264 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CNN direction model."""
import math
import mindspore.nn as nn
from mindspore.common.initializer import Uniform
from mindspore.ops import operations as P
class NetAddN(nn.Cell):
"""
Computes addition of all input tensors element-wise.
"""
def __init__(self):
super(NetAddN, self).__init__()
self.addN = P.AddN()
def construct(self, *z):
return self.addN(z)
class Conv(nn.Cell):
"""
A convolution layer
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
kernel (tuple): Size of the kernel. Default: (3, 3).
dilate (bool): If set to true a second convolution layer is added. Default: True.
act (string): The activation function. Default: 'relu'.
mp (int): Size of max pooling layer. Default: None.
Returns:
Tensor, output tensor.
Examples:
>>> Conv(3, 64)
"""
def __init__(self,
in_channel,
out_channel,
kernel=(3, 3),
dilate=True,
act='relu',
mp=None):
super(Conv, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.kernel = kernel
self.dilate = dilate
self.act = act
self.mp = mp
self.conv1 = nn.Conv2d(self.in_channel, self.out_channel, kernel_size=self.kernel, pad_mode="same",
weight_init='he_normal')
self.batch_norm1 = nn.BatchNorm2d(self.out_channel, eps=1e-3, momentum=0.99,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
if self.dilate:
self.dilate_relu = P.ReLU()
self.dilate_conv = nn.Conv2d(self.out_channel, self.out_channel, kernel_size=self.kernel,
dilation=(2, 2), pad_mode='same', weight_init='he_normal')
self.dilate_batch_norm = nn.BatchNorm2d(self.out_channel, eps=1e-3, momentum=0.99,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
self.dilate_add = NetAddN()
if self.act == 'relu':
self.act_layer = P.ReLU()
if self.mp is not None:
self.mp_layer = nn.MaxPool2d(kernel_size=self.mp, stride=self.mp, pad_mode='valid')
def construct(self, x):
out = self.conv1(x)
out = self.batch_norm1(out)
out1 = out
if self.dilate:
out = self.dilate_relu(out)
out = self.dilate_conv(out)
out = self.dilate_batch_norm(out)
out = self.dilate_add(out1, out)
if self.act == 'relu':
out = self.act_layer(out)
if self.mp is not None:
out = self.mp_layer(out)
return out
class Block(nn.Cell):
"""
A Block of convolution operations.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
Returns:
Tensor, output tensor.
Examples:
>>> Block(3, 64)
"""
def __init__(self,
in_channel,
out_channel):
super(Block, self).__init__()
self.conv1 = Conv(in_channel, out_channel, act='relu')
self.conv2 = Conv(out_channel, out_channel, act=None)
self.add = NetAddN()
self.relu = P.ReLU()
def construct(self, x):
y = self.conv1(x)
y = self.conv2(y)
out = self.add(x, y)
out = self.relu(out)
return out
class ResidualBlock(nn.Cell):
"""
A residual block.
Args:
block (Block) : The building block.
num_blocks (int): Number of blocks.
in_channel (int): Input channel.
out_channel (int): Output channel.
mp (int) : Size of the max pooling layer. Default: 2.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(Block, 1, 3, 64)
"""
def __init__(self,
block,
num_blocks,
in_channel,
out_channel,
mp=2):
super(ResidualBlock, self).__init__()
self.num_blocks = num_blocks
self.in_channel = in_channel
self.out_channel = out_channel
self.mp = mp
self.conv1 = Conv(self.in_channel, self.out_channel, kernel=(3, 3), dilate=False)
layers = []
for _ in range(self.num_blocks):
res_block = block(out_channel, out_channel)
layers.append(res_block)
self.layer = nn.SequentialCell(layers)
if mp is not None:
self.max_pool = nn.MaxPool2d(kernel_size=mp, stride=mp, pad_mode='valid')
def construct(self, x):
out = self.conv1(x)
out = self.layer(out)
if self.mp is not None:
out = self.max_pool(out)
return out
class CNNDirectionModel(nn.Cell):
"""
CNN direction model.
Args:
in_channels (list): List of the dimesnions of the input channels. The first element is the input dimension
of the first Conv layer, and the rest of the elements are the input dimensions of the residual blocks,
in order.
out_channels (list): List of the dimesnions of the output channels. The first element is the ourpur dimension
of the first Conv layer, and the rest of the elements are the output dimensions of the residual blocks, in order.
dense_layers (list): Dimensions of the dense layers, inorder.
image_size (list): Size of the input images.
num_classes (int): Number of classes. Default: 2 for binary classification.
Returns: Tensor, output tensor.
Examples:
>>> CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512] )
"""
def __init__(self,
in_channels,
out_channels,
dense_layers,
image_size,
num_classes=2
):
super(CNNDirectionModel, self).__init__()
self.num_classes = num_classes
self.image_h = image_size[0]
self.image_w = image_size[1]
self.conv1 = Conv(in_channels[0], out_channels[0], kernel=(7, 7), dilate=False, mp=2)
self.residual_block1 = ResidualBlock(Block, 1, in_channels[1], out_channels[1])
self.residual_block2 = ResidualBlock(Block, 1, in_channels[2], out_channels[2])
self.residual_block3 = ResidualBlock(Block, 2, in_channels[3], out_channels[3])
self.residual_block4 = ResidualBlock(Block, 1, in_channels[4], out_channels[4])
# 5 previous layers have mp=2. Height and width of the image would become 1/32.
self.avg_pool = nn.AvgPool2d(kernel_size=(int(self.image_h / 32), int(self.image_w / 32)))
# sqrt(6 / (fan_in + fan_out))
scale = math.sqrt(6 / (out_channels[-1] + dense_layers[0]))
# weight_init='glorot_uniform'
self.dense1 = nn.Dense(out_channels[-1], dense_layers[0], weight_init=Uniform(scale=scale), activation='relu')
scale = math.sqrt(6 / (dense_layers[0] + dense_layers[1]))
self.dense2 = nn.Dense(dense_layers[0], dense_layers[1], weight_init=Uniform(scale=scale), activation='relu')
scale = math.sqrt(6 / (dense_layers[1] + num_classes))
self.dense3 = nn.Dense(dense_layers[1], num_classes, weight_init=Uniform(scale=scale), activation='softmax')
def construct(self, x):
out = self.conv1(x)
out = self.residual_block1(out)
out = self.residual_block2(out)
out = self.residual_block3(out)
out = self.residual_block4(out)
out = self.avg_pool(out)
out = P.Reshape()(out, (out.shape[0], out.shape[1]))
out = self.dense1(out)
out = self.dense2(out)
out = self.dense3(out)
return out

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config1 = ed({
"batch_size": 8,
"epoch_size": 5,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr": 1e-4,
"work_nums": 4,
"im_size_w": 512,
"im_size_h": 64,
"pos_samples_size": 100,
"augment_severity": 0.1,
"augment_prob": 0.3
})

View File

@ -0,0 +1,246 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
from src.dataset_utils import lucky, noise_blur, noise_speckle, noise_gamma, noise_gaussian, noise_salt_pepper, \
shift_color, enhance_brightness, enhance_sharpness, enhance_contrast, enhance_color, gaussian_blur, \
randcrop, resize, rdistort, rgeometry, rotate_about_center, whole_rdistort, warp_perspective, random_contrast, \
unify_img_label
import cv2
import numpy as np
cv2.setNumThreads(0)
image_height = None
image_width = None
class Augmentor():
"""
Augment image with random noise and transformation
Controlled by severity level [0, 1]
Usage:
augmentor = Augmentor(severity=0.3,
prob=0.5,
enable_transform=True,
enable_crop=False)
image_new = augmentor.process(image)
"""
def __init__(self, severity, prob, enable_transform=True, enable_crop=False):
"""
severity: in [0, 1], from min to max level of noise/transformation
prob: in [0, 1], probability to apply each operator
enable_transform: enable all transformation operators
enable_crop: enable crop operator
"""
self.severity = np.clip(severity, 0, 1)
self.prob = np.clip(prob, 0, 1)
self.enable_transform = enable_transform
self.enable_crop = enable_crop
def add_noise(self, im):
"""randomly add noise to image"""
severity = self.severity
prob = self.prob
if lucky(prob):
im = noise_gamma(im, severity=severity)
if lucky(prob):
im = noise_blur(im, severity=severity)
if lucky(prob):
im = noise_gaussian(im, severity=severity)
if lucky(prob):
im = noise_salt_pepper(im, severity=severity)
if lucky(prob):
im = shift_color(im, severity=severity)
if lucky(prob):
im = gaussian_blur(im, severity=severity)
if lucky(prob):
im = noise_speckle(im, severity=severity)
if lucky(prob):
im = enhance_sharpness(im, severity=severity)
if lucky(prob):
im = enhance_contrast(im, severity=severity)
if lucky(prob):
im = enhance_brightness(im, severity=severity)
if lucky(prob):
im = enhance_color(im, severity=severity)
if lucky(prob):
im = random_contrast(im)
return im
def convert_color(self, im, cval):
if cval in ['median', 'md']:
cval = np.median(im, axis=(0, 1)).astype(int)
elif cval == 'mean':
cval = np.mean(im, axis=(0, 1)).astype(int)
if hasattr(cval, '__iter__'):
cval = [int(i) for i in cval]
else:
cval = int(cval)
return cval
def transform(self, im, cval=255, **kw):
"""According to the parameters initialized by the class, deform the incoming image"""
severity = self.severity
prob = self.prob
cval = self.convert_color(im, cval)
if lucky(prob):
# affine transform
im = rgeometry(im, severity=severity, cval=cval)
if lucky(prob):
im = rdistort(im, severity=severity, cval=cval)
if lucky(prob):
im = warp_perspective(im, severity=severity, cval=cval)
if lucky(prob):
im = resize(im, fx=kw.get('fx'), fy=kw.get('fy'), severity=severity)
if lucky(prob):
im = rotate_about_center(im, severity=severity, cval=cval)
if lucky(prob):
# the overall distortion of the image.
im = whole_rdistort(im, severity=severity)
if lucky(prob) and self.enable_crop:
# random crop
im = randcrop(im, severity=severity)
return im
def process(self, im, cval='median', **kw):
""" Execute code according to the effect of initial setting, and support variable parameters"""
if self.enable_transform:
im = self.transform(im, cval=cval, **kw)
im = self.add_noise(im)
return im
def rotate_and_set_neg(img, label):
label = label - 1
img_rotate = np.rot90(img)
img_rotate = np.rot90(img_rotate)
# return img_rotate, label
return img_rotate, np.array(label).astype(np.int32)
def rotate(img, label):
img_rotate = np.rot90(img)
img_rotate = np.rot90(img_rotate)
return img_rotate, label
def random_neg_with_rotate(img, label):
if lucky(0.5):
##50% of samples set to negative samples
label = label - 1
# rotate by 180 debgress
img_rotate = np.rot90(img)
img = np.rot90(img_rotate)
return img, np.array(label).astype(np.int32)
def transform_image(img, label):
data = np.array([img[...]], np.float32)
data = data / 127.5 - 1
return data.transpose((0, 3, 1, 2))[0], label
def create_dataset_train(mindrecord_file_pos, config):
"""
create a train dataset
Args:
mindrecord_file_pos(string): mindrecord file for positive samples.
config(dict): config of dataset.
Returns:
dataset
"""
rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0'))
decode = C.Decode()
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=4,
num_shards=rank_size, shard_id=rank_id, shuffle=True)
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
augmentor = Augmentor(config.augment_severity, config.augment_prob)
operation = augmentor.process
ds = ds.map(operations=operation, input_columns=["image"],
num_parallel_workers=1, python_multiprocessing=True)
##randomly augment half of samples to be negative samples
ds = ds.map(operations=[random_neg_with_rotate, unify_img_label, transform_image], input_columns=["image", "label"],
num_parallel_workers=8, python_multiprocessing=True)
##for training double the dataset to accoun for positive and negative
ds = ds.repeat(2)
# apply batch operations
ds = ds.batch(config.batch_size, drop_remainder=True)
return ds
def resize_image(img, label):
color_fill = 255
scale = image_height / img.shape[0]
img = cv2.resize(img, None, fx=scale, fy=scale)
if img.shape[1] > image_width:
img = img[:, 0:image_width]
else:
blank_img = np.zeros((image_height, image_width, 3), np.uint8)
# fill the image with white
blank_img.fill(color_fill)
blank_img[:image_height, :img.shape[1]] = img
img = blank_img
data = np.array([img[...]], np.float32)
data = data / 127.5 - 1
return data.transpose((0, 3, 1, 2))[0], label
def create_dataset_eval(mindrecord_file_pos, config):
"""
create an eval dataset
Args:
mindrecord_file_pos(string): mindrecord file for positive samples.
config(dict): config of dataset.
Returns:
dataset
"""
rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0'))
decode = C.Decode()
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=1,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
global image_height
global image_width
image_height = config.im_size_h
image_width = config.im_size_w
ds = ds.map(operations=resize_image, input_columns=["image", "label"], num_parallel_workers=config.work_nums,
python_multiprocessing=False)
# apply batch operations
ds = ds.batch(1, drop_remainder=True)
return ds

View File

@ -0,0 +1,641 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import absolute_import, division, print_function, unicode_literals
from math import ceil, sin, pi
from random import choice, random
from random import randint, uniform
import cv2
import numpy as np
from numpy.random import randn
from PIL import ImageEnhance, Image
from scipy.ndimage import filters, interpolation
from scipy.ndimage.interpolation import map_coordinates
from skimage.transform import PiecewiseAffineTransform, warp
nprandint = np.random.randint
def lucky(p=0.3, rand_func=random):
""" return True with probability p """
return rand_func() < p
def rgeometry(im, eps=0.04, delta=0.8, cval=None, severity=1):
"""
affine transform
"""
if severity == 0:
return im
if cval is None:
cval = [0] * im.shape[2]
elif isinstance(cval, (float, int)):
cval = [cval] * im.shape[2]
severity = abs(severity)
eps = severity * eps
delta = severity * delta
m = np.array([[1 + eps * randn(), 0.0], [eps * randn(), 1.0 + eps * randn()]])
c = np.array(im.shape[:2]) * 0.5
d = c - np.dot(m, c) + np.array([randn() * delta, randn() * delta])
im = cv2.split(im)
im = [interpolation.affine_transform(i, m, offset=d, order=1, mode='constant', cval=cval[e])
for e, i in enumerate(im)]
im = cv2.merge(im)
return np.array(im)
def rdistort(im, distort=4.0, dsigma=10.0, cval=None, severity=1):
"""distort"""
if severity == 0:
return im
if cval is None:
cval = [0] * im.shape[2]
elif isinstance(cval, (float, int)):
cval = [cval] * im.shape[2]
severity = abs(severity)
distort = severity * distort
dsigma = dsigma * (1 - severity)
h, w = im.shape[:2]
hs, ws = randn(h, w), randn(h, w)
hs = filters.gaussian_filter(hs, dsigma)
ws = filters.gaussian_filter(ws, dsigma)
hs *= distort / np.abs(hs).max()
ws *= distort / np.abs(ws).max()
# When "ij" is passed in, the first array determines the column, the second array determines the row, by default,
# the first array determines the row, and the second array determines the column
ch, cw = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
coordinates = np.array([ch + hs, cw + ws])
im = cv2.split(im)
im = [map_coordinates(img, coordinates, order=1, cval=cval[i]) for i, img in enumerate(im)]
im = cv2.merge(im)
return np.array(im)
def reverse_color(im):
""" Pixel inversion """
return 255 - im
def resize(im, fx=None, fy=None, delta=0.3, severity=1):
""" scaling in the two directions of width fx and height fy,
If the zoom factor is not specified, the maximum change amount of 0.3 is randomly selected from 1 to 1"""
if fx is None:
fx = 1 + delta * severity * uniform(-1, 1)
if fy is None:
fy = 1 + delta * severity * uniform(-1, 1)
return np.array(cv2.resize(im, None, fx=fx, fy=fy))
def warp_perspective(im, theta=20, delta=10, cval=0, severity=1):
""" perspective mapping """
if severity == 0:
return im
if cval is None:
cval = [0] * im.shape[2]
elif isinstance(cval, (float, int)):
cval = [cval] * im.shape[2]
delta = delta * severity
rows, cols = im.shape[:2]
pts_im = np.float32([[0, 0], [cols, 0], [cols, rows], [0, rows]])
# Distort randomly and constrain the scope of change
pts_warp = pts_im + np.random.uniform(-1, 1, pts_im.shape) * theta * severity
pts_warp = np.maximum(pts_warp, delta) # Constrain the change to the part >=3
pts_warp[[1, 2], 0] = np.minimum(pts_warp[[1, 2], 0], pts_im[[1, 2], 0] - delta)
pts_warp[[2, 3], 1] = np.minimum(pts_warp[[2, 3], 1], pts_im[[2, 3], 1] - delta)
pts_warp = np.float32(pts_warp)
M = cv2.getPerspectiveTransform(pts_im, pts_warp)
res = np.array(cv2.warpPerspective(im, M, (cols, rows), borderValue=cval))
return res
def noise_salt_pepper(image, percentage=0.001, severity=1):
""" Salt and pepper noise, percentage represents the percentage of salt and pepper noise"""
percentage *= severity
amount = int(percentage * image.shape[0] * image.shape[1])
if amount == 0:
return image
_, _, deep = image.shape
# Salt mode
coords = [np.random.randint(0, i - 1, amount) for i in image.shape[:2]]
salt = nprandint(200, 255, amount)
salt = salt.repeat(deep, axis=0)
image[coords[0], coords[1], :] = salt.reshape(amount, deep)
# pepper mode
coords = [np.random.randint(0, i - 1, amount) for i in image.shape[:2]]
pepper = nprandint(0, 50, amount)
pepper = pepper.repeat(deep, axis=0)
image[coords[0], coords[1], :] = pepper.reshape(amount, deep)
return image
def noise_gaussian(im, sigma=20, severity=1):
""" add Gaussian noise"""
sigma = sigma * abs(severity)
return cvt_uint8(np.float32(im) + sigma * np.random.randn(*im.shape))
def noise_gamma(im, extend=30, severity=1):
""" add gamma noise """
s = int(extend * abs(severity))
n = np.random.gamma(shape=2, scale=s, size=im.shape)
n = n - np.mean(n)
im = cvt_uint8(np.float32(im) + n)
return im
def noise_speckle(img, extend=40, severity=1):
""" this creates larger 'blotches' of noise which look
more realistic than just adding gaussian noise """
severity = abs(severity) * extend
blur = filters.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
return cvt_uint8(img + blur)
def noise_blur(im, severity=1):
"""add blur by shrinking an image and then enlarging to original size"""
severity = abs(severity)
f = 1 - 0.2 * severity
h, w = im.shape[:2]
hmin = 19.0
f = max(f, hmin / h)
im = cv2.resize(im, None, fx=f, fy=f)
return np.array(cv2.resize(im, (w, h)))
def add_noise(img):
"""combine noises in np array"""
img0 = img
if lucky(0.1):
img = noise_salt_pepper(img, uniform(0.3, 0.6))
if lucky(0.2):
img = noise_gaussian(img, uniform(0.3, 0.6))
if lucky(0.5):
img = noise_blur(img, uniform(0.3, 0.6))
if lucky(0.5):
img = noise_speckle(img, uniform(0.3, 0.6))
if lucky(0.3):
img = img // 2 + img0 // 2
return img
def gaussian_blur(im, sigma=1, kernel_size=None, severity=1):
"""Gaussian blur, if kernel_size is passed in, severity will be invalid"""
if kernel_size is None:
step = 11
kernel_size = int(step * severity)
if kernel_size < 3.0:
return im
if kernel_size % 2 == 0:
kernel_size -= 1
return np.array(cv2.GaussianBlur(im, (kernel_size, kernel_size), sigma))
def rotate_shrink(im, max_angle=6, severity=0.5, cval=255):
"""rotate about center, shrink to keep the same size without cropping image"""
max_angle = int(abs(severity) * max_angle)
angle = randint(-max_angle, max_angle)
h, w = im.shape[:2]
rangle = np.deg2rad(angle) # angle in radians
# now calculate new image width and height
nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
scale = min(w / nw, h / nh)
mat = cv2.getRotationMatrix2D((w // 2, h // 2), angle, scale)
im = cv2.warpAffine(im, mat, (w, h), borderValue=cval)
return np.array(im)
def rotate_about_center(im, angle=4, scale=1, b_mode=None, cval=None, severity=1):
"""For the rotation effect, it is recommended to make b_mode not equal to None for color images, so that the
filling will copy the edge pixel filling """
angle = severity * angle
if angle == 0:
return im
w = im.shape[1]
h = im.shape[0]
rangle = np.deg2rad(angle) # angle in radians
# now calculate new image width and height
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
# ask OpenCV for the rotation matrix
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
# calculate the move from the old center to the new center combined
# with the rotation
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
# the move only affects the translation, so update the translation
# part of the transform
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
if cval is None:
cval = [0] * im.shape[2]
elif isinstance(cval, (int, float)):
cval = [cval] * im.shape[2]
if b_mode is None:
src = cv2.warpAffine(im, rot_mat, (int(ceil(nw)), int(ceil(nh))), flags=cv2.INTER_LANCZOS4,
borderMode=cv2.BORDER_CONSTANT, borderValue=cval)
else:
src = cv2.warpAffine(im, rot_mat, (int(ceil(nw)), int(ceil(nh))), flags=cv2.INTER_LANCZOS4,
borderMode=cv2.BORDER_REPLICATE)
return np.array(src)
def randcrop(img, max_per=0.15, severity=1):
"""Random crop"""
perc = max_per * severity
rows, cols = img.shape[:2]
k = int(rows * cols * perc / (rows + cols))
roi = img[randint(0, k):rows - randint(0, k), randint(0, k):cols - randint(0, k)]
return np.array(roi)
def enhance_sharpness(img, r=None, severity=1):
"""
adjust the sharpness of an image. An
enhancement factor of 0.0 gives a blurred image, a factor of 1.0 gives the
original image, and a factor of 2.0 gives a sharpened image.
"""
if r is None:
severity = abs(severity)
r = uniform(1 - 0.5 * severity, 1) if lucky(0.5) else uniform(1, 1 + severity)
img = Image.fromarray(img)
img = np.array(ImageEnhance.Sharpness(img).enhance(r))
return img
def enhance_contrast(img, r=None, severity=1):
"""
control the contrast of an image, similar
to the contrast control on a TV set. An enhancement factor of 0.0
gives a solid grey image. A factor of 1.0 gives the original image.
"""
if r is None:
severity = abs(severity)
r = uniform(1 - 0.5 * severity, 1) if lucky(0.5) else uniform(1, 1 + severity)
img = Image.fromarray(img)
img = np.array(ImageEnhance.Contrast(img).enhance(r))
return img
def enhance_brightness(img, r=None, severity=1):
"""
control the brightness of an image. An
enhancement factor of 0.0 gives a black image. A factor of 1.0 gives the
original image.
"""
if r is None:
severity = abs(severity)
r = uniform(1 - 0.2 * severity, 1) if lucky(0.5) else uniform(1, 1 + severity * 0.5)
img = Image.fromarray(img)
img = np.array(ImageEnhance.Brightness(img).enhance(r))
return img
def enhance_color(img, r=None, severity=1):
"""
adjust the colour balance of an image, in
a manner similar to the controls on a colour TV set. An enhancement
factor of 0.0 gives a black and white image. A factor of 1.0 gives
the original image.
"""
if r is None:
severity = abs(severity)
r = uniform(1 - 0.5 * severity, 1) if lucky(0.5) else uniform(1, 1 + severity)
img = Image.fromarray(img)
img = np.array(ImageEnhance.Color(img).enhance(r))
return img
def enhance(img):
"""combine image enhancement in the Image type, reduce conversions to np array"""
if lucky(0.3):
img = enhance_sharpness(img)
if lucky(0.3):
img = enhance_contrast(img)
if lucky(0.3):
img = enhance_brightness(img)
return np.array(img)
def draw_line(im):
"""draw a line randomly"""
h, w = im.shape[:2]
p1 = (randint(0, w // 3), randint(0, h - 1)) # from left 1/3
p2 = (randint(w // 3 * 2, w - 1), randint(0, h - 1)) # to right 1/3
color = [randint(0, 255) for i in range(3)]
lw = lucky_choice((1, 2), (0.8, 0.2))
cv2.line(im, p1, p2, color, lw, cv2.LINE_AA)
return np.array(im)
def center_im(im_outter, im_inner, shrink=True, vertical='center'):
"""center an image in a container image. `im_outter` can be the shape of it"""
if not isinstance(im_outter, np.ndarray):
shape = tuple(im_outter)
if im_inner.ndim > len(shape):
shape += im_inner.shape[len(shape):]
im_outter = np.zeros(shape, np.uint8)
H, W = im_outter.shape[:2]
h, w = im_inner.shape[:2]
if h > H or w > W:
if shrink:
rate = min(H / h, W / w)
im_inner = cv2.resize(im_inner, rate)
im_inner = im_inner[:H, :W]
h, w = im_inner.shape[:2]
vertical = vertical.lower()
if vertical == 'center':
dh = (H - h) // 2
elif vertical == 'top':
dh = 0
elif vertical == 'bottom':
dh = H - h
im = im_outter.copy()
dw = (W - w) // 2
im[dh:dh + h, dw:dw + w] = im_inner
return np.array(im)
def enhance_light(img):
"""combine image enhancement in the Image type, reduce conversions to np array"""
if lucky(0.3):
img = enhance_sharpness(img, uniform(0.5, 1.5))
if lucky(0.3):
img = enhance_contrast(img, uniform(0.7, 1.3))
if lucky(0.3):
img = enhance_brightness(img, uniform(0.85, 1.15))
return np.array(img)
def gaussian2d(w, h):
"""The two-dimensional Gaussian distribution effect is actually an ellipse"""
h = h // 2
w = w // 2
x = np.arange(-w, w)
y = np.arange(-h, h)
x, y = np.meshgrid(x, y)
mean_x = np.mean(x)
mean_y = np.mean(y)
std_x = np.std(x)
std_y = np.std(y)
z = np.exp(
-((y - mean_y) ** 2 / (std_y ** 2) + (x - mean_x) ** 2 / (std_x ** 2)) / 2
)
z /= (np.sqrt(2 * np.pi) * std_y)
z *= 1 / (np.max(z) - np.min(z))
return z
def add_stain(img, theta=200, severity=0.5, bright_spot=False, iteration=1):
"""Generate black stains or white bright spots"""
for _ in range(0, iteration):
img = np.float32(img)
theta = theta * abs(severity)
cols_big, rows_big = img.shape[:2]
temp = min([cols_big, rows_big])
if temp < 80:
temp = 80
if temp > 300:
temp = 300
if not bright_spot:
gaussian_img = gaussian2d(randint(temp // 3, temp // 2), randint(temp // 3, temp // 2)) * theta
else:
gaussian_img = gaussian2d(randint(temp // 1.5, int(temp / 0.8)),
randint(temp // 1.5, int(temp / 0.8)))
cols_small, rows_small = gaussian_img.shape[:2]
tmp_min = int(min(cols_small, rows_small))
# 对椭圆效果做大幅度扭曲cval最好不要过大。
gaussian_img = rdistort(gaussian_img, randint(tmp_min // 10, tmp_min // 6), cval=0)
x1 = randint(0, rows_big - 5 if rows_big - 5 > 0 else 0)
y1 = randint(0, cols_big - 5 if cols_big - 5 > 0 else 0)
if y1 + cols_small > cols_big:
y2 = int(cols_big - 1)
else:
y2 = int(y1 + cols_small)
if x1 + rows_small > rows_big:
x2 = int(rows_big - 1)
else:
x2 = int(x1 + rows_small)
row, col = gaussian_img.shape
gaussian_img = gaussian_img.repeat(img.shape[2], axis=1)
gaussian_img = gaussian_img.reshape(row, col, img.shape[2])
gaussian_img = np.float32(gaussian_img[:(y2 - y1), :(x2 - x1)])
if not bright_spot:
img[y1:y2, x1:x2] -= gaussian_img
else:
temp1 = min([np.median(gaussian_img), 255 - np.mean(img[y1:y2, x1:x2])])
gaussian_img = np.clip(gaussian_img - temp1, 0, 255)
img[y1:y2, x1:x2] = np.clip(img[y1:y2, x1:x2] + gaussian_img, 0, 255)
img = cvt_uint8(img)
return np.array(img)
def shift_color(im, delta_max=10, severity=0.5):
"""randomly shift image color"""
if severity == 0:
return im
delta_max = int(delta_max * severity)
if isinstance(delta_max, tuple):
delta_min, delta_max = delta_max
else:
delta_min = -delta_max
im = np.float32(im)
delta = np.random.randint(delta_min, delta_max, (1, 1, im.shape[2]))
im += delta
return np.array(cvt_uint8(im))
def random_contrast(img, contrast_delta=0.3, bright_delta=0.1):
"""randomly change image contrast and brightness"""
if isinstance(contrast_delta, tuple):
contrast_delta_min, contrast_delta = contrast_delta
else:
contrast_delta_min = -contrast_delta
if isinstance(bright_delta, tuple):
bright_delta_min, bright_delta = bright_delta
else:
bright_delta_min = -bright_delta
fc = 1 + uniform(contrast_delta_min, contrast_delta)
fb = 1 + uniform(bright_delta_min, bright_delta)
im = img.astype(np.float32)
if img.ndim == 2:
im = im[:, :, None]
mn = im.mean(axis=(0, 1), keepdims=True)
im = (im - mn) * fc + mn * fb
im = im.clip(0, 255).astype(np.uint8)
return np.array(im)
def period_map(xi, times, extent):
if times < 1:
return None
times = float(times)
theta = randint(extent, extent + 10) * choice([1, -1])
def back(x):
if x < times / 2.0:
# Here only the effect of a sin function is achieved, and more effects can be added later.
return theta * sin(pi * (3 / 2.0 + x / times)) # Monotonically increasing
return theta * sin(pi * (1 / 2.0 + x / times))
xi = np.fabs(xi)
xi = xi % times
yi = np.array(list(map(back, xi)))
return yi
def whole_rdistort(im, severity=1, scop=40):
"""
Using the affine projection method in skimg,
Realize the picture through the corresponding coordinate projection
Specifies the distortion effect of the form. This function will normalize 0-1
"""
if severity == 0:
return im
theta = severity * scop
rows, cols = im.shape[:2]
colpoints = max(int(cols * severity * 0.05), 3)
rowpoints = max(int(rows * severity * 0.05), 3)
src_cols = np.linspace(0, cols, colpoints)
src_rows = np.linspace(0, rows, rowpoints)
src_rows, src_cols = np.meshgrid(src_rows, src_cols)
src = np.dstack([src_cols.flat, src_rows.flat])[0]
# The key location for wave distortion effect
dst_rows = src[:, 1] - period_map(np.linspace(0, 100, src.shape[0]), 50, 20)
# dst columns
dst_cols = src[:, 0] - np.sin(np.linspace(0, 3 * np.pi, src.shape[0])) * theta
dst = np.vstack([dst_cols, dst_rows]).T
tform = PiecewiseAffineTransform()
tform.estimate(src, dst)
image = warp(im, tform, mode='edge', output_shape=(rows, cols)) * 255
return np.array(cvt_uint8(image))
def lucky_choice(seq, ps=None, rand_func=random):
"""randomly choose an element from `seq` according to their probability distribution `ps`"""
if not seq:
return None
if ps is None:
return choice(seq)
cumps = np.cumsum(ps)
r = rand_func() * cumps[-1]
idx = (cumps < r).sum()
idx = min(idx, len(seq) - 1)
return seq[idx]
def cvt_uint8(im):
"""convert image type to `np.uint8`"""
if im.dtype == np.uint8:
return im
return np.round(im).clip(0, 255).astype(np.uint8)
def to_image(im):
"""convert `im` to `Image` type"""
if not isinstance(im, Image.Image):
if im.ndim == 3:
im = im[:, :, ::-1] # reverse channels: BGR in cv2 to RGB in Image
im = Image.fromarray(im)
return im
def to_array(im):
"""convert `im` to `np.array` type"""
if isinstance(im, Image.Image):
im = np.array(im)
if im.ndim == 3:
im = im[:, :, ::-1] # reverse channels: RGB in Image to BGR in cv2
return im
def unify_img(img, img_height=64, max_length=512, img_channel=3):
color_fill = 255
img_shape = img.shape
img_width = int(float(img_shape[1]) / img_shape[0] * img_height)
img = cv2.resize(img, (img_width, img_height))
if img_width > max_length:
img = img[:, 0:max_length]
else:
blank_img = np.zeros((img_height, max_length, img_channel), np.uint8)
# fill the image with white
blank_img.fill(color_fill)
blank_img[0:img_height, 0:img_width] = img
img = blank_img
return np.array(img)
def unify_img_label(img, label, img_height=64, max_length=512, min_length=192, img_channel=3):
color_fill = 255
img_shape = img.shape
img_width = int(float(img_shape[1]) / img_shape[0] * img_height)
img = cv2.resize(img, (img_width, img_height))
if img_width > max_length:
img = img[:, 0:max_length]
else:
blank_img = np.zeros((img_height, max_length, img_channel), np.uint8)
# fill the image with white
blank_img.fill(color_fill)
blank_img[0:img_height, 0:img_width] = img
img = blank_img
return np.array(img), label

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train CNN direction model."""
import argparse
import os
import random
from src.cnn_direction_model import CNNDirectionModel
from src.config import config1 as config
from src.dataset import create_dataset_train
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore import dataset as de
from mindspore.communication.management import init
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.metrics import Accuracy
from mindspore.nn.optim.adam import Adam
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
args_opt = parser.parse_args()
random.seed(11)
np.random.seed(11)
de.config.set_seed(11)
ms.common.set_seed(11)
if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
# init context
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
rank_size = int(os.getenv('RANK_SIZE', '1'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
print("train args: ", args_opt, "\ncfg: ", config,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
init()
# create dataset
dataset = create_dataset_train(args_opt.dataset_path + "/ocr_pos.mindrecord0", config=config)
step_size = dataset.get_dataset_size()
# define net
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
lr = config.lr
lr = Tensor(lr, ms.float32)
# define opt
opt = Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
model = Model(net, loss_fn=loss, optimizer=opt, metrics={"Accuracy": Accuracy()})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=2500,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)