From 858f2b59b2d56b7bcf0aa5a376c3167714a304f8 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Fri, 19 Jun 2020 14:57:42 +0800 Subject: [PATCH] fix image.CentralCrop --- mindspore/nn/layer/image.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 0107f20e0cb..b23f20deb84 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -267,11 +267,9 @@ class PSNR(Cell): @constexpr -def _check_input_3d_or_4d(input_shape, param_name, func_name): - """check input 3d or 4d""" - if len(input_shape) != 3 and len(input_shape) != 4: - raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") - return True +def _raise_dims_rank_error(input_shape, param_name, func_name): + """raise error if input is not 3d or 4d""" + raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") @constexpr def _get_bbox(rank, shape, central_fraction): @@ -281,6 +279,7 @@ def _get_bbox(rank, shape, central_fraction): else: n, c, h, w = shape + central_fraction = central_fraction.asnumpy()[0] bbox_h_start = int((float(h) - float(h) * central_fraction) / 2) bbox_w_start = int((float(w) - float(w) * central_fraction) / 2) bbox_h_size = h - bbox_h_start * 2 @@ -319,16 +318,18 @@ class CentralCrop(Cell): validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) self.central_fraction = validator.check_number_range('central_fraction', central_fraction, 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) + self.central_fraction_tensor = Tensor(np.array([central_fraction]).astype(np.float64)) self.slice = P.Slice() def construct(self, image): image_shape = F.shape(image) rank = len(image_shape) - _check_input_3d_or_4d(image_shape, "image", self.cls_name) + if not rank in (3, 4): + return _raise_dims_rank_error(image_shape, "image", self.cls_name) if self.central_fraction == 1.0: return image - bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) + bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction_tensor) image = self.slice(image, bbox_begin, bbox_size) return image