!42349 fix: generator dataset with multi-process use .skip error
Merge pull request !42349 from guozhijian/fix_generator_dataset_split
This commit is contained in:
commit
d007a090ae
|
@ -175,8 +175,10 @@ class SamplerFn:
|
|||
|
||||
def __init__(self, dataset, num_worker, multi_process, max_rowsize):
|
||||
self.workers = []
|
||||
self.dataset = dataset
|
||||
self.num_worker = num_worker
|
||||
self.multi_process = multi_process
|
||||
self.max_rowsize = max_rowsize
|
||||
self.need_join = False
|
||||
self.ppid = os.getpid()
|
||||
self.pids = []
|
||||
|
@ -342,6 +344,9 @@ class SamplerFn:
|
|||
except TypeError:
|
||||
pass
|
||||
|
||||
def __deepcopy__(self, memodict, exclude=()):
|
||||
self.__init__(self.dataset, self.num_worker, self.multi_process, self.max_rowsize)
|
||||
|
||||
|
||||
def _subprocess_handle(eof, signum, frame):
|
||||
threading.Thread(target=eof.set()).start()
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.transforms as transforms
|
||||
|
||||
|
||||
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
|
@ -281,6 +283,55 @@ def test_skip_exception_2():
|
|||
assert "Input count is not within the required interval" in str(e.value)
|
||||
|
||||
|
||||
def test_skip_with_generator_dataset_multi_process():
|
||||
"""
|
||||
Feature: Skip op
|
||||
Description: Test skip op when using GeneratorDataset(..., num_parallel_workers=2, ...)
|
||||
Expectation: Error is raised as expected
|
||||
"""
|
||||
|
||||
# construct data and label
|
||||
data1 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
|
||||
data2 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
|
||||
data3 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
|
||||
data4 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
|
||||
|
||||
label = [1, 2, 3, 4]
|
||||
|
||||
# load the data and label by NumpySlicesDataset
|
||||
dataset = ds.NumpySlicesDataset(([data1, data2, data3, data4], label), ["data", "label"], num_parallel_workers=2)
|
||||
|
||||
dataset_train, dataset_val = dataset.split([0.5, 0.5])
|
||||
|
||||
# apply the transform to data
|
||||
dataset_train = dataset_train.map(operations=vision.RandomCrop(size=(250, 250)), input_columns="data")
|
||||
|
||||
# apply the transform to label
|
||||
dataset_train = dataset_train.map(operations=transforms.TypeCast(ms.int32), input_columns="label")
|
||||
|
||||
# batch
|
||||
dataset_train = dataset_train.batch(batch_size=2)
|
||||
|
||||
# create iterator
|
||||
epochs = 2
|
||||
ds_iter = dataset_train.create_dict_iterator(output_numpy=True, num_epochs=epochs)
|
||||
count = 0
|
||||
for _ in range(epochs):
|
||||
for item in ds_iter:
|
||||
print("item: {}".format(item), flush=True)
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
# create val iterator
|
||||
epochs = 2
|
||||
ds_iter = dataset_val.create_dict_iterator(output_numpy=True, num_epochs=epochs)
|
||||
count = 0
|
||||
for _ in range(epochs):
|
||||
for item in ds_iter:
|
||||
print("item: {}".format(item), flush=True)
|
||||
count += 1
|
||||
assert count == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tf_skip()
|
||||
|
@ -296,3 +347,4 @@ if __name__ == "__main__":
|
|||
test_skip_filter_2()
|
||||
test_skip_exception_1()
|
||||
test_skip_exception_2()
|
||||
test_skip_with_generator_dataset_multi_process()
|
||||
|
|
Loading…
Reference in New Issue