forked from mindspore-Ecosystem/mindspore
!9825 [MD] fix pad, normalize in c_trasnform
From: @luoyang42 Reviewed-by: @heleiwang,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
ad6507d88a
|
@ -30,6 +30,11 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
|
||||||
}))
|
}))
|
||||||
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) {
|
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) {
|
||||||
std::shared_ptr<Tensor> out = self(in);
|
std::shared_ptr<Tensor> out = self(in);
|
||||||
|
if (out == nullptr) {
|
||||||
|
THROW_IF_ERROR([]() {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
|
||||||
|
}());
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
});
|
});
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -73,7 +73,7 @@ Status Flip(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, int
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in flip op.");
|
RETURN_STATUS_UNEXPECTED("Error in flip op: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null");
|
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null");
|
||||||
|
@ -118,7 +118,7 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in image resize.");
|
RETURN_STATUS_UNEXPECTED("Error in image resize: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ Status DecodeCv(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in image Decode");
|
RETURN_STATUS_UNEXPECTED("Error in image Decode: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,7 +350,7 @@ Status Rescale(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
|
||||||
input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift);
|
input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift);
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in image rescale");
|
RETURN_STATUS_UNEXPECTED("Error in image rescale: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -382,7 +382,7 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in crop.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in crop: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -417,7 +417,7 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output)
|
||||||
*output = std::move(output_cv);
|
*output = std::move(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,7 +503,7 @@ Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,7 +537,7 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
||||||
*output = std::static_pointer_cast<Tensor>(cvt_out);
|
*output = std::static_pointer_cast<Tensor>(cvt_out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,7 +584,7 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
}
|
}
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in image rotation");
|
RETURN_STATUS_UNEXPECTED("Error in image rotation: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -626,7 +626,7 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -646,7 +646,7 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
||||||
output_cv->mat() = input_img * alpha;
|
output_cv->mat() = input_img * alpha;
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in adjust brightness");
|
RETURN_STATUS_UNEXPECTED("Error in adjust brightness: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -673,7 +673,7 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
|
||||||
output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha;
|
output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha;
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in adjust contrast");
|
RETURN_STATUS_UNEXPECTED("Error in adjust contrast: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -751,7 +751,7 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
||||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
(*output)->Reshape(input->shape());
|
(*output)->Reshape(input->shape());
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in auto contrast");
|
RETURN_STATUS_UNEXPECTED("Error in auto contrast: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -776,7 +776,7 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
||||||
output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha;
|
output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha;
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in adjust saturation");
|
RETURN_STATUS_UNEXPECTED("Error in adjust saturation: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -812,7 +812,7 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
||||||
cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL);
|
cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL);
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in adjust hue");
|
RETURN_STATUS_UNEXPECTED("Error in adjust hue: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -853,7 +853,7 @@ Status Equalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
||||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
(*output)->Reshape(input->shape());
|
(*output)->Reshape(input->shape());
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in equalize.");
|
RETURN_STATUS_UNEXPECTED("Error in equalize: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -915,7 +915,7 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
|
||||||
*output = std::static_pointer_cast<Tensor>(input);
|
*output = std::static_pointer_cast<Tensor>(input);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Error in erasing");
|
RETURN_STATUS_UNEXPECTED("Error in erasing: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -943,7 +943,7 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in pad");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in pad: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -962,7 +962,7 @@ Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToRgb.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToRgb: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -981,7 +981,7 @@ Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
||||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (const cv::Exception &e) {
|
} catch (const cv::Exception &e) {
|
||||||
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToBgr.");
|
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToBgr: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,14 @@ def check_pos_int64(value, arg_name=""):
|
||||||
check_value(value, [UINT64_MIN, INT64_MAX])
|
check_value(value, [UINT64_MIN, INT64_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_float32(value, arg_name=""):
|
||||||
|
check_value(value, [FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER], arg_name)
|
||||||
|
|
||||||
|
|
||||||
|
def check_float64(value, arg_name=""):
|
||||||
|
check_value(value, [DOUBLE_MIN_INTEGER, DOUBLE_MAX_INTEGER], arg_name)
|
||||||
|
|
||||||
|
|
||||||
def check_pos_float32(value, arg_name=""):
|
def check_pos_float32(value, arg_name=""):
|
||||||
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
|
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
|
||||||
|
|
||||||
|
|
|
@ -263,7 +263,7 @@ class Normalize(cde.NormalizeOp):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
|
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
|
||||||
The mean values must be in range (0.0, 255.0].
|
The mean values must be in range [0.0, 255.0].
|
||||||
std (sequence): List or tuple of standard deviations for each channel, with respect to channel order.
|
std (sequence): List or tuple of standard deviations for each channel, with respect to channel order.
|
||||||
The standard deviation values must be in range (0.0, 255.0].
|
The standard deviation values must be in range (0.0, 255.0].
|
||||||
|
|
||||||
|
@ -278,6 +278,10 @@ class Normalize(cde.NormalizeOp):
|
||||||
|
|
||||||
@check_normalize_c
|
@check_normalize_c
|
||||||
def __init__(self, mean, std):
|
def __init__(self, mean, std):
|
||||||
|
if len(mean) == 1:
|
||||||
|
mean = [mean[0]] * 3
|
||||||
|
if len(std) == 1:
|
||||||
|
std = [std[0]] * 3
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.std = std
|
self.std = std
|
||||||
super().__init__(*mean, *std)
|
super().__init__(*mean, *std)
|
||||||
|
@ -1193,9 +1197,11 @@ class Pad(cde.PadOp):
|
||||||
with the first value and (right and bottom) with the second value.
|
with the first value and (right and bottom) with the second value.
|
||||||
If 4 values are provided as a list or tuple,
|
If 4 values are provided as a list or tuple,
|
||||||
it pads the left, top, right and bottom respectively.
|
it pads the left, top, right and bottom respectively.
|
||||||
fill_value (Union[int, tuple], optional): The pixel intensity of the borders if
|
fill_value (Union[int, tuple], optional): The pixel intensity of the borders, only valid for
|
||||||
the padding_mode is Border.CONSTANT (default=0). If it is a 3-tuple, it is used to
|
padding_mode Border.CONSTANT (default=0).
|
||||||
fill R, G, B channels respectively.
|
If it is an integer, it is used for all RGB channels.
|
||||||
|
If it is a 3-tuple, it is used to fill R, G, B channels respectively.
|
||||||
|
The fill_value values must be in range [0, 255].
|
||||||
padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT). Can be any of
|
padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT). Can be any of
|
||||||
[Border.CONSTANT, Border.EDGE, Border.REFLECT, Border.SYMMETRIC].
|
[Border.CONSTANT, Border.EDGE, Border.REFLECT, Border.SYMMETRIC].
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ class Normalize:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
|
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
|
||||||
The mean values must be in the range (0.0, 1.0].
|
The mean values must be in the range [0.0, 1.0].
|
||||||
std (sequence): List or tuple of standard deviations for each channel, w.r.t. channel order.
|
std (sequence): List or tuple of standard deviations for each channel, w.r.t. channel order.
|
||||||
The standard deviation values must be in the range (0.0, 1.0].
|
The standard deviation values must be in the range (0.0, 1.0].
|
||||||
|
|
||||||
|
@ -783,8 +783,9 @@ class Pad:
|
||||||
with the first value and the right and bottom with the second value.
|
with the first value and the right and bottom with the second value.
|
||||||
If 4 values are provided as a list or tuple,
|
If 4 values are provided as a list or tuple,
|
||||||
pad the left, top, right and bottom respectively.
|
pad the left, top, right and bottom respectively.
|
||||||
fill_value (Union[int, tuple], optional): Filling value for the pixel intensity
|
fill_value (Union[int, tuple], optional): The pixel intensity of the borders, only valid for
|
||||||
of the borders if the padding_mode is Border.CONSTANT (Default=0).
|
padding_mode Border.CONSTANT (default=0).
|
||||||
|
If it is an integer, it is used for all RGB channels.
|
||||||
If it is a 3-tuple, it is used to fill R, G, B channels respectively.
|
If it is a 3-tuple, it is used to fill R, G, B channels respectively.
|
||||||
padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT).
|
padding_mode (Border mode, optional): The method of padding (default=Border.CONSTANT).
|
||||||
It can be any of [Border.CONSTANT, Border.EDGE, Border.REFLECT, Border.SYMMETRIC].
|
It can be any of [Border.CONSTANT, Border.EDGE, Border.REFLECT, Border.SYMMETRIC].
|
||||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
||||||
from mindspore._c_dataengine import TensorOp
|
from mindspore._c_dataengine import TensorOp
|
||||||
|
|
||||||
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
||||||
check_tensor_op, UINT8_MAX, check_value_normalize_std
|
check_tensor_op, UINT8_MAX, check_value_normalize_std
|
||||||
from .utils import Inter, Border, ImageBatchFormat
|
from .utils import Inter, Border, ImageBatchFormat
|
||||||
|
|
||||||
|
@ -78,12 +78,14 @@ def check_mix_up_batch_c(method):
|
||||||
|
|
||||||
|
|
||||||
def check_normalize_c_param(mean, std):
|
def check_normalize_c_param(mean, std):
|
||||||
|
type_check(mean, (list, tuple), "mean")
|
||||||
|
type_check(std, (list, tuple), "std")
|
||||||
if len(mean) != len(std):
|
if len(mean) != len(std):
|
||||||
raise ValueError("Length of mean and std must be equal.")
|
raise ValueError("Length of mean and std must be equal.")
|
||||||
for mean_value in mean:
|
for mean_value in mean:
|
||||||
check_pos_float32(mean_value)
|
check_value(mean_value, [0, 255], "mean_value")
|
||||||
for std_value in std:
|
for std_value in std:
|
||||||
check_pos_float32(std_value)
|
check_value_normalize_std(std_value, [0, 255], "std_value")
|
||||||
|
|
||||||
|
|
||||||
def check_normalize_py_param(mean, std):
|
def check_normalize_py_param(mean, std):
|
||||||
|
@ -548,8 +550,10 @@ def check_rescale(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[rescale, shift], _ = parse_user_args(method, *args, **kwargs)
|
[rescale, shift], _ = parse_user_args(method, *args, **kwargs)
|
||||||
check_pos_float32(rescale)
|
type_check(rescale, (numbers.Number,), "rescale")
|
||||||
type_check(shift, (numbers.Number,), "shift")
|
type_check(shift, (numbers.Number,), "shift")
|
||||||
|
check_float32(rescale)
|
||||||
|
check_float32(shift)
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,7 @@ def test_eager_exceptions():
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
assert "Input should be NumPy or PIL image" in str(e)
|
assert "Input should be NumPy or PIL image" in str(e)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_eager_resize()
|
test_eager_resize()
|
||||||
test_eager_rescale()
|
test_eager_rescale()
|
||||||
|
|
|
@ -245,6 +245,24 @@ def test_normalize_exception_unequal_size_c():
|
||||||
assert str(e) == "Length of mean and std must be equal."
|
assert str(e) == "Length of mean and std must be equal."
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_exception_out_of_range_c():
|
||||||
|
"""
|
||||||
|
Test Normalize in c transformation: mean, std out of range
|
||||||
|
expected to raise ValueError
|
||||||
|
"""
|
||||||
|
logger.info("test_normalize_exception_out_of_range_c")
|
||||||
|
try:
|
||||||
|
_ = c_vision.Normalize([256, 250, 125], [50, 75, 75])
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "not within the required interval" in str(e)
|
||||||
|
try:
|
||||||
|
_ = c_vision.Normalize([255, 250, 125], [0, 75, 75])
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "not within the required interval" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_exception_unequal_size_py():
|
def test_normalize_exception_unequal_size_py():
|
||||||
"""
|
"""
|
||||||
Test Normalize in python transformation: len(mean) != len(std)
|
Test Normalize in python transformation: len(mean) != len(std)
|
||||||
|
|
Loading…
Reference in New Issue