From aa73abc2f7da30989c621c1698ad8e181957f677 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Sat, 13 Jun 2020 18:13:05 +0800 Subject: [PATCH] Add image.CentralCrop --- mindspore/nn/layer/image.py | 70 ++++++++++++++++++++++- tests/ut/python/nn/test_central_crop.py | 74 +++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 tests/ut/python/nn/test_central_crop.py diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 7d8eef4d6f4..4ab42895777 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -23,7 +23,7 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..cell import Cell -__all__ = ['ImageGradients', 'SSIM', 'PSNR'] +__all__ = ['ImageGradients', 'SSIM', 'PSNR', 'CentralCrop'] class ImageGradients(Cell): r""" @@ -259,3 +259,71 @@ class PSNR(Cell): psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0) return psnr + + +@constexpr +def _check_input_3d_or_4d(input_shape, param_name, func_name): + """check input 3d or 4d""" + if len(input_shape) != 3 and len(input_shape) != 4: + raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") + return True + +@constexpr +def _get_bbox(rank, shape, central_fraction): + """get bbox start and size for slice""" + if rank == 3: + c, h, w = shape + else: + n, c, h, w = shape + + bbox_h_start = int((float(h) - float(h) * central_fraction) / 2) + bbox_w_start = int((float(w) - float(w) * central_fraction) / 2) + bbox_h_size = h - bbox_h_start * 2 + bbox_w_size = w - bbox_w_start * 2 + + if rank == 3: + bbox_begin = (0, bbox_h_start, bbox_w_start) + bbox_size = (c, bbox_h_size, bbox_w_size) + else: + bbox_begin = (0, 0, bbox_h_start, bbox_w_start) + bbox_size = (n, c, bbox_h_size, bbox_w_size) + + return bbox_begin, bbox_size + +class CentralCrop(Cell): + """ + Crop the centeral region of the images with the central_fraction. + + Args: + central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0]. + + Inputs: + - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W]. + + Outputs: + Tensor, 3-D or 4-D float tensor, according to the input. + + Examples: + >>> net = nn.CentralCrop(central_fraction=0.5) + >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32) + >>> output = net(image) + """ + + def __init__(self, central_fraction): + super(CentralCrop, self).__init__() + validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) + self.central_fraction = validator.check_number_range('central_fraction', central_fraction, + 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) + self.slice = P.Slice() + + def construct(self, image): + image_shape = F.shape(image) + rank = len(image_shape) + _check_input_3d_or_4d(image_shape, "image", self.cls_name) + if self.central_fraction == 1.0: + return image + + bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) + image = self.slice(image, bbox_begin, bbox_size) + + return image diff --git a/tests/ut/python/nn/test_central_crop.py b/tests/ut/python/nn/test_central_crop.py new file mode 100644 index 00000000000..dc9f438f952 --- /dev/null +++ b/tests/ut/python/nn/test_central_crop.py @@ -0,0 +1,74 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test CentralCrop +""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.api import _executor + + +class CentralCropNet(nn.Cell): + def __init__(self, central_fraction): + super(CentralCropNet, self).__init__() + self.net = nn.CentralCrop(central_fraction) + + def construct(self, image): + return self.net(image) + + +def test_compile_3d_central_crop(): + central_fraction = 0.2 + net = CentralCropNet(central_fraction) + image = Tensor(np.random.random((3, 16, 16)), mstype.float32) + _executor.compile(net, image) + + +def test_compile_4d_central_crop(): + central_fraction = 0.5 + net = CentralCropNet(central_fraction) + image = Tensor(np.random.random((8, 3, 16, 16)), mstype.float32) + _executor.compile(net, image) + + +def test_central_fraction_bool(): + central_fraction = True + with pytest.raises(TypeError): + _ = CentralCropNet(central_fraction) + + +def test_central_crop_central_fraction_negative(): + central_fraction = -1.0 + with pytest.raises(ValueError): + _ = CentralCropNet(central_fraction) + + +def test_central_fraction_zero(): + central_fraction = 0.0 + with pytest.raises(ValueError): + _ = CentralCropNet(central_fraction) + + +def test_central_crop_invalid_5d_input(): + invalid_shape = (8, 3, 16, 16, 1) + invalid_image = Tensor(np.random.random(invalid_shape)) + + net = CentralCropNet(central_fraction=0.5) + with pytest.raises(ValueError): + _executor.compile(net, invalid_image)