From a5f9b8f92ee5f4ed0604d8d3286b674ce1d3ada5 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Fri, 21 Aug 2020 14:32:17 -0400 Subject: [PATCH] Added fix for MixUpBatch and CutMixBatch and for RandomAffine updated c color op descriptions --- .../ccsrc/minddata/dataset/api/transforms.cc | 4 +- .../dataset/kernels/image/cutmix_batch_op.cc | 17 ++-- .../dataset/kernels/image/mixup_batch_op.cc | 17 ++-- .../dataset/transforms/vision/c_transforms.py | 4 +- .../transforms/vision/py_transforms.py | 35 +------- .../transforms/vision/py_transforms_util.py | 34 ++++++-- .../dataset/transforms/vision/validators.py | 2 + tests/ut/cpp/dataset/c_api_transforms_test.cc | 54 +++++++++++- .../ut/python/dataset/test_cutmix_batch_op.py | 70 +++++++++++++++- tests/ut/python/dataset/test_mixup_op.py | 84 ++++++++++++++++++- tests/ut/python/dataset/test_random_affine.py | 20 +++++ 11 files changed, 285 insertions(+), 56 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index f339855dd81..4a92193df8a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {} bool CutMixBatchOperation::ValidateParams() { - if (alpha_ < 0) { + if (alpha_ <= 0) { MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative."; return false; } @@ -434,7 +434,7 @@ std::shared_ptr HwcToChwOperation::Build() { return std::make_sharedGetItemAt(&first_value, {i, j})); - RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); - RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); + if (input.at(1)->type().IsSignedInt()) { + int64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); + } else { + uint64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); + } } } } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc index c3195a43f40..fb303c8794a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc @@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { // Check inputs if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { - RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch"); + RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch"); } if (label_shape.size() != 2) { RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch"); @@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32"))); for (int64_t i = 0; i < label_shape[0]; i++) { for (int64_t j = 0; j < label_shape[1]; j++) { - uint64_t first_value, second_value; - RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); - RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j})); - RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value)); + if (input.at(1)->type().IsSignedInt()) { + int64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value)); + } else { + uint64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value)); + } } } diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 327abf9d170..7faf65c2857 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp): class RandomAffine(cde.RandomAffineOp): """ - Apply Random affine transformation to the input PIL image. + Apply Random affine transformation to the input image. Args: degrees (int or float or sequence): Range of the rotation degrees. @@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp): class RandomColor(cde.RandomColorOp): """ Adjust the color of the input image by a fixed or random degree. + This operation works only with 3-channel color images. Args: degrees (sequence): Range of random color adjustment degrees. It should be in (min, max) format. If min=max, then it is a single fixed magnitude operation (default=(0.1,1.9)). - Works with 3-channel color images. """ @check_positive_degrees diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index 2ddf6bbc10a..f66050949de 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -1169,39 +1169,12 @@ class RandomAffine: Returns: img (PIL Image), Randomly affine transformed image. """ - # rotation - angle = random.uniform(self.degrees[0], self.degrees[1]) - - # translation - if self.translate is not None: - max_dx = self.translate[0] * img.size[0] - max_dy = self.translate[1] * img.size[1] - translations = (np.round(random.uniform(-max_dx, max_dx)), - np.round(random.uniform(-max_dy, max_dy))) - else: - translations = (0, 0) - - # scale - if self.scale_ranges is not None: - scale = random.uniform(self.scale_ranges[0], self.scale_ranges[1]) - else: - scale = 1.0 - - # shear - if self.shear is not None: - if len(self.shear) == 2: - shear = [random.uniform(self.shear[0], self.shear[1]), 0.] - elif len(self.shear) == 4: - shear = [random.uniform(self.shear[0], self.shear[1]), - random.uniform(self.shear[2], self.shear[3])] - else: - shear = 0.0 return util.random_affine(img, - angle, - translations, - scale, - shear, + self.degrees, + self.translate, + self.scale_ranges, + self.shear, self.resample, self.fill_value) diff --git a/mindspore/dataset/transforms/vision/py_transforms_util.py b/mindspore/dataset/transforms/vision/py_transforms_util.py index 89aa230039e..0acc07f6ede 100644 --- a/mindspore/dataset/transforms/vision/py_transforms_util.py +++ b/mindspore/dataset/transforms/vision/py_transforms_util.py @@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0 if not is_pil(img): raise ValueError("Input image should be a Pillow image.") + # rotation + angle = random.uniform(angle[0], angle[1]) + + # translation + if translations is not None: + max_dx = translations[0] * img.size[0] + max_dy = translations[1] * img.size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + # scale + if scale is not None: + scale = random.uniform(scale[0], scale[1]) + else: + scale = 1.0 + + # shear + if shear is not None: + if len(shear) == 2: + shear = [random.uniform(shear[0], shear[1]), 0.] + elif len(shear) == 4: + shear = [random.uniform(shear[0], shear[1]), + random.uniform(shear[2], shear[3])] + else: + shear = 0.0 + output_size = img.size center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) @@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): def random_color(img, degrees): - """ Adjust the color of the input PIL image by a random degree. @@ -1437,7 +1464,6 @@ def random_color(img, degrees): def random_sharpness(img, degrees): - """ Adjust the sharpness of the input PIL image by a random degree. @@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees): def auto_contrast(img, cutoff, ignore): - """ Automatically maximize the contrast of the input PIL image. @@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore): def invert_color(img): - """ Invert colors of input PIL image. @@ -1498,7 +1522,6 @@ def invert_color(img): def equalize(img): - """ Equalize the histogram of input PIL image. @@ -1517,7 +1540,6 @@ def equalize(img): def uniform_augment(img, transforms, num_ops): - """ Uniformly select and apply a number of transforms sequentially from a list of transforms. Randomly assigns a probability to each transform for diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index f0eaba17936..df13ef69dcb 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method): [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") check_pos_float32(alpha) + check_positive(alpha, "alpha") check_value(prob, [0, 1], "prob") return method(self, *args, **kwargs) @@ -68,6 +69,7 @@ def check_mix_up_batch_c(method): @wraps(method) def new_method(self, *args, **kwargs): [alpha], _ = parse_user_args(method, *args, **kwargs) + check_positive(alpha, "alpha") check_pos_float32(alpha) return method(self, *args, **kwargs) diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 9b68c9780ba..fdffbfd1f12 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { ds = ds->Map({one_hot_op},{"label"}); EXPECT_NE(ds, nullptr); - std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5); + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, + 1, -0.5); EXPECT_EQ(cutmix_batch_op, nullptr); } +TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) { + // Must fail because alpha can't be zero + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, + 0.0, 0.5); + EXPECT_EQ(cutmix_batch_op, nullptr); +} + TEST_F(MindDataTestPipeline, TestCutOut) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; @@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { EXPECT_EQ(mixup_batch_op, nullptr); } +TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) { + // This should fail because alpha can't be zero + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op}, {"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr mixup_batch_op = vision::MixUpBatch(0.0); + EXPECT_EQ(mixup_batch_op, nullptr); +} + TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; @@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { ds = ds->Map({one_hot_op}, {"label"}); EXPECT_NE(ds, nullptr); - std::shared_ptr mixup_batch_op = vision::MixUpBatch(0.5); + std::shared_ptr mixup_batch_op = vision::MixUpBatch(2.0); EXPECT_NE(mixup_batch_op, nullptr); // Create a Map operation on ds diff --git a/tests/ut/python/dataset/test_cutmix_batch_op.py b/tests/ut/python/dataset/test_cutmix_batch_op.py index d283c27d137..f00115ef570 100644 --- a/tests/ut/python/dataset/test_cutmix_batch_op.py +++ b/tests/ut/python/dataset/test_cutmix_batch_op.py @@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se config_get_set_num_parallel_workers DATA_DIR = "../data/dataset/testCifar10Data" +DATA_DIR2 = "../data/dataset/testImageNetData2/train/" GENERATE_GOLDEN = False @@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False): logger.info("MSE= {}".format(str(np.mean(mse)))) +def test_cutmix_batch_success3(plot=False): + """ + Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2 + """ + logger.info("test_cutmix_batch_success3") + + ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False) + decode_op = vision.Decode() + ds_original = ds_original.map(input_columns=["image"], operations=[decode_op]) + ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # CutMix Images + data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False) + + decode_op = vision.Decode() + data1 = data1.map(input_columns=["image"], operations=[decode_op]) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + + cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) + data1 = data1.batch(4, pad_info={}, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) + + images_cutmix = None + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_cutmix = image + else: + images_cutmix = np.append(images_cutmix, image, axis=0) + if plot: + visualize_list(images_original, images_cutmix) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_cutmix[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + def test_cutmix_batch_nhwc_md5(): """ Test CutMixBatch on a batch of HWC images with MD5: @@ -185,7 +233,7 @@ def test_cutmix_batch_fail1(): images_cutmix = image else: images_cutmix = np.append(images_cutmix, image, axis=0) - error_message = "You must batch before calling CutMixBatch" + error_message = "You must make sure images are HWC or CHW and batch " assert error_message in str(error.value) @@ -322,9 +370,28 @@ def test_cutmix_batch_fail7(): assert error_message in str(error.value) +def test_cutmix_batch_fail8(): + """ + Test CutMixBatch Fail 8 + We expect this to fail because alpha is zero + """ + logger.info("test_cutmix_batch_fail8") + + # CutMixBatch Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + + if __name__ == "__main__": test_cutmix_batch_success1(plot=True) test_cutmix_batch_success2(plot=True) + test_cutmix_batch_success3(plot=True) test_cutmix_batch_nchw_md5() test_cutmix_batch_nhwc_md5() test_cutmix_batch_fail1() @@ -334,3 +401,4 @@ if __name__ == "__main__": test_cutmix_batch_fail5() test_cutmix_batch_fail6() test_cutmix_batch_fail7() + test_cutmix_batch_fail8() diff --git a/tests/ut/python/dataset/test_mixup_op.py b/tests/ut/python/dataset/test_mixup_op.py index 9641a642fe7..84e9c02e2ff 100644 --- a/tests/ut/python/dataset/test_mixup_op.py +++ b/tests/ut/python/dataset/test_mixup_op.py @@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se config_get_set_num_parallel_workers DATA_DIR = "../data/dataset/testCifar10Data" +DATA_DIR2 = "../data/dataset/testImageNetData2/train/" GENERATE_GOLDEN = False @@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False): def test_mixup_batch_success2(plot=False): + """ + Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2 + """ + logger.info("test_mixup_batch_success2") + + # Original Images + ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False) + decode_op = vision.Decode() + ds_original = ds_original.map(input_columns=["image"], operations=[decode_op]) + ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False) + + decode_op = vision.Decode() + data1 = data1.map(input_columns=["image"], operations=[decode_op]) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + + mixup_batch_op = vision.MixUpBatch(2.0) + data1 = data1.batch(4, pad_info={}, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op) + + images_mixup = None + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_mixup = image + else: + images_mixup = np.append(images_mixup, image, axis=0) + if plot: + visualize_list(images_original, images_mixup) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_mixup[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_mixup_batch_success3(plot=False): """ Test MixUpBatch op without specified alpha parameter. Alpha parameter will be selected by default in this case """ - logger.info("test_mixup_batch_success2") + logger.info("test_mixup_batch_success3") # Original Images ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) @@ -169,7 +218,7 @@ def test_mixup_batch_fail1(): images_mixup = image else: images_mixup = np.append(images_mixup, image, axis=0) - error_message = "You must batch before calling MixUp" + error_message = "You must make sure images are HWC or CHW and batch" assert error_message in str(error.value) @@ -207,6 +256,7 @@ def test_mixup_batch_fail3(): Test MixUpBatch op We expect this to fail because label column is not passed to mixup_batch """ + logger.info("test_mixup_batch_fail3") # Original Images ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) ds_original = ds_original.batch(5, drop_remainder=True) @@ -237,11 +287,41 @@ def test_mixup_batch_fail3(): error_message = "Both images and labels columns are required" assert error_message in str(error.value) +def test_mixup_batch_fail4(): + """ + Test MixUpBatch Fail 2 + We expect this to fail because alpha is zero + """ + logger.info("test_mixup_batch_fail4") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5) + + images_original = np.array([]) + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.MixUpBatch(0.0) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + if __name__ == "__main__": test_mixup_batch_success1(plot=True) test_mixup_batch_success2(plot=True) + test_mixup_batch_success3(plot=True) test_mixup_batch_md5() test_mixup_batch_fail1() test_mixup_batch_fail2() test_mixup_batch_fail3() + test_mixup_batch_fail4() diff --git a/tests/ut/python/dataset/test_random_affine.py b/tests/ut/python/dataset/test_random_affine.py index 87ee26c5b04..f2f1c42e99b 100644 --- a/tests/ut/python/dataset/test_random_affine.py +++ b/tests/ut/python/dataset/test_random_affine.py @@ -27,6 +27,7 @@ GENERATE_GOLDEN = False DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +MNIST_DATA_DIR = "../data/dataset/testMnistData" def test_random_affine_op(plot=False): @@ -155,6 +156,24 @@ def test_random_affine_c_md5(): ds.config.set_num_parallel_workers((original_num_parallel_workers)) +def test_random_affine_py_exception_non_pil_images(): + """ + Test RandomAffine: input img is ndarray and not PIL, expected to raise TypeError + """ + logger.info("test_random_affine_exception_negative_degrees") + dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3) + try: + transform = py_vision.ComposeOp([py_vision.ToTensor(), + py_vision.RandomAffine(degrees=(15, 15))]) + dataset = dataset.map(input_columns=["image"], operations=transform(), num_parallel_workers=3, + python_multiprocessing=True) + for _ in dataset.create_dict_iterator(): + break + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Pillow image" in str(e) + + def test_random_affine_exception_negative_degrees(): """ Test RandomAffine: input degrees in negative, expected to raise ValueError @@ -289,6 +308,7 @@ if __name__ == "__main__": test_random_affine_op_c(plot=True) test_random_affine_md5() test_random_affine_c_md5() + test_random_affine_py_exception_non_pil_images() test_random_affine_exception_negative_degrees() test_random_affine_exception_translation_range() test_random_affine_exception_scale_value()