add dataset

This commit is contained in:
yangyongjie 2020-05-29 01:40:04 +08:00
parent ae04259442
commit a728b328e1
6 changed files with 897 additions and 0 deletions

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adapter dataset."""
import fnmatch
import io
import os
import numpy as np
from PIL import Image
from ..utils import file_io
def get_raw_samples(data_url):
"""
Get dataset from raw data.
Args:
data_url (str): Dataset path.
Returns:
list, a file list.
"""
def _list_files(dir_path, pattern):
full_files = []
_, _, files = next(file_io.walk(dir_path))
for f in files:
if fnmatch.fnmatch(f.lower(), pattern.lower()):
full_files.append(os.path.join(dir_path, f))
return full_files
img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg")
seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png")
files = []
for img_file in img_files:
_, file_name = os.path.split(img_file)
name, _ = os.path.splitext(file_name)
seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"]))
if seg_file in seg_files:
files.append([img_file, seg_file])
return files
def read_image(img_path):
"""
Read image from file.
Args:
img_path (str): image path.
"""
img = file_io.read(img_path.strip(), binary=True)
data = io.BytesIO(img)
img = Image.open(data)
return np.array(img)

View File

@ -0,0 +1,148 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random process dataset."""
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, img, mask):
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
return img, mask
class RandomHorizontalFlip(object):
"""Randomly decide whether to horizontal flip."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask
class RandomRotate(object):
"""
Randomly decide whether to rotate.
Args:
degree (float): The degree of rotate.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img, mask):
rotate_degree = random.uniform(-1 * self.degree, self.degree)
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return img, mask
class RandomGaussianBlur(object):
"""Randomly decide whether to filter image with gaussian blur."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.random()))
return img, mask
class RandomScaleCrop(object):
"""Randomly decide whether to scale and crop image."""
def __init__(self, base_size, crop_size, fill=0):
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, img, mask):
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - self.crop_size)
y1 = random.randint(0, h - self.crop_size)
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixScaleCrop(object):
"""Scale and crop image with fixing size."""
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, img, mask):
w, h = img.size
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixedResize(object):
"""Resize image with fixing size."""
def __init__(self, size):
self.size = (size, size)
def __call__(self, img, mask):
assert img.size == mask.size
img = img.resize(self.size, Image.BILINEAR)
mask = mask.resize(self.size, Image.NEAREST)
return img, mask

View File

@ -0,0 +1,457 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeepLabv3."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
class ASPPSampleBlock(nn.Cell):
"""ASPP sample block."""
def __init__(self, feature_shape, scale_size,output_stride):
super(ASPPSampleBlock, self).__init__()
sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
self.sample = P.ResizeBilinear((int(sample_h),int(sample_w)),align_corners=True)
def construct(self, x):
return self.sample(x)
class ASPP(nn.Cell):
"""
ASPP model for DeepLabv3.
Args:
channel (int): Input channel.
depth (int): Output channel.
feature_shape (list): The shape of feature,[h,w].
scale_sizes (list): Input scales for multi-scale feature extraction.
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
output_stride (int): 'The ratio of input to output spatial resolution.'
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
Returns:
Tensor, output tensor.
Examples:
>>> ASPP(channel=2048,256,[14,14],[1],[6],16)
"""
def __init__(self, channel, depth, feature_shape, scale_sizes,
atrous_rates, output_stride, fine_tune_batch_norm=False):
super(ASPP, self).__init__()
self.aspp0 = _conv_bn_relu(channel,
depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.atrous_rates = []
if atrous_rates is not None:
self.atrous_rates = atrous_rates
self.aspp_pointwise = _conv_bn_relu(channel,
depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
channel_multiplier=1,
kernel_size=3,
stride=1,
dilation=1,
pad_mode="valid")
self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
self.aspp_depth_relu = nn.ReLU()
self.aspp_depths = []
self.aspp_depth_spacetobatchs = []
self.aspp_depth_batchtospaces = []
for scale_size in scale_sizes:
aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
if atrous_rates is None:
break
for i in range(len(atrous_rates)):
padding = 0
for j in range(100):
padded_size = atrous_rates[i] * j
if padded_size >= aspp_scale_depth_size + 2 * atrous_rates[i]:
padding = padded_size - aspp_scale_depth_size - 2 * atrous_rates[i]
break
paddings = [[atrous_rates[i], atrous_rates[i] + int(padding)],
[atrous_rates[i], atrous_rates[i] + int(padding)]]
self.aspp_depth_spacetobatch = SpaceToBatch(atrous_rates[i],paddings)
self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
crops =[[0, int(padding)], [0, int(padding)]]
self.aspp_depth_batchtospace = BatchToSpace(atrous_rates[i],crops)
self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
self.aspp_depths = nn.CellList(self.aspp_depths)
self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]),int(feature_shape[1])))
self.global_poolings = []
for scale_size in scale_sizes:
pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
self.global_poolings = nn.CellList(self.global_poolings)
self.conv_bn = _conv_bn_relu(channel,
depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.samples = []
for scale_size in scale_sizes:
self.samples.append(ASPPSampleBlock(feature_shape,scale_size,output_stride))
self.samples = nn.CellList(self.samples)
self.feature_shape = feature_shape
self.concat = P.Concat(axis=1)
def construct(self, x, scale_index=0):
aspp0 = self.aspp0(x)
aspp1 = self.global_poolings[scale_index](x)
aspp1 = self.conv_bn(aspp1)
aspp1 = self.samples[scale_index](aspp1)
output = self.concat((aspp1,aspp0))
for i in range(len(self.atrous_rates)):
aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
aspp_i = self.aspp_depth_bn(aspp_i)
aspp_i = self.aspp_depth_relu(aspp_i)
aspp_i = self.aspp_pointwise(aspp_i)
output = self.concat((output,aspp_i))
return output
class DecoderSampleBlock(nn.Cell):
"""Decoder sample block."""
def __init__(self,feature_shape,scale_size=1.0,decoder_output_stride=4):
super(DecoderSampleBlock, self).__init__()
sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
def construct(self, x):
return self.sample(x)
class Decoder(nn.Cell):
"""
Decode module for DeepLabv3.
Args:
low_level_channel (int): Low level input channel
channel (int): Input channel.
depth (int): Output channel.
feature_shape (list): 'Input image shape, [N,C,H,W].'
scale_sizes (list): 'Input scales for multi-scale feature extraction.'
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
Returns:
Tensor, output tensor.
Examples:
>>> Decoder(256, 100, [56,56])
"""
def __init__(self,
low_level_channel,
channel,
depth,
feature_shape,
scale_sizes,
decoder_output_stride,
fine_tune_batch_norm):
super(Decoder, self).__init__()
self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
channel_multiplier=1,
ksize=3,
stride=1,
pad_mode="same",
dilation=1,
use_batch_statistics=fine_tune_batch_norm)
self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.decoder_depth1 = _deep_conv_bn_relu(depth,
channel_multiplier=1,
ksize=3,
stride=1,
pad_mode="same",
dilation=1,
use_batch_statistics=fine_tune_batch_norm)
self.decoder_pointwise1 = _conv_bn_relu(depth,
depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.depth = depth
self.concat = P.Concat(axis=1)
self.samples = []
for scale_size in scale_sizes:
self.samples.append(DecoderSampleBlock(feature_shape,scale_size,decoder_output_stride))
self.samples = nn.CellList(self.samples)
def construct(self, x, low_level_feature, scale_index):
low_level_feature = self.feature_projection(low_level_feature)
low_level_feature = self.samples[scale_index](low_level_feature)
x = self.samples[scale_index](x)
output = self.concat((x, low_level_feature))
output = self.decoder_depth0(output)
output = self.decoder_pointwise0(output)
output = self.decoder_depth1(output)
output = self.decoder_pointwise1(output)
return output
class SingleDeepLabV3(nn.Cell):
"""
DeepLabv3 Network.
Args:
num_classes (int): Class number.
feature_shape (list): Input image shape, [N,C,H,W].
backbone (Cell): Backbone Network.
channel (int): Resnet output channel.
depth (int): ASPP block depth.
scale_sizes (list): Input scales for multi-scale feature extraction.
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
output_stride (int): 'The ratio of input to output spatial resolution.'
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
Returns:
Tensor, output tensor.
Examples:
>>> SingleDeepLabV3(num_classes=10,
>>> feature_shape=[1,3,224,224],
>>> backbone=resnet50_dl(),
>>> channel=2048,
>>> depth=256)
>>> scale_sizes=[1.0])
>>> atrous_rates=[6])
>>> decoder_output_stride=4)
>>> output_stride=16)
"""
def __init__(self,
num_classes,
feature_shape,
backbone,
channel,
depth,
scale_sizes,
atrous_rates,
decoder_output_stride,
output_stride,
fine_tune_batch_norm=False):
super(SingleDeepLabV3, self).__init__()
self.num_classes = num_classes
self.channel = channel
self.depth = depth
self.scale_sizes = []
for scale_size in np.sort(scale_sizes):
self.scale_sizes.append(scale_size)
self.net = backbone
self.aspp = ASPP(channel=self.channel,
depth=self.depth,
feature_shape=[feature_shape[2],
feature_shape[3]],
scale_sizes=self.scale_sizes,
atrous_rates=atrous_rates,
output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm)
self.aspp.add_flags(loop_can_unroll=True)
atrous_rates_len = 0
if atrous_rates is not None:
atrous_rates_len = len(atrous_rates)
self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
ksize=1,
stride=1,
use_batch_statistics=fine_tune_batch_norm)
self.fc2 = nn.Conv2d(depth,
num_classes,
kernel_size=1,
stride=1,
has_bias=True)
self.upsample = P.ResizeBilinear((int(feature_shape[2]),
int(feature_shape[3])),
align_corners=True)
self.samples = []
for scale_size in self.scale_sizes:
self.samples.append(SampleBlock(feature_shape, scale_size))
self.samples = nn.CellList(self.samples)
self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
float(feature_shape[3])]
self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
self.dropout = nn.Dropout(keep_prob=0.9)
self.shape = P.Shape()
self.decoder_output_stride = decoder_output_stride
if decoder_output_stride is not None:
self.decoder = Decoder(low_level_channel=depth,
channel=depth,
depth=depth,
feature_shape=[feature_shape[2],
feature_shape[3]],
scale_sizes=self.scale_sizes,
decoder_output_stride=decoder_output_stride,
fine_tune_batch_norm=fine_tune_batch_norm)
def construct(self, x, scale_index=0):
x = (2.0 / 255.0) * x - 1.0
x = self.pad(x)
low_level_feature, feature_map = self.net(x)
for scale_size in self.scale_sizes:
if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
output = self.aspp(feature_map, scale_index)
output = self.fc1(output)
if self.decoder_output_stride is not None:
output = self.decoder(output, low_level_feature, scale_index)
output = self.fc2(output)
output = self.samples[scale_index](output)
return output
scale_index += 1
return feature_map
class SampleBlock(nn.Cell):
"""Sample block."""
def __init__(self,
feature_shape,
scale_size=1.0):
super(SampleBlock, self).__init__()
sample_h = np.ceil(float(feature_shape[2]) * scale_size)
sample_w = np.ceil(float(feature_shape[3]) * scale_size)
self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
def construct(self, x):
return self.sample(x)
class DeepLabV3(nn.Cell):
"""DeepLabV3 model."""
def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
super(DeepLabV3, self).__init__()
self.infer_scale_sizes = []
if infer_scale_sizes is not None:
self.infer_scale_sizes = infer_scale_sizes
self.infer_scale_sizes = infer_scale_sizes
if image_pyramid is None:
image_pyramid = [1.0]
self.image_pyramid = image_pyramid
scale_sizes = []
for i in range(len(image_pyramid)):
scale_sizes.append(image_pyramid[i])
for i in range(len(infer_scale_sizes)):
scale_sizes.append(infer_scale_sizes[i])
self.samples = []
for scale_size in scale_sizes:
self.samples.append(SampleBlock(feature_shape, scale_size))
self.samples = nn.CellList(self.samples)
self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
feature_shape=feature_shape,
backbone=resnet50_dl(fine_tune_batch_norm),
channel=channel,
depth=depth,
scale_sizes=scale_sizes,
atrous_rates=atrous_rates,
decoder_output_stride=decoder_output_stride,
output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm)
self.softmax = P.Softmax(axis=1)
self.concat = P.Concat(axis=2)
self.expand_dims = P.ExpandDims()
self.reduce_mean = P.ReduceMean()
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
int(feature_shape[3])),
align_corners=True)
def construct(self, x):
logits = ()
if self.training:
if len(self.image_pyramid) >= 1:
if self.image_pyramid[0] == 1:
logits = self.deeplabv3(x)
else:
x1 = self.samples[0](x)
logits = self.deeplabv3(x1)
logits = self.sample_common(logits)
logits = self.expand_dims(logits, 2)
for i in range(len(self.image_pyramid) - 1):
x_i = self.samples[i + 1](x)
logits_i = self.deeplabv3(x_i)
logits_i = self.sample_common(logits_i)
logits_i = self.expand_dims(logits_i, 2)
logits = self.concat((logits, logits_i))
logits = self.reduce_mean(logits, 2)
return logits
if len(self.infer_scale_sizes) >= 1:
infer_index = len(self.image_pyramid)
x1 = self.samples[infer_index](x)
logits = self.deeplabv3(x1)
logits = self.sample_common(logits)
logits = self.softmax(logits)
logits = self.expand_dims(logits, 2)
for i in range(len(self.infer_scale_sizes) - 1):
x_i = self.samples[i + 1 + infer_index](x)
logits_i = self.deeplabv3(x_i)
logits_i = self.sample_common(logits_i)
logits_i = self.softmax(logits_i)
logits_i = self.expand_dims(logits_i, 2)
logits = self.concat((logits, logits_i))
logits = self.reduce_mean(logits, 2)
return logits
def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
output_stride=16, fine_tune_batch_norm=False):
"""
ResNet50 based DeepLabv3 network.
Args:
num_classes (int): Class number.
feature_shape (list): Input image shape, [N,C,H,W].
image_pyramid (list): Input scales for multi-scale feature extraction.
atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
infer_scale_sizes (list): 'The scales to resize images for inference.
decoder_output_stride (int): 'The ratio of input to output spatial resolution'
output_stride (int): 'The ratio of input to output spatial resolution.'
fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
Returns:
Cell, cell instance of ResNet50 based DeepLabv3 neural network.
Examples:
>>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
"""
return DeepLabV3(num_classes=num_classes,
feature_shape=feature_shape,
backbone=resnet50_dl(fine_tune_batch_norm),
channel=2048,
depth=256,
infer_scale_sizes=infer_scale_sizes,
atrous_rates=atrous_rates,
decoder_output_stride=decoder_output_stride,
output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm,
image_pyramid=image_pyramid)

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import abc
import os
import time
from .utils.adapter import get_manifest_samples, get_raw_samples, read_image
class BaseDataset(object):
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage):
self.data_url = data_url
self.usage = usage
self.cur_index = 0
self.samples = []
_s_time = time.time()
self._load_samples()
_e_time = time.time()
print(f"load samples success~, time cost = {_e_time - _s_time}")
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample):
image_path = sample[0]
mask_image_path = sample[1]
image = read_image(image_path)
mask_image = read_image(mask_image_path)
return [image, mask_image]
@abc.abstractmethod
def _load_samples(self):
pass
class HwVocManifestDataset(BaseDataset):
"""
Create dataset with manifest data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_manifest_samples(self.data_url, self.usage)
except Exception as e:
print("load HwVocManifestDataset samples failed!!!")
raise e
class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_raw_samples(os.path.join(self.data_url, self.usage))
except Exception as e:
print("load HwVocRawDataset failed!!!")
raise e

