!3102 [MD] Add additional parameter checks in RandomCropWithBBoxOp and RandomResizeWithBbox

Merge pull request !3102 from nhussain/random_crop_fixes
This commit is contained in:
mindspore-ci-bot 2020-07-17 22:33:50 +08:00 committed by Gitee
commit a84092e89d
9 changed files with 87 additions and 19 deletions

View File

@ -189,8 +189,10 @@ def type_check_list(args, types, arg_names):
Exception: when the type is not correct, otherwise nothing. Exception: when the type is not correct, otherwise nothing.
""" """
type_check(args, (list, tuple,), arg_names) type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names): if len(args) != len(arg_names) and not isinstance(arg_names, str):
raise ValueError("List of arguments is not the same length as argument_names.") raise ValueError("List of arguments is not the same length as argument_names.")
if isinstance(arg_names, str):
arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))]
for arg, arg_name in zip(args, arg_names): for arg, arg_name in zip(args, arg_names):
type_check(arg, types, arg_name) type_check(arg, types, arg_name)

View File

@ -686,8 +686,7 @@ def check_concat(method):
[ds], _ = parse_user_args(method, *args, **kwargs) [ds], _ = parse_user_args(method, *args, **kwargs)
type_check(ds, (list, datasets.Dataset), "datasets") type_check(ds, (list, datasets.Dataset), "datasets")
if isinstance(ds, list): if isinstance(ds, list):
dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)] type_check_list(ds, (datasets.Dataset,), "dataset")
type_check_list(ds, (datasets.Dataset,), dataset_names)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -751,8 +750,7 @@ def check_add_column(method):
if shape is not None: if shape is not None:
type_check(shape, (list,), "shape") type_check(shape, (list,), "shape")
shape_names = ["shape[{0}]".format(i) for i in range(len(shape))] type_check_list(shape, (int,), "shape")
type_check_list(shape, (int,), shape_names)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View File

@ -297,8 +297,7 @@ def check_from_dataset(method):
if columns is not None: if columns is not None:
if not isinstance(columns, list): if not isinstance(columns, list):
columns = [columns] columns = [columns]
col_names = ["col_{0}".format(i) for i in range(len(columns))] type_check_list(columns, (str,), "col")
type_check_list(columns, (str,), col_names)
if freq_range is not None: if freq_range is not None:
type_check(freq_range, (tuple,), "freq_range") type_check(freq_range, (tuple,), "freq_range")

View File

@ -78,6 +78,8 @@ def check_fill_value(fill_value):
def check_padding(padding): def check_padding(padding):
"""Parsing the padding arguments and check if it is legal.""" """Parsing the padding arguments and check if it is legal."""
type_check(padding, (tuple, list, numbers.Number), "padding") type_check(padding, (tuple, list, numbers.Number), "padding")
if isinstance(padding, numbers.Number):
check_value(padding, (0, INT32_MAX), "padding")
if isinstance(padding, (tuple, list)): if isinstance(padding, (tuple, list)):
if len(padding) not in (2, 4): if len(padding) not in (2, 4):
raise ValueError("The size of the padding list or tuple should be 2 or 4.") raise ValueError("The size of the padding list or tuple should be 2 or 4.")
@ -163,10 +165,13 @@ def check_random_resize_crop(method):
check_crop_size(size) check_crop_size(size)
if scale is not None: if scale is not None:
type_check(scale, (tuple,), "scale")
type_check_list(scale, (float, int), "scale")
check_range(scale, [0, FLOAT_MAX_INTEGER]) check_range(scale, [0, FLOAT_MAX_INTEGER])
if ratio is not None: if ratio is not None:
type_check(ratio, (tuple,), "ratio")
type_check_list(ratio, (float, int), "ratio")
check_range(ratio, [0, FLOAT_MAX_INTEGER]) check_range(ratio, [0, FLOAT_MAX_INTEGER])
check_positive(ratio[0], "ratio[0]")
if interpolation is not None: if interpolation is not None:
type_check(interpolation, (Inter,), "interpolation") type_check(interpolation, (Inter,), "interpolation")
if max_attempts is not None: if max_attempts is not None:
@ -450,8 +455,7 @@ def check_random_affine(method):
if translate is not None: if translate is not None:
if type_check(translate, (list, tuple), "translate"): if type_check(translate, (list, tuple), "translate"):
translate_names = ["translate_{0}".format(i) for i in range(len(translate))] type_check_list(translate, (int, float), "translate")
type_check_list(translate, (int, float), translate_names)
if len(translate) != 2: if len(translate) != 2:
raise TypeError("translate should be a list or tuple of length 2.") raise TypeError("translate should be a list or tuple of length 2.")
for i, t in enumerate(translate): for i, t in enumerate(translate):
@ -508,8 +512,7 @@ def check_uniform_augment_cpp(method):
if num_ops > len(operations): if num_ops > len(operations):
raise ValueError("num_ops is greater than operations list size") raise ValueError("num_ops is greater than operations list size")
tensor_ops = ["tensor_op_{0}".format(i) for i in range(len(operations))] type_check_list(operations, (TensorOp,), "tensor_ops")
type_check_list(operations, (TensorOp,), tensor_ops)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View File

