forked from mindspore-Ecosystem/mindspore
!19160 Fix parameter check of Rotate and GaussianBlur API
Merge pull request !19160 from xiaotianci/fix_rotate_and_gaussian_blur
This commit is contained in:
commit
4a8225eaf4
|
@ -632,8 +632,8 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Rotate: load image failed.");
|
||||
}
|
||||
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Rotate: input tensor is not in shape of <H,W,C> or <H,W>.");
|
||||
if (input_cv->Rank() != DEFAULT_IMAGE_RANK && input_cv->Rank() != MIN_IMAGE_DIMENSION) {
|
||||
RETURN_STATUS_UNEXPECTED("Rotate: image shape is not <H,W,C> or <H,W>.");
|
||||
}
|
||||
|
||||
cv::Mat input_img = input_cv->mat();
|
||||
|
@ -641,8 +641,10 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
|||
RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center is not precise.");
|
||||
}
|
||||
// default to center of image
|
||||
if (fx == -1 && fy == -1) {
|
||||
if (fx == -1) {
|
||||
fx = (input_img.cols - 1) / 2.0;
|
||||
}
|
||||
if (fy == -1) {
|
||||
fy = (input_img.rows - 1) / 2.0;
|
||||
}
|
||||
cv::Mat output_img;
|
||||
|
@ -1270,6 +1272,9 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
int32_t kernel_y, float sigma_x, float sigma_y) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (input_cv->mat().data == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("GaussianBlur: load image failed.");
|
||||
}
|
||||
cv::Mat output_cv_mat;
|
||||
cv::GaussianBlur(input_cv->mat(), output_cv_mat, cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
|
||||
static_cast<double>(sigma_y));
|
||||
|
|
|
@ -47,9 +47,6 @@ RotateOp::RotateOp(float degrees, InterpolationMode resample, bool expand, float
|
|||
|
||||
Status RotateOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->shape().Size() >= 2,
|
||||
"Rotate: image shape " + std::to_string(input->shape().Size()) + " is not <H,W,C> or <H,W>.");
|
||||
#ifndef ENABLE_ANDROID
|
||||
return Rotate(input, output, center_x_, center_y_, degrees_, interpolation_, expand_, fill_r_, fill_g_, fill_b_);
|
||||
#else
|
||||
|
|
|
@ -1201,13 +1201,20 @@ class RandomRotation(ImageTensorOperation):
|
|||
|
||||
@check_random_rotation
|
||||
def __init__(self, degrees, resample=Inter.NEAREST, expand=False, center=None, fill_value=0):
|
||||
if isinstance(degrees, numbers.Number):
|
||||
if isinstance(degrees, (int, float)):
|
||||
degrees = degrees % 360
|
||||
if isinstance(degrees, (list, tuple)):
|
||||
degrees = [degrees[0] % 360, degrees[1] % 360]
|
||||
if degrees[0] > degrees[1]:
|
||||
degrees[1] += 360
|
||||
|
||||
degrees = [-degrees, degrees]
|
||||
elif isinstance(degrees, (list, tuple)):
|
||||
if degrees[1] - degrees[0] >= 360:
|
||||
degrees = [-180, 180]
|
||||
else:
|
||||
degrees = [degrees[0] % 360, degrees[1] % 360]
|
||||
if degrees[0] > degrees[1]:
|
||||
degrees[1] += 360
|
||||
if center is None:
|
||||
center = (-1, -1)
|
||||
if isinstance(fill_value, int):
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
self.degrees = degrees
|
||||
self.resample = resample
|
||||
self.expand = expand
|
||||
|
@ -1215,14 +1222,8 @@ class RandomRotation(ImageTensorOperation):
|
|||
self.fill_value = fill_value
|
||||
|
||||
def parse(self):
|
||||
# pylint false positive
|
||||
# pylint: disable=E1130
|
||||
degrees = (-self.degrees, self.degrees) if isinstance(self.degrees, numbers.Number) else self.degrees
|
||||
interpolation = DE_C_INTER_MODE[self.resample]
|
||||
expand = self.expand
|
||||
center = (-1, -1) if self.center is None else self.center
|
||||
fill_value = tuple([self.fill_value] * 3) if isinstance(self.fill_value, int) else self.fill_value
|
||||
return cde.RandomRotationOperation(degrees, interpolation, expand, center, fill_value)
|
||||
return cde.RandomRotationOperation(self.degrees, DE_C_INTER_MODE[self.resample], self.expand, self.center,
|
||||
self.fill_value)
|
||||
|
||||
|
||||
class RandomSelectSubpolicy(ImageTensorOperation):
|
||||
|
@ -1521,9 +1522,12 @@ class Rotate(ImageTensorOperation):
|
|||
|
||||
@check_rotate
|
||||
def __init__(self, degrees, resample=Inter.NEAREST, expand=False, center=None, fill_value=0):
|
||||
if isinstance(degrees, numbers.Number):
|
||||
if isinstance(degrees, (int, float)):
|
||||
degrees = degrees % 360
|
||||
|
||||
if center is None:
|
||||
center = (-1, -1)
|
||||
if isinstance(fill_value, int):
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
self.degrees = degrees
|
||||
self.resample = resample
|
||||
self.expand = expand
|
||||
|
@ -1531,14 +1535,8 @@ class Rotate(ImageTensorOperation):
|
|||
self.fill_value = fill_value
|
||||
|
||||
def parse(self):
|
||||
# pylint false positive
|
||||
# pylint: disable=E1130
|
||||
degrees = self.degrees
|
||||
interpolation = DE_C_INTER_MODE[self.resample]
|
||||
expand = self.expand
|
||||
center = (-1, -1) if self.center is None else self.center
|
||||
fill_value = tuple([self.fill_value] * 3) if isinstance(self.fill_value, int) else self.fill_value
|
||||
return cde.RotateOperation(degrees, interpolation, expand, center, fill_value)
|
||||
return cde.RotateOperation(self.degrees, DE_C_INTER_MODE[self.resample], self.expand, self.center,
|
||||
self.fill_value)
|
||||
|
||||
|
||||
class SoftDvppDecodeRandomCropResizeJpeg(ImageTensorOperation):
|
||||
|
|
|
@ -140,12 +140,12 @@ def check_padding(padding):
|
|||
|
||||
def check_degrees(degrees):
|
||||
"""Check if the degrees is legal."""
|
||||
type_check(degrees, (numbers.Number, list, tuple), "degrees")
|
||||
if isinstance(degrees, numbers.Number):
|
||||
type_check(degrees, (int, float, list, tuple), "degrees")
|
||||
if isinstance(degrees, (int, float)):
|
||||
check_pos_float32(degrees, "degrees")
|
||||
elif isinstance(degrees, (list, tuple)):
|
||||
if len(degrees) == 2:
|
||||
type_check_list(degrees, (numbers.Number,), "degrees")
|
||||
type_check_list(degrees, (int, float), "degrees")
|
||||
for value in degrees:
|
||||
check_float32(value, "degrees")
|
||||
if degrees[0] > degrees[1]:
|
||||
|
@ -420,6 +420,17 @@ def check_random_color_adjust(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value):
|
||||
type_check(resample, (Inter,), "resample")
|
||||
type_check(expand, (bool,), "expand")
|
||||
if center is not None:
|
||||
check_2tuple(center, "center")
|
||||
for value in center:
|
||||
type_check(value, (int, float), "center")
|
||||
check_value(value, [-1, INT32_MAX], "center")
|
||||
check_fill_value(fill_value)
|
||||
|
||||
|
||||
def check_random_rotation(method):
|
||||
"""Wrapper method to check the parameters of random rotation."""
|
||||
|
||||
|
@ -427,15 +438,7 @@ def check_random_rotation(method):
|
|||
def new_method(self, *args, **kwargs):
|
||||
[degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_degrees(degrees)
|
||||
|
||||
if resample is not None:
|
||||
type_check(resample, (Inter,), "resample")
|
||||
if expand is not None:
|
||||
type_check(expand, (bool,), "expand")
|
||||
if center is not None:
|
||||
check_2tuple(center, "center")
|
||||
if fill_value is not None:
|
||||
check_fill_value(fill_value)
|
||||
check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -448,18 +451,9 @@ def check_rotate(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(degrees, (numbers.Number,), "degrees")
|
||||
type_check(degrees, (float, int), "degrees")
|
||||
check_float32(degrees, "degrees")
|
||||
|
||||
if resample is not None:
|
||||
type_check(resample, (Inter,), "resample")
|
||||
if expand is not None:
|
||||
type_check(expand, (bool,), "expand")
|
||||
if center is not None:
|
||||
check_2tuple(center, "center")
|
||||
if fill_value is not None:
|
||||
check_fill_value(fill_value)
|
||||
check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -555,6 +549,7 @@ def check_rgb_to_bgr(method):
|
|||
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(is_hwc, (bool,), "is_hwc")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ def test_rotate_exception():
|
|||
_ = c_vision.Rotate("60")
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in Rotate: {}".format(str(e)))
|
||||
assert "not of type [<class 'numbers.Number'>]" in str(e)
|
||||
assert "not of type [<class 'float'>, <class 'int'>]" in str(e)
|
||||
try:
|
||||
_ = c_vision.Rotate(30, Inter.BICUBIC, False, (0, 0, 0))
|
||||
except ValueError as e:
|
||||
|
|
Loading…
Reference in New Issue