forked from mindspore-Ecosystem/mindspore
!4580 add decode test case for padded
Merge pull request !4580 from guozhijian/add_test_case_for_padded
This commit is contained in:
commit
9d1a6c1a9d
|
@ -1,8 +1,12 @@
|
|||
from io import BytesIO
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset.transforms.vision.c_transforms as V_C
|
||||
from PIL import Image
|
||||
|
||||
FILES_NUM = 4
|
||||
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
||||
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||
|
@ -197,7 +201,7 @@ def test_raise_error():
|
|||
ds3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'ValueError'
|
||||
|
||||
def test_imagefolden_padded():
|
||||
def test_imagefolder_padded():
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
data = ds.ImageFolderDatasetV2(DATA_DIR)
|
||||
|
||||
|
@ -220,6 +224,32 @@ def test_imagefolden_padded():
|
|||
assert verify_list[8] == 1
|
||||
assert verify_list[9] == 6
|
||||
|
||||
def test_imagefolder_padded_with_decode():
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
data = ds.ImageFolderDatasetV2(DATA_DIR)
|
||||
|
||||
white_io = BytesIO()
|
||||
Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
|
||||
padded_sample = {}
|
||||
padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8')
|
||||
padded_sample['label'] = np.array(-1, np.int32)
|
||||
|
||||
white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
|
||||
data2 = ds.PaddedDataset(white_samples)
|
||||
data3 = data + data2
|
||||
|
||||
num_shards = 5
|
||||
count = 0
|
||||
for shard_id in range(num_shards):
|
||||
testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
|
||||
data3.use_sampler(testsampler)
|
||||
data3.map(input_columns="image", operations=V_C.Decode())
|
||||
for ele in data3.create_dict_iterator():
|
||||
print("label: {}".format(ele['label']))
|
||||
count += 1
|
||||
assert count == 48
|
||||
|
||||
|
||||
def test_more_shard_padded():
|
||||
result_list = []
|
||||
for i in range(8):
|
||||
|
|
Loading…
Reference in New Issue