forked from mindspore-Ecosystem/mindspore
!5105 Check input image type for random posterize
Merge pull request !5105 from luoyang/c-api
This commit is contained in:
commit
600263ccfe
|
@ -40,6 +40,8 @@ Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
|
||||||
}
|
}
|
||||||
cv::Mat in_image = input_cv->mat();
|
cv::Mat in_image = input_cv->mat();
|
||||||
cv::Mat output_img;
|
cv::Mat output_img;
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(in_image.depth() == CV_8U || in_image.depth() == CV_8S,
|
||||||
|
"Input image data type can not be float, but got " + input->type().ToString());
|
||||||
cv::LUT(in_image, lut_vector, output_img);
|
cv::LUT(in_image, lut_vector, output_img);
|
||||||
std::shared_ptr<CVTensor> result_tensor;
|
std::shared_ptr<CVTensor> result_tensor;
|
||||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
|
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
|
||||||
|
|
|
@ -142,8 +142,29 @@ def test_random_posterize_exception_bit():
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
|
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
|
||||||
|
|
||||||
|
def test_rescale_with_random_posterize():
|
||||||
|
"""
|
||||||
|
Test RandomPosterize: only support CV_8S/CV_8U
|
||||||
|
"""
|
||||||
|
logger.info("test_rescale_with_random_posterize")
|
||||||
|
|
||||||
|
DATA_DIR_10 = "../data/dataset/testCifar10Data"
|
||||||
|
dataset = ds.Cifar10Dataset(DATA_DIR_10)
|
||||||
|
|
||||||
|
rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0)
|
||||||
|
dataset = dataset.map(input_columns=["image"], operations=rescale_op)
|
||||||
|
|
||||||
|
random_posterize_op = c_vision.RandomPosterize((4, 8))
|
||||||
|
dataset = dataset.map(input_columns=["image"], operations=random_posterize_op, num_parallel_workers=1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = dataset.output_shapes()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "Input image data type can not be float" in str(e)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
skip_test_random_posterize_op_c(plot=True)
|
skip_test_random_posterize_op_c(plot=True)
|
||||||
skip_test_random_posterize_op_fixed_point_c(plot=True)
|
skip_test_random_posterize_op_fixed_point_c(plot=True)
|
||||||
test_random_posterize_exception_bit()
|
test_random_posterize_exception_bit()
|
||||||
|
test_rescale_with_random_posterize()
|
||||||
|
|
Loading…
Reference in New Issue