modify deeplabv3

This commit is contained in:
unknown 2020-05-29 02:44:12 +08:00
parent 8aae0a18c7
commit ec7cbb9929
13 changed files with 919 additions and 0 deletions

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation."""
import os, time
import argparse
from mindspore import context
from mindspore import log as logger
from mindspore.communication.management import init
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model, ParallelMode
import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset
from src.losses import OhemLoss
from src.miou_precision import MiouPrecision
from src.deeplabv3 import deeplabv3_resnet50
from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
print(args_opt)
if __name__ == "__main__":
args_opt.crop_size = config.crop_size
args_opt.base_size = config.crop_size
eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval")
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid)
param_dict = load_checkpoint(args_opt.checkpoint_url)
load_param_into_net(net, param_dict)
mIou = MiouPrecision(config.seg_num_classes)
metrics={'mIou':mIou}
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
model = Model(net, loss, metrics=metrics)
model.eval(eval_dataset)

View File

@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
from . import backbone
from .backbone import *
__all__ = [
"ASPP", "DeepLabV3", "deeplabv3_resnet50", "Decoder"
]
__all__.extend(backbone.__all__)

View File

@ -0,0 +1,534 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet based DeepLab."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
import numpy as np
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore._checkparam import check_bool, twice
from mindspore import log as logger
from mindspore.common.parameter import Parameter
def _conv_bn_relu(in_channel,
out_channel,
ksize,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a conv2d -> batchnorm -> relu layer"""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)
def _deep_conv_bn_relu(in_channel,
channel_multiplier,
ksize,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell(
[DepthwiseConv2dNative(in_channel,
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)
def _stob_deep_conv_btos_bn_relu(in_channel,
channel_multiplier,
ksize,
space_to_batch_block_shape,
batch_to_space_block_shape,
paddings,
crops,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell(
[SpaceToBatch(space_to_batch_block_shape,paddings),
DepthwiseConv2dNative(in_channel,
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
BatchToSpace(batch_to_space_block_shape,crops),
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)
def _stob_conv_btos_bn_relu(in_channel,
out_channel,
ksize,
space_to_batch_block_shape,
batch_to_space_block_shape,
paddings,
crops,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell(
[SpaceToBatch(space_to_batch_block_shape,paddings),
nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
BatchToSpace(batch_to_space_block_shape,crops),
nn.BatchNorm2d(out_channel,use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)
def _make_layer(block,
in_channels,
out_channels,
num_blocks,
stride=1,
rate=1,
multi_grads=None,
output_stride=None,
g_current_stride=2,
g_rate=1):
"""Make layer for DeepLab-ResNet network."""
if multi_grads is None:
multi_grads = [1] * num_blocks
# (stride == 2, num_blocks == 4 --> strides == [1, 1, 1, 2])
strides = [1] * (num_blocks - 1) + [stride]
blocks = []
if output_stride is not None:
if output_stride % 4 != 0:
raise ValueError('The output_stride needs to be a multiple of 4.')
output_stride //= 4
for i_stride, _ in enumerate(strides):
if output_stride is not None and g_current_stride > output_stride:
raise ValueError('The target output_stride cannot be reached.')
if output_stride is not None and g_current_stride == output_stride:
b_rate = g_rate
b_stride = 1
g_rate *= strides[i_stride]
else:
b_rate = rate
b_stride = strides[i_stride]
g_current_stride *= strides[i_stride]
blocks.append(block(in_channels=in_channels,
out_channels=out_channels,
stride=b_stride,
rate=b_rate,
multi_grad=multi_grads[i_stride]))
in_channels = out_channels
layer = nn.SequentialCell(blocks)
return layer, g_current_stride, g_rate
class Subsample(nn.Cell):
"""
Subsample for DeepLab-ResNet.
Args:
factor (int): Sample factor.
Returns:
Tensor, the sub sampled tensor.
Examples:
>>> Subsample(2)
"""
def __init__(self, factor):
super(Subsample, self).__init__()
self.factor = factor
self.pool = nn.MaxPool2d(kernel_size=1,
stride=factor)
def construct(self, x):
if self.factor == 1:
return x
return self.pool(x)
class SpaceToBatch(nn.Cell):
def __init__(self, block_shape, paddings):
super(SpaceToBatch, self).__init__()
self.space_to_batch = P.SpaceToBatch(block_shape, paddings)
self.bs = block_shape
self.pd = paddings
def construct(self, x):
return self.space_to_batch(x)
class BatchToSpace(nn.Cell):
def __init__(self, block_shape, crops):
super(BatchToSpace, self).__init__()
self.batch_to_space = P.BatchToSpace(block_shape, crops)
self.bs = block_shape
self.cr = crops
def construct(self, x):
return self.batch_to_space(x)
class _DepthwiseConv2dNative(nn.Cell):
def __init__(self,
in_channels,
channel_multiplier,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
weight_init):
super(_DepthwiseConv2dNative, self).__init__()
self.in_channels = in_channels
self.channel_multiplier = channel_multiplier
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
if not (isinstance(in_channels, int) and in_channels > 0):
raise ValueError('Attr \'in_channels\' of \'DepthwiseConv2D\' Op passed '
+ str(in_channels) + ', should be a int and greater than 0.')
if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \
(not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError('Attr \'kernel_size\' of \'DepthwiseConv2D\' Op passed '
+ str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.')
self.weight = Parameter(initializer(weight_init, [1, in_channels // group, *kernel_size]),
name='weight')
def construct(self, *inputs):
"""Must be overridden by all subclasses."""
raise NotImplementedError
class DepthwiseConv2dNative(_DepthwiseConv2dNative):
def __init__(self,
in_channels,
channel_multiplier,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
weight_init='normal'):
kernel_size = twice(kernel_size)
super(DepthwiseConv2dNative, self).__init__(
in_channels,
channel_multiplier,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
weight_init)
self.depthwise_conv2d_native = P.DepthwiseConv2dNative(channel_multiplier=self.channel_multiplier,
kernel_size=self.kernel_size,
mode=3,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
def set_strategy(self, strategy):
self.depthwise_conv2d_native.set_strategy(strategy)
return self
def construct(self, x):
return self.depthwise_conv2d_native(x, self.weight)
class BottleneckV1(nn.Cell):
"""
ResNet V1 BottleneckV1 block definition.
Args:
in_channels (int): Input channel.
out_channels (int): Output channel.
stride (int): Stride size for the initial convolutional layer. Default: 1.
rate (int): Rate for convolution. Default: 1.
multi_grad (int): Employ a rate within network. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>> BottleneckV1(3,256,stride=2)
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
use_batch_statistics=False,
use_batch_to_stob_and_btos=False):
super(BottleneckV1, self).__init__()
expansion = 4
mid_channels = out_channels // expansion
self.conv_bn1 = _conv_bn_relu(in_channels,
mid_channels,
ksize=1,
stride=1,
use_batch_statistics=use_batch_statistics)
self.conv_bn2 = _conv_bn_relu(mid_channels,
mid_channels,
ksize=3,
stride=stride,
padding=1,
dilation=1,
use_batch_statistics=use_batch_statistics)
if use_batch_to_stob_and_btos == True:
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
mid_channels,
ksize=3,
stride=stride,
padding=0,
dilation=1,
space_to_batch_block_shape = 2,
batch_to_space_block_shape = 2,
paddings =[[2, 3], [2, 3]],
crops =[[0, 1], [0, 1]],
pad_mode="valid",
use_batch_statistics=use_batch_statistics)
self.conv3 = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
if in_channels != out_channels:
conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn])
else:
self.downsample = Subsample(stride)
self.add = P.TensorAdd()
self.relu = nn.ReLU()
self.Reshape = P.Reshape()
def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
out = self.bn3(self.conv3(out))
out = self.add(out, self.downsample(x))
out = self.relu(out)
return out
return out
class BottleneckV2(nn.Cell):
"""
ResNet V2 Bottleneck variance V2 block definition.
Args:
in_channels (int): Input channel.
out_channels (int): Output channel.
stride (int): Stride size for the initial convolutional layer. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>> BottleneckV2(3,256,stride=2)
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
use_batch_statistics=False,
use_batch_to_stob_and_btos=False,
dilation=1):
super(BottleneckV2, self).__init__()
expansion = 4
mid_channels = out_channels // expansion
self.conv_bn1 = _conv_bn_relu(in_channels,
mid_channels,
ksize=1,
stride=1,
use_batch_statistics=use_batch_statistics)
self.conv_bn2 = _conv_bn_relu(mid_channels,
mid_channels,
ksize=3,
stride=stride,
padding=1,
dilation=dilation,
use_batch_statistics=use_batch_statistics)
if use_batch_to_stob_and_btos == True:
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
mid_channels,
ksize=3,
stride=stride,
padding=0,
dilation=1,
space_to_batch_block_shape = 2,
batch_to_space_block_shape = 2,
paddings =[[2, 3], [2, 3]],
crops =[[0, 1], [0, 1]],
pad_mode="valid",
use_batch_statistics=use_batch_statistics)
self.conv3 = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
if in_channels != out_channels:
conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn])
else:
self.downsample = Subsample(stride)
self.add = P.TensorAdd()
self.relu = nn.ReLU()
def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
out = self.bn3(self.conv3(out))
out = self.add(out, x)
out = self.relu(out)
return out
class BottleneckV3(nn.Cell):
"""
ResNet V1 Bottleneck variance V1 block definition.
Args:
in_channels (int): Input channel.
out_channels (int): Output channel.
stride (int): Stride size for the initial convolutional layer. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>> BottleneckV3(3,256,stride=2)
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
use_batch_statistics=False):
super(BottleneckV3, self).__init__()
expansion = 4
mid_channels = out_channels // expansion
self.conv_bn1 = _conv_bn_relu(in_channels,
mid_channels,
ksize=1,
stride=1,
use_batch_statistics=use_batch_statistics)
self.conv_bn2 = _conv_bn_relu(mid_channels,
mid_channels,
ksize=3,
stride=stride,
padding=1,
dilation=1,
use_batch_statistics=use_batch_statistics)
self.conv3 = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
if in_channels != out_channels:
conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn])
else:
self.downsample = Subsample(stride)
self.downsample = Subsample(stride)
self.add = P.TensorAdd()
self.relu = nn.ReLU()
def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
out = self.bn3(self.conv3(out))
out = self.add(out, self.downsample(x))
out = self.relu(out)
return out
class ResNetV1(nn.Cell):
"""
ResNet V1 for DeepLab.
Args:
Returns:
Tuple, output tensor tuple, (c2,c5).
Examples:
>>> ResNetV1(False)
"""
def __init__(self, fine_tune_batch_norm=False):
super(ResNetV1, self).__init__()
self.layer_root = nn.SequentialCell(
[RootBlockBeta(fine_tune_batch_norm),
nn.MaxPool2d(kernel_size=(3,3),
stride=(2,2),
#padding=1,
pad_mode='same')])
self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer1_3 = BottleneckV3(256, 256, stride=2, use_batch_statistics=fine_tune_batch_norm)
self.layer2_1 = BottleneckV1(256, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer2_2 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer2_3 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer2_4 = BottleneckV3(512, 512, stride=2, use_batch_statistics=fine_tune_batch_norm)
self.layer3_1 = BottleneckV1(512, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_2 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_3 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
def construct(self, x):
x = self.layer_root(x)
x = self.layer1_1(x)
c2 = self.layer1_2(x)
x = self.layer1_3(c2)
x = self.layer2_1(x)
x = self.layer2_2(x)
x = self.layer2_3(x)
x = self.layer2_4(x)
x = self.layer3_1(x)
x = self.layer3_2(x)
x = self.layer3_3(x)
x = self.layer3_4(x)
x = self.layer3_5(x)
x = self.layer3_6(x)
x = self.layer4_1(x)
x = self.layer4_2(x)
c5 = self.layer4_3(x)
return c2, c5
class RootBlockBeta(nn.Cell):
"""
ResNet V1 beta root block definition.
Returns:
Tensor, the block unit's output.
Examples:
>>> RootBlockBeta()
"""
def __init__(self, fine_tune_batch_norm=False):
super(RootBlockBeta, self).__init__()
self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid", use_batch_statistics=fine_tune_batch_norm)
self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
def construct(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and evaluation.py
"""
from easydict import EasyDict as ed
config = ed({
"learning_rate": 0.0014,
"weight_decay": 0.00005,
"momentum": 0.97,
"crop_size": 513,
"eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
"atrous_rates": None,
"image_pyramid": None,
"output_stride": 16,
"fine_tune_batch_norm": False,
"ignore_label": 255,
"decoder_output_stride": None,
"seg_num_classes": 21
})

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OhemLoss."""
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class OhemLoss(nn.Cell):
def __init__(self, num, ignore_label):
super(OhemLoss, self).__init__()
self.mul = P.Mul()
self.shape = P.Shape()
self.one_hot = nn.OneHot(-1, num, 1.0, 0.0)
self.squeeze = P.Squeeze()
self.num = num
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.select = P.Select()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.not_equal = P.NotEqual()
self.equal = P.Equal()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.fill = P.Fill()
self.transpose = P.Transpose()
self.ignore_label = ignore_label
self.loss_weight = 1.0
def construct(self, logits, labels):
logits = self.transpose(logits,(0,2,3,1))
logits = self.reshape(logits, (-1,self.num))
labels = F.cast(labels,mstype.int32)
labels = self.reshape(labels, (-1,))
one_hot_labels = self.one_hot(labels)
losses = self.cross_entropy(logits, one_hot_labels)[0]
weights = self.cast(self.not_equal(labels,self.ignore_label),mstype.float32) * self.loss_weight
weighted_losses = self.mul(losses, weights)
loss = self.reduce_sum(weighted_losses,(0,))
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
present = self.select(
self.equal(weights, zeros),
zeros,
ones)
present = self.reduce_sum(present,(0,))
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
present = self.select(
self.equal(present, zeros),
min_control,
present)
loss = loss / present
return loss

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mIou."""
import numpy as np
from mindspore.nn.metrics.metric import Metric
def confuse_matrix(target, pred, n):
k = (target >= 0) & (target < n)
return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
def iou(hist):
denominator = hist.sum(1) + hist.sum(0) - np.diag(hist)
res = np.diag(hist) / np.where(denominator > 0, denominator, 1)
res = np.sum(res) / np.count_nonzero(denominator)
return res
class MiouPrecision(Metric):
def __init__(self, num_class=21):
super(MiouPrecision, self).__init__()
if not isinstance(num_class, int):
raise TypeError('num_class should be integer type, but got {}'.format(type(num_class)))
if num_class < 1:
raise ValueError('num_class must be at least 1, but got {}'.format(num_class))
self._num_class = num_class
self._mIoU=[]
self.clear()
def clear(self):
self._hist = np.zeros((self._num_class, self._num_class))
self._mIoU=[]
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
predict_in = self._convert_data(inputs[0])
label_in = self._convert_data(inputs[1])
if predict_in.shape[1] != self._num_class:
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
batch_size = predict_in.shape[0]
pred = np.argmax(predict_in,axis=1)
label = label_in
if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class)
mIoUs = iou(self._hist)
self._mIoU.append(mIoUs)
def eval(self):
"""
Computes the mIoU categorical accuracy.
"""
mIoU=np.nanmean(self._mIoU)
print('mIoU = {}'.format(mIoU))
return mIoU

View File

@ -0,0 +1,15 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
def _is_obs(url):
return url.startswith("obs://") or url.startswith("s3://")
def read(url, binary=False):
if _is_obs(url):
# TODO read cloud file.
return None
with open(url, "rb" if binary else "r") as f:
return f.read()
def walk(url):
if _is_obs(url):
# TODO read cloud file.
return None
return os.walk(url)

View File

@ -0,0 +1,99 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train."""
import os, time
import argparse
from mindspore import context
from mindspore import log as logger
from mindspore.communication.management import init
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model, ParallelMode
import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset
from src.losses import OhemLoss
from src.deeplabv3 import deeplabv3_resnet50
from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 training")
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.')
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
"default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
args_opt = parser.parse_args()
print(args_opt)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
class LossCallBack(Callback):
"""
Monitor the loss in training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
def model_fine_tune(flags, net, fix_weight_layer):
checkpoint_path = flags.checkpoint_url
if checkpoint_path is None:
return
param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(net, param_dict)
for para in net.trainable_params():
if fix_weight_layer in para.name:
para.requires_grad=False
if __name__ == "__main__":
if args_opt.distribute == "true":
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train")
dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size)
callback = [time_cb, LossCallBack()]
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates,
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride,
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid)
net.set_train()
model_fine_tune(args_opt, net, 'layer')
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay)
model = Model(net, loss, opt)
model.train(args_opt.epoch_size, train_dataset, callback)