!42349 fix: generator dataset with multi-process use .skip error

Merge pull request !42349 from guozhijian/fix_generator_dataset_split
This commit is contained in:
i-robot 2022-09-20 10:58:23 +00:00 committed by Gitee
commit d007a090ae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 57 additions and 0 deletions

View File

@ -175,8 +175,10 @@ class SamplerFn:
def __init__(self, dataset, num_worker, multi_process, max_rowsize):
self.workers = []
self.dataset = dataset
self.num_worker = num_worker
self.multi_process = multi_process
self.max_rowsize = max_rowsize
self.need_join = False
self.ppid = os.getpid()
self.pids = []
@ -342,6 +344,9 @@ class SamplerFn:
except TypeError:
pass
def __deepcopy__(self, memodict, exclude=()):
self.__init__(self.dataset, self.num_worker, self.multi_process, self.max_rowsize)
def _subprocess_handle(eof, signum, frame):
threading.Thread(target=eof.set()).start()

View File

@ -15,8 +15,10 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -281,6 +283,55 @@ def test_skip_exception_2():
assert "Input count is not within the required interval" in str(e.value)
def test_skip_with_generator_dataset_multi_process():
"""
Feature: Skip op
Description: Test skip op when using GeneratorDataset(..., num_parallel_workers=2, ...)
Expectation: Error is raised as expected
"""
# construct data and label
data1 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data2 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data3 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data4 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
label = [1, 2, 3, 4]
# load the data and label by NumpySlicesDataset
dataset = ds.NumpySlicesDataset(([data1, data2, data3, data4], label), ["data", "label"], num_parallel_workers=2)
dataset_train, dataset_val = dataset.split([0.5, 0.5])
# apply the transform to data
dataset_train = dataset_train.map(operations=vision.RandomCrop(size=(250, 250)), input_columns="data")
# apply the transform to label
dataset_train = dataset_train.map(operations=transforms.TypeCast(ms.int32), input_columns="label")
# batch
dataset_train = dataset_train.batch(batch_size=2)
# create iterator
epochs = 2
ds_iter = dataset_train.create_dict_iterator(output_numpy=True, num_epochs=epochs)
count = 0
for _ in range(epochs):
for item in ds_iter:
print("item: {}".format(item), flush=True)
count += 1
assert count == 2
# create val iterator
epochs = 2
ds_iter = dataset_val.create_dict_iterator(output_numpy=True, num_epochs=epochs)
count = 0
for _ in range(epochs):
for item in ds_iter:
print("item: {}".format(item), flush=True)
count += 1
assert count == 4
if __name__ == "__main__":
test_tf_skip()
@ -296,3 +347,4 @@ if __name__ == "__main__":
test_skip_filter_2()
test_skip_exception_1()
test_skip_exception_2()
test_skip_with_generator_dataset_multi_process()