fix param check for Rotate and GaussianBlur

This commit is contained in:
Xiao Tianci 2021-06-30 14:48:55 +08:00
parent 0ac3cd3aef
commit 8cdd0d8d90
5 changed files with 49 additions and 54 deletions

View File

@ -632,8 +632,8 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Rotate: load image failed.");
}
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
RETURN_STATUS_UNEXPECTED("Rotate: input tensor is not in shape of <H,W,C> or <H,W>.");
if (input_cv->Rank() != DEFAULT_IMAGE_RANK && input_cv->Rank() != MIN_IMAGE_DIMENSION) {
RETURN_STATUS_UNEXPECTED("Rotate: image shape is not <H,W,C> or <H,W>.");
}
cv::Mat input_img = input_cv->mat();
@ -641,8 +641,10 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center is not precise.");
}
// default to center of image
if (fx == -1 && fy == -1) {
if (fx == -1) {
fx = (input_img.cols - 1) / 2.0;
}
if (fy == -1) {
fy = (input_img.rows - 1) / 2.0;
}
cv::Mat output_img;
@ -1270,6 +1272,9 @@ Status GaussianBlur(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
int32_t kernel_y, float sigma_x, float sigma_y) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (input_cv->mat().data == nullptr) {
RETURN_STATUS_UNEXPECTED("GaussianBlur: load image failed.");
}
cv::Mat output_cv_mat;
cv::GaussianBlur(input_cv->mat(), output_cv_mat, cv::Size(kernel_x, kernel_y), static_cast<double>(sigma_x),
static_cast<double>(sigma_y));

View File

@ -47,9 +47,6 @@ RotateOp::RotateOp(float degrees, InterpolationMode resample, bool expand, float
Status RotateOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(
input->shape().Size() >= 2,
"Rotate: image shape " + std::to_string(input->shape().Size()) + " is not <H,W,C> or <H,W>.");
#ifndef ENABLE_ANDROID
return Rotate(input, output, center_x_, center_y_, degrees_, interpolation_, expand_, fill_r_, fill_g_, fill_b_);
#else

View File

@ -1201,13 +1201,20 @@ class RandomRotation(ImageTensorOperation):
@check_random_rotation
def __init__(self, degrees, resample=Inter.NEAREST, expand=False, center=None, fill_value=0):
if isinstance(degrees, numbers.Number):
if isinstance(degrees, (int, float)):
degrees = degrees % 360
if isinstance(degrees, (list, tuple)):
degrees = [degrees[0] % 360, degrees[1] % 360]
if degrees[0] > degrees[1]:
degrees[1] += 360
degrees = [-degrees, degrees]
elif isinstance(degrees, (list, tuple)):
if degrees[1] - degrees[0] >= 360:
degrees = [-180, 180]
else:
degrees = [degrees[0] % 360, degrees[1] % 360]
if degrees[0] > degrees[1]:
degrees[1] += 360
if center is None:
center = (-1, -1)
if isinstance(fill_value, int):
fill_value = tuple([fill_value] * 3)
self.degrees = degrees
self.resample = resample
self.expand = expand
@ -1215,14 +1222,8 @@ class RandomRotation(ImageTensorOperation):
self.fill_value = fill_value
def parse(self):
# pylint false positive
# pylint: disable=E1130
degrees = (-self.degrees, self.degrees) if isinstance(self.degrees, numbers.Number) else self.degrees
interpolation = DE_C_INTER_MODE[self.resample]
expand = self.expand
center = (-1, -1) if self.center is None else self.center
fill_value = tuple([self.fill_value] * 3) if isinstance(self.fill_value, int) else self.fill_value
return cde.RandomRotationOperation(degrees, interpolation, expand, center, fill_value)
return cde.RandomRotationOperation(self.degrees, DE_C_INTER_MODE[self.resample], self.expand, self.center,
self.fill_value)
class RandomSelectSubpolicy(ImageTensorOperation):
@ -1521,9 +1522,12 @@ class Rotate(ImageTensorOperation):
@check_rotate
def __init__(self, degrees, resample=Inter.NEAREST, expand=False, center=None, fill_value=0):
if isinstance(degrees, numbers.Number):
if isinstance(degrees, (int, float)):
degrees = degrees % 360
if center is None:
center = (-1, -1)
if isinstance(fill_value, int):
fill_value = tuple([fill_value] * 3)
self.degrees = degrees
self.resample = resample
self.expand = expand
@ -1531,14 +1535,8 @@ class Rotate(ImageTensorOperation):
self.fill_value = fill_value
def parse(self):
# pylint false positive
# pylint: disable=E1130
degrees = self.degrees
interpolation = DE_C_INTER_MODE[self.resample]
expand = self.expand
center = (-1, -1) if self.center is None else self.center
fill_value = tuple([self.fill_value] * 3) if isinstance(self.fill_value, int) else self.fill_value
return cde.RotateOperation(degrees, interpolation, expand, center, fill_value)
return cde.RotateOperation(self.degrees, DE_C_INTER_MODE[self.resample], self.expand, self.center,
self.fill_value)
class SoftDvppDecodeRandomCropResizeJpeg(ImageTensorOperation):

View File

@ -140,12 +140,12 @@ def check_padding(padding):
def check_degrees(degrees):
"""Check if the degrees is legal."""
type_check(degrees, (numbers.Number, list, tuple), "degrees")
if isinstance(degrees, numbers.Number):
type_check(degrees, (int, float, list, tuple), "degrees")
if isinstance(degrees, (int, float)):
check_pos_float32(degrees, "degrees")
elif isinstance(degrees, (list, tuple)):
if len(degrees) == 2:
type_check_list(degrees, (numbers.Number,), "degrees")
type_check_list(degrees, (int, float), "degrees")
for value in degrees:
check_float32(value, "degrees")
if degrees[0] > degrees[1]:
@ -420,6 +420,17 @@ def check_random_color_adjust(method):
return new_method
def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value):
type_check(resample, (Inter,), "resample")
type_check(expand, (bool,), "expand")
if center is not None:
check_2tuple(center, "center")
for value in center:
type_check(value, (int, float), "center")
check_value(value, [-1, INT32_MAX], "center")
check_fill_value(fill_value)
def check_random_rotation(method):
"""Wrapper method to check the parameters of random rotation."""
@ -427,15 +438,7 @@ def check_random_rotation(method):
def new_method(self, *args, **kwargs):
[degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
check_degrees(degrees)
if resample is not None:
type_check(resample, (Inter,), "resample")
if expand is not None:
type_check(expand, (bool,), "expand")
if center is not None:
check_2tuple(center, "center")
if fill_value is not None:
check_fill_value(fill_value)
check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
return method(self, *args, **kwargs)
@ -448,18 +451,9 @@ def check_rotate(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
type_check(degrees, (numbers.Number,), "degrees")
type_check(degrees, (float, int), "degrees")
check_float32(degrees, "degrees")
if resample is not None:
type_check(resample, (Inter,), "resample")
if expand is not None:
type_check(expand, (bool,), "expand")
if center is not None:
check_2tuple(center, "center")
if fill_value is not None:
check_fill_value(fill_value)
check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
return method(self, *args, **kwargs)
@ -555,6 +549,7 @@ def check_rgb_to_bgr(method):
[is_hwc], _ = parse_user_args(method, *args, **kwargs)
type_check(is_hwc, (bool,), "is_hwc")
return method(self, *args, **kwargs)
return new_method

View File

@ -101,7 +101,7 @@ def test_rotate_exception():
_ = c_vision.Rotate("60")
except TypeError as e:
logger.info("Got an exception in Rotate: {}".format(str(e)))
assert "not of type [<class 'numbers.Number'>]" in str(e)
assert "not of type [<class 'float'>, <class 'int'>]" in str(e)
try:
_ = c_vision.Rotate(30, Inter.BICUBIC, False, (0, 0, 0))
except ValueError as e: