forked from mindspore-Ecosystem/mindspore
!2348 fix image.CenterCrop.
Merge pull request !2348 from liuxiao/central_crop
This commit is contained in:
commit
c55b81e94f
|
@ -267,11 +267,9 @@ class PSNR(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_3d_or_4d(input_shape, param_name, func_name):
|
def _raise_dims_rank_error(input_shape, param_name, func_name):
|
||||||
"""check input 3d or 4d"""
|
"""raise error if input is not 3d or 4d"""
|
||||||
if len(input_shape) != 3 and len(input_shape) != 4:
|
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
|
||||||
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _get_bbox(rank, shape, central_fraction):
|
def _get_bbox(rank, shape, central_fraction):
|
||||||
|
@ -281,6 +279,7 @@ def _get_bbox(rank, shape, central_fraction):
|
||||||
else:
|
else:
|
||||||
n, c, h, w = shape
|
n, c, h, w = shape
|
||||||
|
|
||||||
|
central_fraction = central_fraction.asnumpy()[0]
|
||||||
bbox_h_start = int((float(h) - float(h) * central_fraction) / 2)
|
bbox_h_start = int((float(h) - float(h) * central_fraction) / 2)
|
||||||
bbox_w_start = int((float(w) - float(w) * central_fraction) / 2)
|
bbox_w_start = int((float(w) - float(w) * central_fraction) / 2)
|
||||||
bbox_h_size = h - bbox_h_start * 2
|
bbox_h_size = h - bbox_h_start * 2
|
||||||
|
@ -319,16 +318,18 @@ class CentralCrop(Cell):
|
||||||
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
|
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
|
||||||
self.central_fraction = validator.check_number_range('central_fraction', central_fraction,
|
self.central_fraction = validator.check_number_range('central_fraction', central_fraction,
|
||||||
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
|
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
|
||||||
|
self.central_fraction_tensor = Tensor(np.array([central_fraction]).astype(np.float64))
|
||||||
self.slice = P.Slice()
|
self.slice = P.Slice()
|
||||||
|
|
||||||
def construct(self, image):
|
def construct(self, image):
|
||||||
image_shape = F.shape(image)
|
image_shape = F.shape(image)
|
||||||
rank = len(image_shape)
|
rank = len(image_shape)
|
||||||
_check_input_3d_or_4d(image_shape, "image", self.cls_name)
|
if not rank in (3, 4):
|
||||||
|
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
|
||||||
if self.central_fraction == 1.0:
|
if self.central_fraction == 1.0:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction_tensor)
|
||||||
image = self.slice(image, bbox_begin, bbox_size)
|
image = self.slice(image, bbox_begin, bbox_size)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
Loading…
Reference in New Issue