pylint fix

This commit is contained in:
yangyongjie 2020-05-29 12:01:09 +08:00
parent ee2efd3589
commit 2065383e38
2 changed files with 54 additions and 51 deletions

View File

@ -22,6 +22,8 @@ from src.losses import OhemLoss
from src.miou_precision import MiouPrecision from src.miou_precision import MiouPrecision
from src.deeplabv3 import deeplabv3_resnet50 from src.deeplabv3 import deeplabv3_resnet50
from src.config import config from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.') parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
@ -32,14 +34,16 @@ parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
print(args_opt) print(args_opt)
if __name__ == "__main__": if __name__ == "__main__":
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval") eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval")
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
param_dict = load_checkpoint(args_opt.checkpoint_url) param_dict = load_checkpoint(args_opt.checkpoint_url)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
mIou = MiouPrecision(config.seg_num_classes) mIou = MiouPrecision(config.seg_num_classes)
@ -47,4 +51,3 @@ if __name__ == "__main__":
loss = OhemLoss(config.seg_num_classes, config.ignore_label) loss = OhemLoss(config.seg_num_classes, config.ignore_label)
model = Model(net, loss, metrics=metrics) model = Model(net, loss, metrics=metrics)
model.eval(eval_dataset) model.eval(eval_dataset)

View File

@ -93,31 +93,30 @@ def _stob_deep_conv_btos_bn_relu(in_channel,
def _stob_conv_btos_bn_relu(in_channel, def _stob_conv_btos_bn_relu(in_channel,
out_channel, out_channel,
ksize, ksize,
space_to_batch_block_shape, space_to_batch_block_shape,
batch_to_space_block_shape, batch_to_space_block_shape,
paddings, paddings,
crops, crops,
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
pad_mode="pad", pad_mode="pad",
use_batch_statistics=False): use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell( return nn.SequentialCell([SpaceToBatch(space_to_batch_block_shape, paddings),
[SpaceToBatch(space_to_batch_block_shape,paddings), nn.Conv2d(in_channel,
nn.Conv2d(in_channel, out_channel,
out_channel, kernel_size=ksize,
kernel_size=ksize, stride=stride,
stride=stride, padding=padding,
padding=padding, dilation=dilation,
dilation=dilation, pad_mode=pad_mode),
pad_mode=pad_mode), BatchToSpace(batch_to_space_block_shape, crops),
BatchToSpace(batch_to_space_block_shape,crops), nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics),
nn.BatchNorm2d(out_channel,use_batch_statistics=use_batch_statistics), nn.ReLU()]
nn.ReLU()] )
)
def _make_layer(block, def _make_layer(block,
@ -206,6 +205,7 @@ class BatchToSpace(nn.Cell):
class _DepthwiseConv2dNative(nn.Cell): class _DepthwiseConv2dNative(nn.Cell):
"""Depthwise Conv2D Cell."""
def __init__(self, def __init__(self,
in_channels, in_channels,
channel_multiplier, channel_multiplier,
@ -242,6 +242,7 @@ class _DepthwiseConv2dNative(nn.Cell):
class DepthwiseConv2dNative(_DepthwiseConv2dNative): class DepthwiseConv2dNative(_DepthwiseConv2dNative):
"""Depthwise Conv2D Cell."""
def __init__(self, def __init__(self,
in_channels, in_channels,
channel_multiplier, channel_multiplier,
@ -315,31 +316,31 @@ class BottleneckV1(nn.Cell):
padding=1, padding=1,
dilation=1, dilation=1,
use_batch_statistics=use_batch_statistics) use_batch_statistics=use_batch_statistics)
if use_batch_to_stob_and_btos == True: if use_batch_to_stob_and_btos:
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
mid_channels, mid_channels,
ksize=3, ksize=3,
stride=stride, stride=stride,
padding=0, padding=0,
dilation=1, dilation=1,
space_to_batch_block_shape = 2, space_to_batch_block_shape=2,
batch_to_space_block_shape = 2, batch_to_space_block_shape=2,
paddings =[[2, 3], [2, 3]], paddings=[[2, 3], [2, 3]],
crops =[[0, 1], [0, 1]], crops=[[0, 1], [0, 1]],
pad_mode="valid", pad_mode="valid",
use_batch_statistics=use_batch_statistics) use_batch_statistics=use_batch_statistics)
self.conv3 = nn.Conv2d(mid_channels, self.conv3 = nn.Conv2d(mid_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1) stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
if in_channels != out_channels: if in_channels != out_channels:
conv = nn.Conv2d(in_channels, conv = nn.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=stride) stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn]) self.downsample = nn.SequentialCell([conv, bn])
else: else:
self.downsample = Subsample(stride) self.downsample = Subsample(stride)
@ -397,23 +398,23 @@ class BottleneckV2(nn.Cell):
stride=stride, stride=stride,
padding=0, padding=0,
dilation=1, dilation=1,
space_to_batch_block_shape = 2, space_to_batch_block_shape=2,
batch_to_space_block_shape = 2, batch_to_space_block_shape=2,
paddings =[[2, 3], [2, 3]], paddings=[[2, 3], [2, 3]],
crops =[[0, 1], [0, 1]], crops=[[0, 1], [0, 1]],
pad_mode="valid", pad_mode="valid",
use_batch_statistics=use_batch_statistics) use_batch_statistics=use_batch_statistics)
self.conv3 = nn.Conv2d(mid_channels, self.conv3 = nn.Conv2d(mid_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1) stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
if in_channels != out_channels: if in_channels != out_channels:
conv = nn.Conv2d(in_channels, conv = nn.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=stride) stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn]) self.downsample = nn.SequentialCell([conv, bn])
else: else:
self.downsample = Subsample(stride) self.downsample = Subsample(stride)
@ -465,14 +466,14 @@ class BottleneckV3(nn.Cell):
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1) stride=1)
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
if in_channels != out_channels: if in_channels != out_channels:
conv = nn.Conv2d(in_channels, conv = nn.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=stride) stride=stride)
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.downsample = nn.SequentialCell([conv, bn]) self.downsample = nn.SequentialCell([conv, bn])
else: else:
self.downsample = Subsample(stride) self.downsample = Subsample(stride)
@ -502,9 +503,8 @@ class ResNetV1(nn.Cell):
super(ResNetV1, self).__init__() super(ResNetV1, self).__init__()
self.layer_root = nn.SequentialCell( self.layer_root = nn.SequentialCell(
[RootBlockBeta(fine_tune_batch_norm), [RootBlockBeta(fine_tune_batch_norm),
nn.MaxPool2d(kernel_size=(3,3), nn.MaxPool2d(kernel_size=(3, 3),
stride=(2,2), stride=(2, 2),
#padding=1,
pad_mode='same')]) pad_mode='same')])
self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm)
@ -519,7 +519,7 @@ class ResNetV1(nn.Cell):
self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True,
use_batch_statistics=fine_tune_batch_norm) use_batch_statistics=fine_tune_batch_norm)
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True,
@ -542,7 +542,7 @@ class ResNetV1(nn.Cell):
x = self.layer3_4(x) x = self.layer3_4(x)
x = self.layer3_5(x) x = self.layer3_5(x)
x = self.layer3_6(x) x = self.layer3_6(x)
x = self.layer4_1(x) x = self.layer4_1(x)
x = self.layer4_2(x) x = self.layer4_2(x)
c5 = self.layer4_3(x) c5 = self.layer4_3(x)