@ -134,7 +134,7 @@ def test_from_dataset_exceptions():
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
test_config("text", (2, 3), 1.2345, test_config("text", (2, 3), 1.2345,
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)") "Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)") test_config(23, (2, 3), 1.2345, "Argument col[0] with value 23 is not of type (<class 'str'>,)")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
test_config("text", (2, 3), 0, "top_k must be greater than 0") test_config("text", (2, 3), 0, "top_k must be greater than 0")
test_config([123], (2, 3), -1, "top_k must be greater than 0") test_config([123], (2, 3), -1, "top_k must be greater than 0")

View File

@ -332,11 +332,37 @@ def test_random_crop_and_resize_comp(plot=False):
image_c_cropped.append(c_image) image_c_cropped.append(c_image)
image_py_cropped.append(py_image) image_py_cropped.append(py_image)
mse = diff_mse(c_image, py_image) mse = diff_mse(c_image, py_image)
assert mse < 0.02 # rounding error assert mse < 0.02 # rounding error
if plot: if plot:
visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2) visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2)
def test_random_crop_and_resize_06():
"""
Test RandomCropAndResize with c_transforms: invalid values for scale,
expected to raise ValueError
"""
logger.info("test_random_crop_and_resize_05_c")
# Generate dataset
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
try:
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), scale="", ratio=(1, 0.5))
data = data.map(input_columns=["image"], operations=decode_op)
data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Argument scale with value \"\" is not of type (<class 'tuple'>,)" in str(e)
try:
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), scale=(1, "2"), ratio=(1, 0.5))
data = data.map(input_columns=["image"], operations=decode_op)
data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Argument scale[1] with value 2 is not of type (<class 'float'>, <class 'int'>)." in str(e)
if __name__ == "__main__": if __name__ == "__main__":
test_random_crop_and_resize_op_c(True) test_random_crop_and_resize_op_c(True)
test_random_crop_and_resize_op_py(True) test_random_crop_and_resize_op_py(True)
@ -347,4 +373,5 @@ if __name__ == "__main__":
test_random_crop_and_resize_04_py() test_random_crop_and_resize_04_py()
test_random_crop_and_resize_05_c() test_random_crop_and_resize_05_c()
test_random_crop_and_resize_05_py() test_random_crop_and_resize_05_py()
test_random_crop_and_resize_06()
test_random_crop_and_resize_comp(True) test_random_crop_and_resize_comp(True)

View File

@ -178,13 +178,15 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],
columns_order=["image", "annotation"], columns_order=["image", "annotation"],
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) operations=[lambda img, bboxes: (
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
# Test Op added to list of Operations here # Test Op added to list of Operations here
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],
columns_order=["image", "annotation"], columns_order=["image", "annotation"],
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) operations=[lambda img, bboxes: (
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
unaugSamp, augSamp = [], [] unaugSamp, augSamp = [], []
@ -239,6 +241,29 @@ def test_random_crop_with_bbox_op_bad_c():
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
def test_random_crop_with_bbox_op_negative_padding():
"""
Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
logger.info("test_random_crop_with_bbox_op_invalid_c")
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
try:
test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1)
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input padding is not within the required interval of (0 to 2147483647)." in str(err)
if __name__ == "__main__": if __name__ == "__main__":
test_random_crop_with_bbox_op_c(plot_vis=True) test_random_crop_with_bbox_op_c(plot_vis=True)
test_random_crop_with_bbox_op_coco_c(plot_vis=True) test_random_crop_with_bbox_op_coco_c(plot_vis=True)
@ -247,3 +272,4 @@ if __name__ == "__main__":
test_random_crop_with_bbox_op_edge_c(plot_vis=True) test_random_crop_with_bbox_op_edge_c(plot_vis=True)
test_random_crop_with_bbox_op_invalid_c() test_random_crop_with_bbox_op_invalid_c()
test_random_crop_with_bbox_op_bad_c() test_random_crop_with_bbox_op_bad_c()
test_random_crop_with_bbox_op_negative_padding()

View File

@ -16,9 +16,10 @@
Testing the resize with bounding boxes op in DE Testing the resize with bounding boxes op in DE
""" """
import numpy as np import numpy as np
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger from mindspore import log as logger
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
save_and_check_md5 save_and_check_md5
@ -172,6 +173,18 @@ def test_resize_with_bbox_op_bad_c():
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
def test_resize_with_bbox_op_params_outside_of_interpolation_dict():
"""
Test passing in a invalid key for interpolation
"""
logger.info("test_resize_with_bbox_op_params_outside_of_interpolation_dict")
size = (500, 500)
more_para = None
with pytest.raises(KeyError, match="None"):
c_vision.ResizeWithBBox(size, more_para)
if __name__ == "__main__": if __name__ == "__main__":
test_resize_with_bbox_op_voc_c(plot_vis=False) test_resize_with_bbox_op_voc_c(plot_vis=False)
test_resize_with_bbox_op_coco_c(plot_vis=False) test_resize_with_bbox_op_coco_c(plot_vis=False)

View File

@ -166,10 +166,10 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
F.Invert()] F.Invert()]
with pytest.raises(TypeError) as e: with pytest.raises(TypeError) as e:
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Argument tensor_op_5 with value" \ assert "Argument tensor_ops[5] with value" \
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e.value) " <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e.value)
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e.value) assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e.value)