forked from mindspore-Ecosystem/mindspore
!1637 Add DeepLabV3 network
Merge pull request !1637 from z00378171/deeplabv3
This commit is contained in:
commit
eec2678ebe
|
@ -0,0 +1,66 @@
|
|||
# Deeplab-V3 Example
|
||||
|
||||
## Description
|
||||
This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore.
|
||||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download the VOC 2012 dataset for training.
|
||||
|
||||
> Notes:
|
||||
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
|
||||
|
||||
|
||||
## Running the Example
|
||||
### Training
|
||||
- Set options in config.py.
|
||||
- Run `run_standalone_train.sh` for non-distributed training.
|
||||
``` bash
|
||||
sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR
|
||||
```
|
||||
- Run `run_distribute_train.sh` for distributed training.
|
||||
``` bash
|
||||
sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH
|
||||
```
|
||||
### Evaluation
|
||||
Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path.
|
||||
- Run run_eval.sh for evaluation.
|
||||
``` bash
|
||||
sh scripts/run_eval.sh DEVICE_ID DATA_DIR
|
||||
```
|
||||
|
||||
## Options and Parameters
|
||||
It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py.
|
||||
|
||||
### Options:
|
||||
```
|
||||
config.py:
|
||||
learning_rate Learning rate, default is 0.0014.
|
||||
weight_decay Weight decay, default is 5e-5.
|
||||
momentum Momentum, default is 0.97.
|
||||
crop_size Image crop size [height, width] during training, default is 513.
|
||||
eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75].
|
||||
output_stride The ratio of input to output spatial resolution, default is 16.
|
||||
ignore_label Ignore label value, default is 255.
|
||||
seg_num_classes Number of semantic classes, including the background class (if exists).
|
||||
foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21.
|
||||
fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False.
|
||||
atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None.
|
||||
decoder_output_stride The ratio of input to output spatial resolution when employing decoder
|
||||
to refine segmentation results, default is None.
|
||||
image_pyramid Input scales for multi-scale feature extraction, default is None.
|
||||
```
|
||||
|
||||
|
||||
### Parameters:
|
||||
```
|
||||
Parameters for dataset and network:
|
||||
distribute Run distribute, default is false.
|
||||
epoch_size Epoch size, default is 6.
|
||||
batch_size batch size of input dataset: N, default is 2.
|
||||
data_url Train/Evaluation data url, required.
|
||||
checkpoint_url Checkpoint path, default is None.
|
||||
enable_save_ckpt Enable save checkpoint, default is true.
|
||||
save_checkpoint_steps Save checkpoint steps, default is 1000.
|
||||
save_checkpoint_num Save checkpoint numbers, default is 1.
|
||||
```
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluation."""
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.md_dataset import create_dataset
|
||||
from src.losses import OhemLoss
|
||||
from src.miou_precision import MiouPrecision
|
||||
from src.deeplabv3 import deeplabv3_resnet50
|
||||
from src.config import config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
|
||||
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url')
|
||||
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
print(args_opt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args_opt.crop_size = config.crop_size
|
||||
args_opt.base_size = config.crop_size
|
||||
eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval")
|
||||
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
|
||||
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
||||
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
||||
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
mIou = MiouPrecision(config.seg_num_classes)
|
||||
metrics = {'mIou': mIou}
|
||||
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
|
||||
model = Model(net, loss, metrics=metrics)
|
||||
model.eval(eval_dataset)
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: bash run_distribute_train.sh 8 40 /path/zh-wiki/ /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$4
|
||||
export RANK_TABLE_FILE=$4
|
||||
export RANK_SIZE=$1
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ../train.py \
|
||||
--distribute="true" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt="true" \
|
||||
--checkpoint_url="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_url=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_eval.sh DEVICE_ID DATA_DIR"
|
||||
echo "for example: bash run_eval.sh /path/zh-wiki/ "
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python evaluation.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--checkpoint_url="" \
|
||||
--data_url=$DATA_DIR > log.txt 2>&1 &
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
|
||||
echo "for example: bash run_standalone_train.sh 0 40 /path/zh-wiki/ "
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python train.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt="true" \
|
||||
--checkpoint_url="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_url=$DATA_DIR > log.txt 2>&1 &
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Init DeepLabv3."""
|
||||
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
|
||||
from .backbone import *
|
||||
|
||||
__all__ = [
|
||||
"ASPP", "DeepLabV3", "deeplabv3_resnet50"
|
||||
]
|
||||
|
||||
__all__.extend(backbone.__all__)
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Init backbone."""
|
||||
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
|
||||
RootBlockBeta, resnet50_dl
|
||||
|
||||
__all__ = [
|
||||
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl"
|
||||
]
|
|
@ -0,0 +1,577 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet based DeepLab."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import twice
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
def _conv_bn_relu(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
pad_mode="pad",
|
||||
use_batch_statistics=False):
|
||||
"""Get a conv2d -> batchnorm -> relu layer"""
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics),
|
||||
nn.ReLU()]
|
||||
)
|
||||
|
||||
|
||||
def _deep_conv_bn_relu(in_channel,
|
||||
channel_multiplier,
|
||||
ksize,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
pad_mode="pad",
|
||||
use_batch_statistics=False):
|
||||
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
|
||||
return nn.SequentialCell(
|
||||
[DepthwiseConv2dNative(in_channel,
|
||||
channel_multiplier,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
|
||||
nn.ReLU()]
|
||||
)
|
||||
|
||||
|
||||
def _stob_deep_conv_btos_bn_relu(in_channel,
|
||||
channel_multiplier,
|
||||
ksize,
|
||||
space_to_batch_block_shape,
|
||||
batch_to_space_block_shape,
|
||||
paddings,
|
||||
crops,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
pad_mode="pad",
|
||||
use_batch_statistics=False):
|
||||
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
|
||||
return nn.SequentialCell(
|
||||
[SpaceToBatch(space_to_batch_block_shape, paddings),
|
||||
DepthwiseConv2dNative(in_channel,
|
||||
channel_multiplier,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
BatchToSpace(batch_to_space_block_shape, crops),
|
||||
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
|
||||
nn.ReLU()]
|
||||
)
|
||||
|
||||
|
||||
def _stob_conv_btos_bn_relu(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
space_to_batch_block_shape,
|
||||
batch_to_space_block_shape,
|
||||
paddings,
|
||||
crops,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
pad_mode="pad",
|
||||
use_batch_statistics=False):
|
||||
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
|
||||
return nn.SequentialCell([SpaceToBatch(space_to_batch_block_shape, paddings),
|
||||
nn.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
BatchToSpace(batch_to_space_block_shape, crops),
|
||||
nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics),
|
||||
nn.ReLU()]
|
||||
)
|
||||
|
||||
|
||||
def _make_layer(block,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_blocks,
|
||||
stride=1,
|
||||
rate=1,
|
||||
multi_grads=None,
|
||||
output_stride=None,
|
||||
g_current_stride=2,
|
||||
g_rate=1):
|
||||
"""Make layer for DeepLab-ResNet network."""
|
||||
if multi_grads is None:
|
||||
multi_grads = [1] * num_blocks
|
||||
# (stride == 2, num_blocks == 4 --> strides == [1, 1, 1, 2])
|
||||
strides = [1] * (num_blocks - 1) + [stride]
|
||||
blocks = []
|
||||
if output_stride is not None:
|
||||
if output_stride % 4 != 0:
|
||||
raise ValueError('The output_stride needs to be a multiple of 4.')
|
||||
output_stride //= 4
|
||||
for i_stride, _ in enumerate(strides):
|
||||
if output_stride is not None and g_current_stride > output_stride:
|
||||
raise ValueError('The target output_stride cannot be reached.')
|
||||
if output_stride is not None and g_current_stride == output_stride:
|
||||
b_rate = g_rate
|
||||
b_stride = 1
|
||||
g_rate *= strides[i_stride]
|
||||
else:
|
||||
b_rate = rate
|
||||
b_stride = strides[i_stride]
|
||||
g_current_stride *= strides[i_stride]
|
||||
blocks.append(block(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=b_stride,
|
||||
rate=b_rate,
|
||||
multi_grad=multi_grads[i_stride]))
|
||||
in_channels = out_channels
|
||||
layer = nn.SequentialCell(blocks)
|
||||
return layer, g_current_stride, g_rate
|
||||
|
||||
|
||||
class Subsample(nn.Cell):
|
||||
"""
|
||||
Subsample for DeepLab-ResNet.
|
||||
Args:
|
||||
factor (int): Sample factor.
|
||||
Returns:
|
||||
Tensor, the sub sampled tensor.
|
||||
Examples:
|
||||
>>> Subsample(2)
|
||||
"""
|
||||
def __init__(self, factor):
|
||||
super(Subsample, self).__init__()
|
||||
self.factor = factor
|
||||
self.pool = nn.MaxPool2d(kernel_size=1,
|
||||
stride=factor)
|
||||
|
||||
def construct(self, x):
|
||||
if self.factor == 1:
|
||||
return x
|
||||
return self.pool(x)
|
||||
|
||||
|
||||
class SpaceToBatch(nn.Cell):
|
||||
def __init__(self, block_shape, paddings):
|
||||
super(SpaceToBatch, self).__init__()
|
||||
self.space_to_batch = P.SpaceToBatch(block_shape, paddings)
|
||||
self.bs = block_shape
|
||||
self.pd = paddings
|
||||
|
||||
def construct(self, x):
|
||||
return self.space_to_batch(x)
|
||||
|
||||
|
||||
class BatchToSpace(nn.Cell):
|
||||
def __init__(self, block_shape, crops):
|
||||
super(BatchToSpace, self).__init__()
|
||||
self.batch_to_space = P.BatchToSpace(block_shape, crops)
|
||||
self.bs = block_shape
|
||||
self.cr = crops
|
||||
|
||||
def construct(self, x):
|
||||
return self.batch_to_space(x)
|
||||
|
||||
|
||||
class _DepthwiseConv2dNative(nn.Cell):
|
||||
"""Depthwise Conv2D Cell."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channel_multiplier,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
weight_init):
|
||||
super(_DepthwiseConv2dNative, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.group = group
|
||||
if not (isinstance(in_channels, int) and in_channels > 0):
|
||||
raise ValueError('Attr \'in_channels\' of \'DepthwiseConv2D\' Op passed '
|
||||
+ str(in_channels) + ', should be a int and greater than 0.')
|
||||
if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \
|
||||
(not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
|
||||
kernel_size[0] < 1 or kernel_size[1] < 1:
|
||||
raise ValueError('Attr \'kernel_size\' of \'DepthwiseConv2D\' Op passed '
|
||||
+ str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.')
|
||||
self.weight = Parameter(initializer(weight_init, [1, in_channels // group, *kernel_size]),
|
||||
name='weight')
|
||||
|
||||
def construct(self, *inputs):
|
||||
"""Must be overridden by all subclasses."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DepthwiseConv2dNative(_DepthwiseConv2dNative):
|
||||
"""Depthwise Conv2D Cell."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channel_multiplier,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
weight_init='normal'):
|
||||
kernel_size = twice(kernel_size)
|
||||
super(DepthwiseConv2dNative, self).__init__(
|
||||
in_channels,
|
||||
channel_multiplier,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
weight_init)
|
||||
self.depthwise_conv2d_native = P.DepthwiseConv2dNative(channel_multiplier=self.channel_multiplier,
|
||||
kernel_size=self.kernel_size,
|
||||
mode=3,
|
||||
pad_mode=self.pad_mode,
|
||||
pad=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=self.group)
|
||||
|
||||
def set_strategy(self, strategy):
|
||||
self.depthwise_conv2d_native.set_strategy(strategy)
|
||||
return self
|
||||
|
||||
def construct(self, x):
|
||||
return self.depthwise_conv2d_native(x, self.weight)
|
||||
|
||||
|
||||
class BottleneckV1(nn.Cell):
|
||||
"""
|
||||
ResNet V1 BottleneckV1 block definition.
|
||||
Args:
|
||||
in_channels (int): Input channel.
|
||||
out_channels (int): Output channel.
|
||||
stride (int): Stride size for the initial convolutional layer. Default: 1.
|
||||
rate (int): Rate for convolution. Default: 1.
|
||||
multi_grad (int): Employ a rate within network. Default: 1.
|
||||
Returns:
|
||||
Tensor, the ResNet unit's output.
|
||||
Examples:
|
||||
>>> BottleneckV1(3,256,stride=2)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
use_batch_statistics=False,
|
||||
use_batch_to_stob_and_btos=False):
|
||||
super(BottleneckV1, self).__init__()
|
||||
expansion = 4
|
||||
mid_channels = out_channels // expansion
|
||||
self.conv_bn1 = _conv_bn_relu(in_channels,
|
||||
mid_channels,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
self.conv_bn2 = _conv_bn_relu(mid_channels,
|
||||
mid_channels,
|
||||
ksize=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
if use_batch_to_stob_and_btos:
|
||||
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
|
||||
mid_channels,
|
||||
ksize=3,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
space_to_batch_block_shape=2,
|
||||
batch_to_space_block_shape=2,
|
||||
paddings=[[2, 3], [2, 3]],
|
||||
crops=[[0, 1], [0, 1]],
|
||||
pad_mode="valid",
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
|
||||
self.conv3 = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
if in_channels != out_channels:
|
||||
conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride)
|
||||
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
self.downsample = nn.SequentialCell([conv, bn])
|
||||
else:
|
||||
self.downsample = Subsample(stride)
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = nn.ReLU()
|
||||
self.Reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv_bn1(x)
|
||||
out = self.conv_bn2(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = self.add(out, self.downsample(x))
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class BottleneckV2(nn.Cell):
|
||||
"""
|
||||
ResNet V2 Bottleneck variance V2 block definition.
|
||||
Args:
|
||||
in_channels (int): Input channel.
|
||||
out_channels (int): Output channel.
|
||||
stride (int): Stride size for the initial convolutional layer. Default: 1.
|
||||
Returns:
|
||||
Tensor, the ResNet unit's output.
|
||||
Examples:
|
||||
>>> BottleneckV2(3,256,stride=2)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
use_batch_statistics=False,
|
||||
use_batch_to_stob_and_btos=False,
|
||||
dilation=1):
|
||||
super(BottleneckV2, self).__init__()
|
||||
expansion = 4
|
||||
mid_channels = out_channels // expansion
|
||||
self.conv_bn1 = _conv_bn_relu(in_channels,
|
||||
mid_channels,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
self.conv_bn2 = _conv_bn_relu(mid_channels,
|
||||
mid_channels,
|
||||
ksize=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
if use_batch_to_stob_and_btos:
|
||||
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
|
||||
mid_channels,
|
||||
ksize=3,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
space_to_batch_block_shape=2,
|
||||
batch_to_space_block_shape=2,
|
||||
paddings=[[2, 3], [2, 3]],
|
||||
crops=[[0, 1], [0, 1]],
|
||||
pad_mode="valid",
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
self.conv3 = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
if in_channels != out_channels:
|
||||
conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride)
|
||||
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
self.downsample = nn.SequentialCell([conv, bn])
|
||||
else:
|
||||
self.downsample = Subsample(stride)
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv_bn1(x)
|
||||
out = self.conv_bn2(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = self.add(out, x)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class BottleneckV3(nn.Cell):
|
||||
"""
|
||||
ResNet V1 Bottleneck variance V1 block definition.
|
||||
Args:
|
||||
in_channels (int): Input channel.
|
||||
out_channels (int): Output channel.
|
||||
stride (int): Stride size for the initial convolutional layer. Default: 1.
|
||||
Returns:
|
||||
Tensor, the ResNet unit's output.
|
||||
Examples:
|
||||
>>> BottleneckV3(3,256,stride=2)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
use_batch_statistics=False):
|
||||
super(BottleneckV3, self).__init__()
|
||||
expansion = 4
|
||||
mid_channels = out_channels // expansion
|
||||
self.conv_bn1 = _conv_bn_relu(in_channels,
|
||||
mid_channels,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
self.conv_bn2 = _conv_bn_relu(mid_channels,
|
||||
mid_channels,
|
||||
ksize=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
self.conv3 = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
if in_channels != out_channels:
|
||||
conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride)
|
||||
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
|
||||
self.downsample = nn.SequentialCell([conv, bn])
|
||||
else:
|
||||
self.downsample = Subsample(stride)
|
||||
self.downsample = Subsample(stride)
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv_bn1(x)
|
||||
out = self.conv_bn2(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = self.add(out, self.downsample(x))
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetV1(nn.Cell):
|
||||
"""
|
||||
ResNet V1 for DeepLab.
|
||||
Args:
|
||||
Returns:
|
||||
Tuple, output tensor tuple, (c2,c5).
|
||||
Examples:
|
||||
>>> ResNetV1(False)
|
||||
"""
|
||||
def __init__(self, fine_tune_batch_norm=False):
|
||||
super(ResNetV1, self).__init__()
|
||||
self.layer_root = nn.SequentialCell(
|
||||
[RootBlockBeta(fine_tune_batch_norm),
|
||||
nn.MaxPool2d(kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
pad_mode='same')])
|
||||
self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer1_3 = BottleneckV3(256, 256, stride=2, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer2_1 = BottleneckV1(256, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer2_2 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer2_3 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer2_4 = BottleneckV3(512, 512, stride=2, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_1 = BottleneckV1(512, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_2 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_3 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
|
||||
|
||||
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.layer_root(x)
|
||||
x = self.layer1_1(x)
|
||||
c2 = self.layer1_2(x)
|
||||
x = self.layer1_3(c2)
|
||||
x = self.layer2_1(x)
|
||||
x = self.layer2_2(x)
|
||||
x = self.layer2_3(x)
|
||||
x = self.layer2_4(x)
|
||||
x = self.layer3_1(x)
|
||||
x = self.layer3_2(x)
|
||||
x = self.layer3_3(x)
|
||||
x = self.layer3_4(x)
|
||||
x = self.layer3_5(x)
|
||||
x = self.layer3_6(x)
|
||||
|
||||
x = self.layer4_1(x)
|
||||
x = self.layer4_2(x)
|
||||
c5 = self.layer4_3(x)
|
||||
return c2, c5
|
||||
|
||||
|
||||
class RootBlockBeta(nn.Cell):
|
||||
"""
|
||||
ResNet V1 beta root block definition.
|
||||
Returns:
|
||||
Tensor, the block unit's output.
|
||||
Examples:
|
||||
>>> RootBlockBeta()
|
||||
"""
|
||||
def __init__(self, fine_tune_batch_norm=False):
|
||||
super(RootBlockBeta, self).__init__()
|
||||
self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid",
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same",
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_mode="same",
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet50_dl(fine_tune_batch_norm=False):
|
||||
return ResNetV1(fine_tune_batch_norm)
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and evaluation.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"learning_rate": 0.0014,
|
||||
"weight_decay": 0.00005,
|
||||
"momentum": 0.97,
|
||||
"crop_size": 513,
|
||||
"eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
"atrous_rates": None,
|
||||
"image_pyramid": None,
|
||||
"output_stride": 16,
|
||||
"fine_tune_batch_norm": False,
|
||||
"ignore_label": 255,
|
||||
"decoder_output_stride": None,
|
||||
"seg_num_classes": 21
|
||||
})
|
|
@ -0,0 +1,457 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DeepLabv3."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
|
||||
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
|
||||
|
||||
|
||||
class ASPPSampleBlock(nn.Cell):
|
||||
"""ASPP sample block."""
|
||||
def __init__(self, feature_shape, scale_size, output_stride):
|
||||
super(ASPPSampleBlock, self).__init__()
|
||||
sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
|
||||
sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
|
||||
self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
return self.sample(x)
|
||||
|
||||
|
||||
class ASPP(nn.Cell):
|
||||
"""
|
||||
ASPP model for DeepLabv3.
|
||||
|
||||
Args:
|
||||
channel (int): Input channel.
|
||||
depth (int): Output channel.
|
||||
feature_shape (list): The shape of feature,[h,w].
|
||||
scale_sizes (list): Input scales for multi-scale feature extraction.
|
||||
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
|
||||
output_stride (int): 'The ratio of input to output spatial resolution.'
|
||||
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ASPP(channel=2048,256,[14,14],[1],[6],16)
|
||||
"""
|
||||
def __init__(self, channel, depth, feature_shape, scale_sizes,
|
||||
atrous_rates, output_stride, fine_tune_batch_norm=False):
|
||||
super(ASPP, self).__init__()
|
||||
self.aspp0 = _conv_bn_relu(channel,
|
||||
depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.atrous_rates = []
|
||||
if atrous_rates is not None:
|
||||
self.atrous_rates = atrous_rates
|
||||
self.aspp_pointwise = _conv_bn_relu(channel,
|
||||
depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
|
||||
channel_multiplier=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_mode="valid")
|
||||
self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
|
||||
self.aspp_depth_relu = nn.ReLU()
|
||||
self.aspp_depths = []
|
||||
self.aspp_depth_spacetobatchs = []
|
||||
self.aspp_depth_batchtospaces = []
|
||||
|
||||
for scale_size in scale_sizes:
|
||||
aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
|
||||
if atrous_rates is None:
|
||||
break
|
||||
for rate in atrous_rates:
|
||||
padding = 0
|
||||
for j in range(100):
|
||||
padded_size = rate * j
|
||||
if padded_size >= aspp_scale_depth_size + 2 * rate:
|
||||
padding = padded_size - aspp_scale_depth_size - 2 * rate
|
||||
break
|
||||
paddings = [[rate, rate + int(padding)],
|
||||
[rate, rate + int(padding)]]
|
||||
self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings)
|
||||
self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
|
||||
crops = [[0, int(padding)], [0, int(padding)]]
|
||||
self.aspp_depth_batchtospace = BatchToSpace(rate, crops)
|
||||
self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
|
||||
self.aspp_depths = nn.CellList(self.aspp_depths)
|
||||
self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
|
||||
self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
|
||||
|
||||
self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1])))
|
||||
self.global_poolings = []
|
||||
for scale_size in scale_sizes:
|
||||
pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
|
||||
pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
|
||||
self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
|
||||
self.global_poolings = nn.CellList(self.global_poolings)
|
||||
self.conv_bn = _conv_bn_relu(channel,
|
||||
depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.samples = []
|
||||
for scale_size in scale_sizes:
|
||||
self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride))
|
||||
self.samples = nn.CellList(self.samples)
|
||||
self.feature_shape = feature_shape
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, x, scale_index=0):
|
||||
aspp0 = self.aspp0(x)
|
||||
aspp1 = self.global_poolings[scale_index](x)
|
||||
aspp1 = self.conv_bn(aspp1)
|
||||
aspp1 = self.samples[scale_index](aspp1)
|
||||
output = self.concat((aspp1, aspp0))
|
||||
|
||||
for i in range(len(self.atrous_rates)):
|
||||
aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
|
||||
aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
|
||||
aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
|
||||
aspp_i = self.aspp_depth_bn(aspp_i)
|
||||
aspp_i = self.aspp_depth_relu(aspp_i)
|
||||
aspp_i = self.aspp_pointwise(aspp_i)
|
||||
output = self.concat((output, aspp_i))
|
||||
return output
|
||||
|
||||
|
||||
class DecoderSampleBlock(nn.Cell):
|
||||
"""Decoder sample block."""
|
||||
def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4):
|
||||
super(DecoderSampleBlock, self).__init__()
|
||||
sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
|
||||
sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
|
||||
self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
return self.sample(x)
|
||||
|
||||
|
||||
class Decoder(nn.Cell):
|
||||
"""
|
||||
Decode module for DeepLabv3.
|
||||
Args:
|
||||
low_level_channel (int): Low level input channel
|
||||
channel (int): Input channel.
|
||||
depth (int): Output channel.
|
||||
feature_shape (list): 'Input image shape, [N,C,H,W].'
|
||||
scale_sizes (list): 'Input scales for multi-scale feature extraction.'
|
||||
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
|
||||
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
Examples:
|
||||
>>> Decoder(256, 100, [56,56])
|
||||
"""
|
||||
def __init__(self,
|
||||
low_level_channel,
|
||||
channel,
|
||||
depth,
|
||||
feature_shape,
|
||||
scale_sizes,
|
||||
decoder_output_stride,
|
||||
fine_tune_batch_norm):
|
||||
super(Decoder, self).__init__()
|
||||
self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
|
||||
pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
|
||||
self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
|
||||
channel_multiplier=1,
|
||||
ksize=3,
|
||||
stride=1,
|
||||
pad_mode="same",
|
||||
dilation=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
|
||||
depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.decoder_depth1 = _deep_conv_bn_relu(depth,
|
||||
channel_multiplier=1,
|
||||
ksize=3,
|
||||
stride=1,
|
||||
pad_mode="same",
|
||||
dilation=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.decoder_pointwise1 = _conv_bn_relu(depth,
|
||||
depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.depth = depth
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.samples = []
|
||||
for scale_size in scale_sizes:
|
||||
self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride))
|
||||
self.samples = nn.CellList(self.samples)
|
||||
|
||||
def construct(self, x, low_level_feature, scale_index):
|
||||
low_level_feature = self.feature_projection(low_level_feature)
|
||||
low_level_feature = self.samples[scale_index](low_level_feature)
|
||||
x = self.samples[scale_index](x)
|
||||
output = self.concat((x, low_level_feature))
|
||||
output = self.decoder_depth0(output)
|
||||
output = self.decoder_pointwise0(output)
|
||||
output = self.decoder_depth1(output)
|
||||
output = self.decoder_pointwise1(output)
|
||||
return output
|
||||
|
||||
|
||||
class SingleDeepLabV3(nn.Cell):
|
||||
"""
|
||||
DeepLabv3 Network.
|
||||
Args:
|
||||
num_classes (int): Class number.
|
||||
feature_shape (list): Input image shape, [N,C,H,W].
|
||||
backbone (Cell): Backbone Network.
|
||||
channel (int): Resnet output channel.
|
||||
depth (int): ASPP block depth.
|
||||
scale_sizes (list): Input scales for multi-scale feature extraction.
|
||||
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
|
||||
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
|
||||
output_stride (int): 'The ratio of input to output spatial resolution.'
|
||||
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
Examples:
|
||||
>>> SingleDeepLabV3(num_classes=10,
|
||||
>>> feature_shape=[1,3,224,224],
|
||||
>>> backbone=resnet50_dl(),
|
||||
>>> channel=2048,
|
||||
>>> depth=256)
|
||||
>>> scale_sizes=[1.0])
|
||||
>>> atrous_rates=[6])
|
||||
>>> decoder_output_stride=4)
|
||||
>>> output_stride=16)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
feature_shape,
|
||||
backbone,
|
||||
channel,
|
||||
depth,
|
||||
scale_sizes,
|
||||
atrous_rates,
|
||||
decoder_output_stride,
|
||||
output_stride,
|
||||
fine_tune_batch_norm=False):
|
||||
super(SingleDeepLabV3, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.channel = channel
|
||||
self.depth = depth
|
||||
self.scale_sizes = []
|
||||
for scale_size in np.sort(scale_sizes):
|
||||
self.scale_sizes.append(scale_size)
|
||||
self.net = backbone
|
||||
self.aspp = ASPP(channel=self.channel,
|
||||
depth=self.depth,
|
||||
feature_shape=[feature_shape[2],
|
||||
feature_shape[3]],
|
||||
scale_sizes=self.scale_sizes,
|
||||
atrous_rates=atrous_rates,
|
||||
output_stride=output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||
self.aspp.add_flags(loop_can_unroll=True)
|
||||
atrous_rates_len = 0
|
||||
if atrous_rates is not None:
|
||||
atrous_rates_len = len(atrous_rates)
|
||||
self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
|
||||
ksize=1,
|
||||
stride=1,
|
||||
use_batch_statistics=fine_tune_batch_norm)
|
||||
self.fc2 = nn.Conv2d(depth,
|
||||
num_classes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
has_bias=True)
|
||||
self.upsample = P.ResizeBilinear((int(feature_shape[2]),
|
||||
int(feature_shape[3])),
|
||||
align_corners=True)
|
||||
self.samples = []
|
||||
for scale_size in self.scale_sizes:
|
||||
self.samples.append(SampleBlock(feature_shape, scale_size))
|
||||
self.samples = nn.CellList(self.samples)
|
||||
self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
|
||||
float(feature_shape[3])]
|
||||
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.dropout = nn.Dropout(keep_prob=0.9)
|
||||
self.shape = P.Shape()
|
||||
self.decoder_output_stride = decoder_output_stride
|
||||
if decoder_output_stride is not None:
|
||||
self.decoder = Decoder(low_level_channel=depth,
|
||||
channel=depth,
|
||||
depth=depth,
|
||||
feature_shape=[feature_shape[2],
|
||||
feature_shape[3]],
|
||||
scale_sizes=self.scale_sizes,
|
||||
decoder_output_stride=decoder_output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||
|
||||
def construct(self, x, scale_index=0):
|
||||
x = (2.0 / 255.0) * x - 1.0
|
||||
x = self.pad(x)
|
||||
low_level_feature, feature_map = self.net(x)
|
||||
for scale_size in self.scale_sizes:
|
||||
if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
|
||||
output = self.aspp(feature_map, scale_index)
|
||||
output = self.fc1(output)
|
||||
if self.decoder_output_stride is not None:
|
||||
output = self.decoder(output, low_level_feature, scale_index)
|
||||
output = self.fc2(output)
|
||||
output = self.samples[scale_index](output)
|
||||
return output
|
||||
scale_index += 1
|
||||
return feature_map
|
||||
|
||||
|
||||
class SampleBlock(nn.Cell):
|
||||
"""Sample block."""
|
||||
def __init__(self,
|
||||
feature_shape,
|
||||
scale_size=1.0):
|
||||
super(SampleBlock, self).__init__()
|
||||
sample_h = np.ceil(float(feature_shape[2]) * scale_size)
|
||||
sample_w = np.ceil(float(feature_shape[3]) * scale_size)
|
||||
self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
return self.sample(x)
|
||||
|
||||
|
||||
class DeepLabV3(nn.Cell):
|
||||
"""DeepLabV3 model."""
|
||||
def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
|
||||
decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
|
||||
super(DeepLabV3, self).__init__()
|
||||
self.infer_scale_sizes = []
|
||||
if infer_scale_sizes is not None:
|
||||
self.infer_scale_sizes = infer_scale_sizes
|
||||
|
||||
self.infer_scale_sizes = infer_scale_sizes
|
||||
if image_pyramid is None:
|
||||
image_pyramid = [1.0]
|
||||
|
||||
self.image_pyramid = image_pyramid
|
||||
scale_sizes = []
|
||||
for pyramid in image_pyramid:
|
||||
scale_sizes.append(pyramid)
|
||||
for scale in infer_scale_sizes:
|
||||
scale_sizes.append(scale)
|
||||
self.samples = []
|
||||
for scale_size in scale_sizes:
|
||||
self.samples.append(SampleBlock(feature_shape, scale_size))
|
||||
self.samples = nn.CellList(self.samples)
|
||||
self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
|
||||
feature_shape=feature_shape,
|
||||
backbone=resnet50_dl(fine_tune_batch_norm),
|
||||
channel=channel,
|
||||
depth=depth,
|
||||
scale_sizes=scale_sizes,
|
||||
atrous_rates=atrous_rates,
|
||||
decoder_output_stride=decoder_output_stride,
|
||||
output_stride=output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
|
||||
int(feature_shape[3])),
|
||||
align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
logits = ()
|
||||
if self.training:
|
||||
if len(self.image_pyramid) >= 1:
|
||||
if self.image_pyramid[0] == 1:
|
||||
logits = self.deeplabv3(x)
|
||||
else:
|
||||
x1 = self.samples[0](x)
|
||||
logits = self.deeplabv3(x1)
|
||||
logits = self.sample_common(logits)
|
||||
logits = self.expand_dims(logits, 2)
|
||||
for i in range(len(self.image_pyramid) - 1):
|
||||
x_i = self.samples[i + 1](x)
|
||||
logits_i = self.deeplabv3(x_i)
|
||||
logits_i = self.sample_common(logits_i)
|
||||
logits_i = self.expand_dims(logits_i, 2)
|
||||
logits = self.concat((logits, logits_i))
|
||||
logits = self.reduce_mean(logits, 2)
|
||||
return logits
|
||||
if len(self.infer_scale_sizes) >= 1:
|
||||
infer_index = len(self.image_pyramid)
|
||||
x1 = self.samples[infer_index](x)
|
||||
logits = self.deeplabv3(x1)
|
||||
logits = self.sample_common(logits)
|
||||
logits = self.softmax(logits)
|
||||
logits = self.expand_dims(logits, 2)
|
||||
for i in range(len(self.infer_scale_sizes) - 1):
|
||||
x_i = self.samples[i + 1 + infer_index](x)
|
||||
logits_i = self.deeplabv3(x_i)
|
||||
logits_i = self.sample_common(logits_i)
|
||||
logits_i = self.softmax(logits_i)
|
||||
logits_i = self.expand_dims(logits_i, 2)
|
||||
logits = self.concat((logits, logits_i))
|
||||
logits = self.reduce_mean(logits, 2)
|
||||
return logits
|
||||
|
||||
|
||||
def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
|
||||
infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
|
||||
output_stride=16, fine_tune_batch_norm=False):
|
||||
"""
|
||||
ResNet50 based DeepLabv3 network.
|
||||
|
||||
Args:
|
||||
num_classes (int): Class number.
|
||||
feature_shape (list): Input image shape, [N,C,H,W].
|
||||
image_pyramid (list): Input scales for multi-scale feature extraction.
|
||||
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
|
||||
infer_scale_sizes (list): 'The scales to resize images for inference.
|
||||
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
|
||||
output_stride (int): 'The ratio of input to output spatial resolution.'
|
||||
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 based DeepLabv3 neural network.
|
||||
|
||||
Examples:
|
||||
>>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
|
||||
"""
|
||||
return DeepLabV3(num_classes=num_classes,
|
||||
feature_shape=feature_shape,
|
||||
backbone=resnet50_dl(fine_tune_batch_norm),
|
||||
channel=2048,
|
||||
depth=256,
|
||||
infer_scale_sizes=infer_scale_sizes,
|
||||
atrous_rates=atrous_rates,
|
||||
decoder_output_stride=decoder_output_stride,
|
||||
output_stride=output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm,
|
||||
image_pyramid=image_pyramid)
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Process Dataset."""
|
||||
import abc
|
||||
import os
|
||||
import time
|
||||
|
||||
from .utils.adapter import get_raw_samples, read_image
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
"""
|
||||
Create dataset.
|
||||
|
||||
Args:
|
||||
data_url (str): The path of data.
|
||||
usage (str): Whether to use train or eval (default='train').
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
def __init__(self, data_url, usage):
|
||||
self.data_url = data_url
|
||||
self.usage = usage
|
||||
self.cur_index = 0
|
||||
self.samples = []
|
||||
_s_time = time.time()
|
||||
self._load_samples()
|
||||
_e_time = time.time()
|
||||
print(f"load samples success~, time cost = {_e_time - _s_time}")
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.samples[item]
|
||||
return self._next_data(sample)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
@staticmethod
|
||||
def _next_data(sample):
|
||||
image_path = sample[0]
|
||||
mask_image_path = sample[1]
|
||||
|
||||
image = read_image(image_path)
|
||||
mask_image = read_image(mask_image_path)
|
||||
return [image, mask_image]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_samples(self):
|
||||
pass
|
||||
|
||||
|
||||
class HwVocRawDataset(BaseDataset):
|
||||
"""
|
||||
Create dataset with raw data.
|
||||
|
||||
Args:
|
||||
data_url (str): The path of data.
|
||||
usage (str): Whether to use train or eval (default='train').
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
def __init__(self, data_url, usage="train"):
|
||||
super().__init__(data_url, usage)
|
||||
|
||||
def _load_samples(self):
|
||||
try:
|
||||
self.samples = get_raw_samples(os.path.join(self.data_url, self.usage))
|
||||
except Exception as e:
|
||||
print("load HwVocRawDataset failed!!!")
|
||||
raise e
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""OhemLoss."""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class OhemLoss(nn.Cell):
|
||||
"""Ohem loss cell."""
|
||||
def __init__(self, num, ignore_label):
|
||||
super(OhemLoss, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.shape = P.Shape()
|
||||
self.one_hot = nn.OneHot(-1, num, 1.0, 0.0)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.num = num
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.equal = P.Equal()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.fill = P.Fill()
|
||||
self.transpose = P.Transpose()
|
||||
self.ignore_label = ignore_label
|
||||
self.loss_weight = 1.0
|
||||
|
||||
def construct(self, logits, labels):
|
||||
logits = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits = self.reshape(logits, (-1, self.num))
|
||||
labels = F.cast(labels, mstype.int32)
|
||||
labels = self.reshape(labels, (-1,))
|
||||
one_hot_labels = self.one_hot(labels)
|
||||
losses = self.cross_entropy(logits, one_hot_labels)[0]
|
||||
weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
|
||||
weighted_losses = self.mul(losses, weights)
|
||||
loss = self.reduce_sum(weighted_losses, (0,))
|
||||
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
|
||||
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
|
||||
present = self.select(self.equal(weights, zeros), zeros, ones)
|
||||
present = self.reduce_sum(present, (0,))
|
||||
|
||||
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
|
||||
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
|
||||
present = self.select(self.equal(present, zeros), min_control, present)
|
||||
loss = loss / present
|
||||
return loss
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset module."""
|
||||
from PIL import Image
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
|
||||
from .ei_dataset import HwVocRawDataset
|
||||
from .utils import custom_transforms as tr
|
||||
|
||||
|
||||
class DataTransform:
|
||||
"""Transform dataset for DeepLabV3."""
|
||||
|
||||
def __init__(self, args, usage):
|
||||
self.args = args
|
||||
self.usage = usage
|
||||
|
||||
def __call__(self, image, label):
|
||||
if self.usage == "train":
|
||||
return self._train(image, label)
|
||||
if self.usage == "eval":
|
||||
return self._eval(image, label)
|
||||
return None
|
||||
|
||||
def _train(self, image, label):
|
||||
"""
|
||||
Process training data.
|
||||
|
||||
Args:
|
||||
image (list): Image data.
|
||||
label (list): Dataset label.
|
||||
"""
|
||||
image = Image.fromarray(image)
|
||||
label = Image.fromarray(label)
|
||||
|
||||
rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
|
||||
image, label = rsc_tr(image, label)
|
||||
|
||||
rhf_tr = tr.RandomHorizontalFlip()
|
||||
image, label = rhf_tr(image, label)
|
||||
|
||||
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
||||
image, label = nor_tr(image, label)
|
||||
|
||||
return image, label
|
||||
|
||||
def _eval(self, image, label):
|
||||
"""
|
||||
Process eval data.
|
||||
|
||||
Args:
|
||||
image (list): Image data.
|
||||
label (list): Dataset label.
|
||||
"""
|
||||
image = Image.fromarray(image)
|
||||
label = Image.fromarray(label)
|
||||
|
||||
fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
|
||||
image, label = fsc_tr(image, label)
|
||||
|
||||
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
||||
image, label = nor_tr(image, label)
|
||||
|
||||
return image, label
|
||||
|
||||
|
||||
def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
|
||||
"""
|
||||
Create Dataset for DeepLabV3.
|
||||
|
||||
Args:
|
||||
args (dict): Train parameters.
|
||||
data_url (str): Dataset path.
|
||||
epoch_num (int): Epoch of dataset (default=1).
|
||||
batch_size (int): Batch size of dataset (default=1).
|
||||
usage (str): Whether is use to train or eval (default='train').
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
# create iter dataset
|
||||
dataset = HwVocRawDataset(data_url, usage=usage)
|
||||
dataset_len = len(dataset)
|
||||
|
||||
# wrapped with GeneratorDataset
|
||||
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
|
||||
dataset.set_dataset_size(dataset_len)
|
||||
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
|
||||
|
||||
channelswap_op = C.HWC2CHW()
|
||||
dataset = dataset.map(input_columns="image", operations=channelswap_op)
|
||||
|
||||
# 1464 samples / batch_size 8 = 183 batches
|
||||
# epoch_num is num of steps
|
||||
# 3658 steps / 183 = 20 epochs
|
||||
if usage == "train":
|
||||
dataset = dataset.shuffle(1464)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
|
||||
dataset = dataset.repeat(count=epoch_num)
|
||||
dataset.map_model = 4
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""mIou."""
|
||||
import numpy as np
|
||||
from mindspore.nn.metrics.metric import Metric
|
||||
|
||||
|
||||
def confuse_matrix(target, pred, n):
|
||||
k = (target >= 0) & (target < n)
|
||||
return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
|
||||
def iou(hist):
|
||||
denominator = hist.sum(1) + hist.sum(0) - np.diag(hist)
|
||||
res = np.diag(hist) / np.where(denominator > 0, denominator, 1)
|
||||
res = np.sum(res) / np.count_nonzero(denominator)
|
||||
return res
|
||||
|
||||
|
||||
class MiouPrecision(Metric):
|
||||
"""Calculate miou precision."""
|
||||
def __init__(self, num_class=21):
|
||||
super(MiouPrecision, self).__init__()
|
||||
if not isinstance(num_class, int):
|
||||
raise TypeError('num_class should be integer type, but got {}'.format(type(num_class)))
|
||||
if num_class < 1:
|
||||
raise ValueError('num_class must be at least 1, but got {}'.format(num_class))
|
||||
self._num_class = num_class
|
||||
self._mIoU = []
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self._hist = np.zeros((self._num_class, self._num_class))
|
||||
self._mIoU = []
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
predict_in = self._convert_data(inputs[0])
|
||||
label_in = self._convert_data(inputs[1])
|
||||
if predict_in.shape[1] != self._num_class:
|
||||
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
|
||||
'classes'.format(self._num_class, predict_in.shape[1]))
|
||||
pred = np.argmax(predict_in, axis=1)
|
||||
label = label_in
|
||||
if len(label.flatten()) != len(pred.flatten()):
|
||||
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
|
||||
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
|
||||
'classes'.format(self._num_class, predict_in.shape[1]))
|
||||
self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class)
|
||||
mIoUs = iou(self._hist)
|
||||
self._mIoU.append(mIoUs)
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Computes the mIoU categorical accuracy.
|
||||
"""
|
||||
mIoU = np.nanmean(self._mIoU)
|
||||
print('mIoU = {}'.format(mIoU))
|
||||
return mIoU
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Adapter dataset."""
|
||||
import fnmatch
|
||||
import io
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..utils import file_io
|
||||
|
||||
|
||||
def get_raw_samples(data_url):
|
||||
"""
|
||||
Get dataset from raw data.
|
||||
|
||||
Args:
|
||||
data_url (str): Dataset path.
|
||||
|
||||
Returns:
|
||||
list, a file list.
|
||||
"""
|
||||
def _list_files(dir_path, pattern):
|
||||
full_files = []
|
||||
_, _, files = next(file_io.walk(dir_path))
|
||||
for f in files:
|
||||
if fnmatch.fnmatch(f.lower(), pattern.lower()):
|
||||
full_files.append(os.path.join(dir_path, f))
|
||||
return full_files
|
||||
|
||||
img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg")
|
||||
seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png")
|
||||
|
||||
files = []
|
||||
for img_file in img_files:
|
||||
_, file_name = os.path.split(img_file)
|
||||
name, _ = os.path.splitext(file_name)
|
||||
seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"]))
|
||||
if seg_file in seg_files:
|
||||
files.append([img_file, seg_file])
|
||||
return files
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""
|
||||
Read image from file.
|
||||
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
"""
|
||||
img = file_io.read(img_path.strip(), binary=True)
|
||||
data = io.BytesIO(img)
|
||||
img = Image.open(data)
|
||||
return np.array(img)
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Random process dataset."""
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
|
||||
|
||||
class Normalize:
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
Args:
|
||||
mean (tuple): means for each channel.
|
||||
std (tuple): standard deviations for each channel.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, img, mask):
|
||||
img = np.array(img).astype(np.float32)
|
||||
mask = np.array(mask).astype(np.float32)
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
"""Randomly decide whether to horizontal flip."""
|
||||
def __call__(self, img, mask):
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class RandomRotate:
|
||||
"""
|
||||
Randomly decide whether to rotate.
|
||||
|
||||
Args:
|
||||
degree (float): The degree of rotate.
|
||||
"""
|
||||
def __init__(self, degree):
|
||||
self.degree = degree
|
||||
|
||||
def __call__(self, img, mask):
|
||||
rotate_degree = random.uniform(-1 * self.degree, self.degree)
|
||||
img = img.rotate(rotate_degree, Image.BILINEAR)
|
||||
mask = mask.rotate(rotate_degree, Image.NEAREST)
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class RandomGaussianBlur:
|
||||
"""Randomly decide whether to filter image with gaussian blur."""
|
||||
def __call__(self, img, mask):
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(
|
||||
radius=random.random()))
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class RandomScaleCrop:
|
||||
"""Randomly decide whether to scale and crop image."""
|
||||
def __init__(self, base_size, crop_size, fill=0):
|
||||
self.base_size = base_size
|
||||
self.crop_size = crop_size
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, img, mask):
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < self.crop_size:
|
||||
padh = self.crop_size - oh if oh < self.crop_size else 0
|
||||
padw = self.crop_size - ow if ow < self.crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - self.crop_size)
|
||||
y1 = random.randint(0, h - self.crop_size)
|
||||
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class FixScaleCrop:
|
||||
"""Scale and crop image with fixing size."""
|
||||
def __init__(self, crop_size):
|
||||
self.crop_size = crop_size
|
||||
|
||||
def __call__(self, img, mask):
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
oh = self.crop_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
else:
|
||||
ow = self.crop_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# center crop
|
||||
w, h = img.size
|
||||
x1 = int(round((w - self.crop_size) / 2.))
|
||||
y1 = int(round((h - self.crop_size) / 2.))
|
||||
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
||||
|
||||
return img, mask
|
||||
|
||||
|
||||
class FixedResize:
|
||||
"""Resize image with fixing size."""
|
||||
def __init__(self, size):
|
||||
self.size = (size, size)
|
||||
|
||||
def __call__(self, img, mask):
|
||||
assert img.size == mask.size
|
||||
|
||||
img = img.resize(self.size, Image.BILINEAR)
|
||||
mask = mask.resize(self.size, Image.NEAREST)
|
||||
return img, mask
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""File operation module."""
|
||||
import os
|
||||
|
||||
|
||||
def _is_obs(url):
|
||||
return url.startswith("obs://") or url.startswith("s3://")
|
||||
|
||||
|
||||
def read(url, binary=False):
|
||||
if _is_obs(url):
|
||||
# TODO read cloud file.
|
||||
return None
|
||||
|
||||
with open(url, "rb" if binary else "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def walk(url):
|
||||
if _is_obs(url):
|
||||
# TODO read cloud file.
|
||||
return None
|
||||
return os.walk(url)
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train."""
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from src.md_dataset import create_dataset
|
||||
from src.losses import OhemLoss
|
||||
from src.deeplabv3 import deeplabv3_resnet50
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.')
|
||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
|
||||
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
||||
args_opt = parser.parse_args()
|
||||
print(args_opt)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
def model_fine_tune(flags, train_net, fix_weight_layer):
|
||||
checkpoint_path = flags.checkpoint_url
|
||||
if checkpoint_path is None:
|
||||
return
|
||||
param_dict = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(train_net, param_dict)
|
||||
for para in train_net.trainable_params():
|
||||
if fix_weight_layer in para.name:
|
||||
para.requires_grad = False
|
||||
if __name__ == "__main__":
|
||||
if args_opt.distribute == "true":
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
||||
init()
|
||||
args_opt.base_size = config.crop_size
|
||||
args_opt.crop_size = config.crop_size
|
||||
train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train")
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
callback = [time_cb, LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
|
||||
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
||||
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
||||
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
||||
net.set_train()
|
||||
model_fine_tune(args_opt, net, 'layer')
|
||||
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
|
||||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
||||
model = Model(net, loss, opt)
|
||||
model.train(args_opt.epoch_size, train_dataset, callback)
|
||||
|
Loading…
Reference in New Issue