View File

@ -0,0 +1,104 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from .ei_dataset import HwVocManifestDataset, HwVocRawDataset
from .utils import custom_transforms as tr
class DataTransform(object):
"""Transform dataset for DeepLabV3."""
def __init__(self, args, usage):
self.args = args
self.usage = usage
def __call__(self, image, label):
if "train" == self.usage:
return self._train(image, label)
elif "eval" == self.usage:
return self._eval(image, label)
def _train(self, image, label):
image = Image.fromarray(image)
label = Image.fromarray(label)
rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
image, label = rsc_tr(image, label)
rhf_tr = tr.RandomHorizontalFlip()
image, label = rhf_tr(image, label)
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
image, label = nor_tr(image, label)
return image, label
def _eval(self, image, label):
image = Image.fromarray(image)
label = Image.fromarray(label)
fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
image, label = fsc_tr(image, label)
nor_tr = tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
image, label = nor_tr(image, label)
return image, label
def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
"""
Create Dataset for DeepLabV3.
Args:
args (dict): Train parameters.
data_url (str): Dataset path.
epoch_num (int): Epoch of dataset (default=1).
batch_size (int): Batch size of dataset (default=1).
usage (str): Whether is use to train or eval (default='train').
Returns:
Dataset.
"""
# create iter dataset
if data_url.endswith(".manifest"):
dataset = HwVocManifestDataset(data_url, usage=usage)
else:
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)
# wrapped with GeneratorDataset
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
dataset.set_dataset_size(dataset_len)
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
channelswap_op = C.HWC2CHW()
dataset = dataset.map(input_columns="image", operations=channelswap_op)
# 1464 samples / batch_size 8 = 183 batches
# epoch_num is num of steps
# 3658 steps / 183 = 20 epochs
if usage == "train":
dataset = dataset.shuffle(1464)
dataset = dataset.batch(batch_size, drop_remainder=(usage == usage))
dataset = dataset.repeat(count=epoch_num)
dataset.map_model = 4
dataset.__loop_size__ = 1
return dataset