diff --git a/mindspore/python/mindspore/dataset/vision/py_transforms_util.py b/mindspore/python/mindspore/dataset/vision/py_transforms_util.py index a286acd5682..21326bbd25f 100644 --- a/mindspore/python/mindspore/dataset/vision/py_transforms_util.py +++ b/mindspore/python/mindspore/dataset/vision/py_transforms_util.py @@ -964,6 +964,8 @@ def pad(img, padding, fill_value, padding_mode): image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) image.putpalette(palette) return image + if isinstance(fill_value, tuple) and (img.mode == 'L' or img.mode == '1'): + fill_value = (fill_value[0],) return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) if img.mode == 'P': diff --git a/tests/ut/python/dataset/test_random_crop.py b/tests/ut/python/dataset/test_random_crop.py index 19e3d0cf684..f0a12708409 100644 --- a/tests/ut/python/dataset/test_random_crop.py +++ b/tests/ut/python/dataset/test_random_crop.py @@ -17,6 +17,7 @@ Testing RandomCrop op in DE """ import numpy as np import pytest +from PIL import Image import mindspore.dataset.transforms as ops import mindspore.dataset.vision as vision @@ -550,6 +551,18 @@ def test_random_crop_09(): assert error_msg in str(error_info.value) +def test_random_crop_10(): + """ + Feature: RandomCrop + Description: Test Py RandomCrop with grayscale/binary image + Expectation: The dataset is processed as expected + """ + path = "../data/dataset/apple.jpg" + image_list = [Image.open(path), Image.open(path).convert('1'), Image.open(path).convert('L')] + for image in image_list: + _ = vision.RandomCrop((28))(image) + + def test_random_crop_comp(plot=False): """ Feature: RandomCrop op @@ -667,6 +680,7 @@ if __name__ == "__main__": test_random_crop_07_py() test_random_crop_08_py() test_random_crop_09() + test_random_crop_10() test_random_crop_op_c(True) test_random_crop_op_py(True) test_random_crop_comp(True) diff --git a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py index f71bb035f9b..f9733d470fc 100644 --- a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py @@ -67,7 +67,10 @@ def test_imagenet_to_mindrecord(fixture_file): for i in range(PARTITION_NUMBER): assert os.path.exists(file_name + str(i)) assert os.path.exists(file_name + str(i) + ".db") - read(file_name + "0") + read([file_name + "0", + file_name + "1", + file_name + "2", + file_name + "3"]) def test_imagenet_to_mindrecord_default_partition_number(fixture_file): """