From 8bb4449fa872b590ec230020197b3cec6f987d23 Mon Sep 17 00:00:00 2001 From: jonyguo Date: Mon, 17 Aug 2020 11:50:58 +0800 Subject: [PATCH] add testcase for padded dataset with decode op --- tests/ut/python/dataset/test_paddeddataset.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index df87599d59b..a2a24d03dfc 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -1,8 +1,12 @@ +from io import BytesIO import os import numpy as np import pytest import mindspore.dataset as ds from mindspore.mindrecord import FileWriter +import mindspore.dataset.transforms.vision.c_transforms as V_C +from PIL import Image + FILES_NUM = 4 CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" CV_DIR_NAME = "../data/mindrecord/testImageNetData" @@ -197,7 +201,7 @@ def test_raise_error(): ds3.use_sampler(testsampler) assert excinfo.type == 'ValueError' -def test_imagefolden_padded(): +def test_imagefolder_padded(): DATA_DIR = "../data/dataset/testPK/data" data = ds.ImageFolderDatasetV2(DATA_DIR) @@ -220,6 +224,32 @@ def test_imagefolden_padded(): assert verify_list[8] == 1 assert verify_list[9] == 6 +def test_imagefolder_padded_with_decode(): + DATA_DIR = "../data/dataset/testPK/data" + data = ds.ImageFolderDatasetV2(DATA_DIR) + + white_io = BytesIO() + Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG') + padded_sample = {} + padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8') + padded_sample['label'] = np.array(-1, np.int32) + + white_samples = [padded_sample, padded_sample, padded_sample, padded_sample] + data2 = ds.PaddedDataset(white_samples) + data3 = data + data2 + + num_shards = 5 + count = 0 + for shard_id in range(num_shards): + testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None) + data3.use_sampler(testsampler) + data3.map(input_columns="image", operations=V_C.Decode()) + for ele in data3.create_dict_iterator(): + print("label: {}".format(ele['label'])) + count += 1 + assert count == 48 + + def test_more_shard_padded(): result_list = [] for i in range(8):