[MD]Enhance Python Multiprocessing test coverage

This commit is contained in:
Cathy Wong 2022-02-08 14:15:19 -05:00
parent e1c2c4c268
commit befe8f4837
4 changed files with 370 additions and 42 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ import numpy as np
import mindspore.dataset as ds
CIFAR10_DIR = "../data/dataset/testCifar10Data"
# This UT test tests the following cases
@ -109,8 +110,7 @@ def test_batch_padding_05():
def batch_padding_performance_3d():
cifar10_dir = "../data/dataset/testCifar10Data"
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
data1 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False) # shape = [32,32,3]
data1 = data1.repeat(24)
pad_info = {"image": ([36, 36, 3], 0)}
# pad_info = None
@ -124,8 +124,7 @@ def batch_padding_performance_3d():
def batch_padding_performance_1d():
cifar10_dir = "../data/dataset/testCifar10Data"
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
data1 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False) # shape = [32,32,3]
data1 = data1.repeat(24)
data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image")
pad_info = {"image": ([3888], 0)} # 3888 =36*36*3
@ -140,8 +139,7 @@ def batch_padding_performance_1d():
def batch_pyfunc_padding_3d():
cifar10_dir = "../data/dataset/testCifar10Data"
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
data1 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False) # shape = [32,32,3]
data1 = data1.repeat(24)
# pad_info = {"image": ([36, 36, 3], 0)}
data1 = data1.map(operations=(lambda x: np.pad(x, ((0, 4), (0, 4), (0, 0)))), input_columns="image",
@ -156,8 +154,7 @@ def batch_pyfunc_padding_3d():
def batch_pyfunc_padding_1d():
cifar10_dir = "../data/dataset/testCifar10Data"
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
data1 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False) # shape = [32,32,3]
data1 = data1.repeat(24)
data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image")
data1 = data1.map(operations=(lambda x: np.pad(x, (0, 816))), input_columns="image", python_multiprocessing=False)
@ -170,29 +167,36 @@ def batch_pyfunc_padding_1d():
# print(res)
# this function runs pad_batch and numpy.pad then compare the results
def pad_map_config(my_num_workers=None, py_multiproc=False, my_max_rowsize=16):
data1 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False, num_samples=1000) # shape = [32,32,3]
data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image",
num_parallel_workers=my_num_workers, python_multiprocessing=py_multiproc,
max_rowsize=my_max_rowsize) # reshape to 1d
data1 = data1.map(operations=(lambda x: np.pad(x, (0, 816))), input_columns="image",
num_parallel_workers=my_num_workers, python_multiprocessing=py_multiproc,
max_rowsize=my_max_rowsize)
data1 = data1.batch(batch_size=25, drop_remainder=True)
res = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(data["image"])
return res
def pad_batch_config():
data2 = ds.Cifar10Dataset(CIFAR10_DIR, shuffle=False, num_samples=1000) # shape = [32,32,3]
data2 = data2.map(operations=(lambda x: x.reshape(-1)), input_columns="image") # reshape to 1d
data2 = data2.batch(batch_size=25, drop_remainder=True, pad_info={"image": ([3888], 0)})
res = []
for data in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(data["image"])
return res
def test_pad_via_map():
cifar10_dir = "../data/dataset/testCifar10Data"
def pad_map_config():
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image") # reshape to 1d
data1 = data1.map(operations=(lambda x: np.pad(x, (0, 816))), input_columns="image")
data1 = data1.batch(batch_size=25, drop_remainder=True)
res = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(data["image"])
return res
def pad_batch_config():
data2 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
data2 = data2.map(operations=(lambda x: x.reshape(-1)), input_columns="image") # reshape to 1d
data2 = data2.batch(batch_size=25, drop_remainder=True, pad_info={"image": ([3888], 0)})
res = []
for data in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(data["image"])
return res
"""
Feature: Batch Padding
Description: Compare results for pad_batch versus numpy.pad
Expectation: pad_batch and numpy.pad results are the same
"""
res_from_map = pad_map_config()
res_from_batch = pad_batch_config()
assert len(res_from_batch) == len(res_from_batch)
@ -200,6 +204,21 @@ def test_pad_via_map():
np.testing.assert_array_equal(res_from_map[i], res_from_batch[i])
def test_pad_via_map_multiproc():
"""
Feature: Batch Padding
Description: Compare results for pad_batch versus numpy.pad, with multiprocessing for map
Expectation: pad_batch and numpy.pad results are the same
"""
# Note: Reduce shared memory needed (for CI) by using small num_parallel_workers and max_rowsize values
res_from_map = pad_map_config(2, True, 1)
res_from_batch = pad_batch_config()
assert len(res_from_batch) == len(res_from_batch)
for i, _ in enumerate(res_from_map):
np.testing.assert_array_equal(res_from_map[i], res_from_batch[i])
if __name__ == '__main__':
test_batch_padding_01()
test_batch_padding_02()
@ -211,3 +230,4 @@ if __name__ == '__main__':
# batch_pyfunc_padding_3d()
# batch_pyfunc_padding_1d()
test_pad_via_map()
test_pad_via_map_multiproc()

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -208,6 +208,7 @@ def test_case_7():
ds.config.set_enable_shared_mem(mem_original)
def test_case_8():
"""
Test PyFunc
@ -241,6 +242,7 @@ def test_case_8():
ds.config.set_enable_shared_mem(mem_original)
def test_case_9():
"""
Test PyFunc
@ -266,6 +268,7 @@ def test_case_9():
ds.config.set_enable_shared_mem(mem_original)
def test_case_10():
"""
Test PyFunc
@ -293,6 +296,7 @@ def test_case_10():
ds.config.set_enable_shared_mem(mem_original)
def test_pyfunc_implicit_compose():
"""
Test Implicit Compose with pyfunc
@ -326,7 +330,7 @@ def test_pyfunc_exception():
# and cause core dump and blocking in this UT. Add cleanup() here to fix it.
it._cleanup() # pylint: disable=W0212
def pyfunc(x):
def pyfunc():
raise Exception("Pyfunc Throw")
with pytest.raises(RuntimeError) as info:
@ -339,12 +343,21 @@ def test_pyfunc_exception():
assert "Pyfunc Throw" in str(info.value)
def skip_test_pyfunc_exception_multiprocess():
def test_pyfunc_exception_multiprocess():
"""
Feature: PyFunc in Map op
Description: Test python_multiprocessing=True with exception in child pyfunc process
Expectation: Exception is received and test ends gracefully
"""
logger.info("Test Multiprocess PyFunc Exception Throw: lambda x : raise Exception()")
def pyfunc(x):
def pyfunc():
raise Exception("MP Pyfunc Throw")
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
@ -354,6 +367,8 @@ def skip_test_pyfunc_exception_multiprocess():
pass
assert "MP Pyfunc Throw" in str(info.value)
ds.config.set_enable_shared_mem(mem_original)
def test_func_with_yield_manifest_dataset_01():
def pass_func(_):
@ -382,6 +397,7 @@ def test_func_mixed_with_ops():
Description: will decrease num_parallel_worker into 1
Expectation: success
"""
def generator_func():
for i in range(1, 5):
yield (np.ones(shape=[2, i]),)
@ -417,6 +433,6 @@ if __name__ == "__main__":
test_case_10()
test_pyfunc_implicit_compose()
test_pyfunc_exception()
skip_test_pyfunc_exception_multiprocess()
test_pyfunc_exception_multiprocess()
test_func_with_yield_manifest_dataset_01()
test_func_mixed_with_ops()

View File

@ -0,0 +1,269 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Python Multiprocessing with Python functions/ops
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.vision.py_transforms as py_vision
from util import visualize_list
MNIST_DATA_DIR = "../data/dataset/testMnistData"
TF_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
TF_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
PYFUNCMAP_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
PYFUNCMAP_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
def test_pyfunc_multiproc_shrmem():
"""
Feature: PyFunc in Map op
Description: Test python_multiprocessing=True with shared memory enabled
Expectation: Data results are correct
"""
def pyfunc(x):
return x
# Confirm shared memory optimization is enabled by default
mem_original = ds.config.get_enable_shared_mem()
assert mem_original
# Reduce memory needed by reducing queue size
prefetch_original = ds.config.get_prefetch_size()
ds.config.set_prefetch_size(1)
max_elements = 2000
np_data = list(range(0, max_elements))
data1 = ds.NumpySlicesDataset(np_data, shuffle=False)
data1 = data1.map(pyfunc, num_parallel_workers=8, python_multiprocessing=True, max_rowsize=1)
for i, data in enumerate(data1):
np.testing.assert_equal(data[0].asnumpy(), np_data[i])
assert data1.get_dataset_size() == max_elements
ds.config.set_prefetch_size(prefetch_original)
def create_dataset_pyop_multiproc(num_parallel_workers=None, max_rowsize=16, batch_size=32, repeat_size=1,
num_samples=None):
"""
Create dataset with Python ops list and python_multiprocessing=True for Map op
"""
# Define dataset
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=num_samples)
data1 = data1.map(operations=[py_vision.ToType(np.int32)], input_columns="label",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True, max_rowsize=max_rowsize)
# Setup transforms list which include Python ops
transforms_list = [
py_vision.ToTensor(),
lambda x: x,
py_vision.HWC2CHW(),
py_vision.RandomErasing(0.9, value='random'),
py_vision.Cutout(4, 2),
lambda y: y
]
compose_op = py_transforms.Compose(transforms_list)
data1 = data1.map(operations=compose_op, input_columns="image", num_parallel_workers=num_parallel_workers,
python_multiprocessing=True, max_rowsize=max_rowsize)
# Apply Dataset Ops
buffer_size = 10000
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.batch(batch_size, drop_remainder=True)
data1 = data1.repeat(repeat_size)
return data1
def test_pyfunc_multiproc_noshrmem():
"""
Feature: Python Multiprocessing
Description: Test Map op with python_multiprocessing=True
Expectation: Number of return data rows is correct
"""
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
mydata1 = create_dataset_pyop_multiproc(num_parallel_workers=12, repeat_size=2)
mycount1 = 0
for _ in mydata1.create_dict_iterator(num_epochs=1):
mycount1 += 1
assert mycount1 == 624
ds.config.set_enable_shared_mem(mem_original)
def test_pyfunc_multiproc_max_rowsize_small():
"""
Feature: Python Multiprocessing
Description: Test Map op with python_multiprocessing=True and max_rowsize=1 (less than default of 16)
Expectation: Number of return data rows is correct
"""
# Reduce memory needed by reducing queue size
prefetch_original = ds.config.get_prefetch_size()
ds.config.set_prefetch_size(1)
mydata1 = create_dataset_pyop_multiproc(num_parallel_workers=2, max_rowsize=1, num_samples=500)
mycount1 = 0
for _ in mydata1.create_dict_iterator(num_epochs=1):
mycount1 += 1
assert mycount1 == 15
ds.config.set_prefetch_size(prefetch_original)
def test_pyfunc_multiproc_max_rowsize_large():
"""
Feature: Python Multiprocessing
Description: Test Map op with python_multiprocessing=True and max_rowsize=20 (more than default of 16)
Expectation: Number of return data rows is correct
"""
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
mydata1 = create_dataset_pyop_multiproc(num_parallel_workers=4, max_rowsize=20, num_samples=500)
mycount1 = 0
for _ in mydata1.create_dict_iterator(num_epochs=1):
mycount1 += 1
assert mycount1 == 15
ds.config.set_enable_shared_mem(mem_original)
def test_pyfunc_multiproc_basic_pipeline(plot=False):
"""
Feature: Python Multiprocessing
Description: Test Map op with python_multiprocessing=True in a basic pipeline with Py ops
Expectation: Images in plots from the 2 pipelines are visually fine
"""
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# Define map operations
transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)]
transforms1 = [
py_vision.Decode(),
py_transforms.RandomChoice(transforms_list),
py_vision.ToTensor()
]
transform1 = py_transforms.Compose(transforms1)
transforms2 = [
py_vision.Decode(),
py_vision.ToTensor()
]
transform2 = py_transforms.Compose(transforms2)
# First dataset
data1 = ds.TFRecordDataset(TF_DATA_DIR, TF_SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(operations=transform1, input_columns=["image"], num_parallel_workers=2,
python_multiprocessing=True, max_rowsize=1)
# Second dataset
data2 = ds.TFRecordDataset(TF_DATA_DIR, TF_SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(operations=transform2, input_columns=["image"])
image_choice = []
image_original = []
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
image_choice.append(image1)
image_original.append(image2)
if plot:
visualize_list(image_original, image_choice)
ds.config.set_enable_shared_mem(mem_original)
def test_pyfunc_multiproc_child_exception():
"""
Feature: Python Multiprocessing
Description: Test Map op with python_multiprocessing=True with Python op encountering exception
Expectation: Exception is correctly processed
"""
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# Define map operations
# Note: crop size[5000, 5000] > image size[4032, 2268]
transforms_list = [py_vision.RandomCrop(5000)]
transforms = [
py_vision.Decode(),
py_transforms.RandomChoice(transforms_list),
py_vision.ToTensor()
]
transform = py_transforms.Compose(transforms)
# Generate dataset
data = ds.TFRecordDataset(TF_DATA_DIR, TF_SCHEMA_DIR, columns_list=["image"], shuffle=False)
data = data.map(operations=transform, input_columns=["image"], python_multiprocessing=True)
# Note: Expect error raised with RandomCrop input: crop size greater than image size
with pytest.raises(RuntimeError) as info:
data.create_dict_iterator(num_epochs=1).__next__()
assert "Crop size" in str(info.value)
ds.config.set_enable_shared_mem(mem_original)
def test_pyfunc_multiproc_mainproc_exception():
"""
Feature: PyFunc in Map op
Description: Test python_multiprocessing=True with exception in main process
Expectation: Exception is received and test ends gracefully
"""
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# Apply dataset operations
data1 = ds.TFRecordDataset(PYFUNCMAP_DATA_DIR, PYFUNCMAP_SCHEMA_DIR, shuffle=False)
data1 = data1.map(operations=(lambda x: x + x), input_columns="col0", output_columns="out",
python_multiprocessing=True)
with pytest.raises(ZeroDivisionError) as info:
i = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
i = i + 4
if i > 8:
# Cause division by zero error
_ = i / 0
assert "division by zero" in str(info.value)
ds.config.set_enable_shared_mem(mem_original)
if __name__ == '__main__':
test_pyfunc_multiproc_shrmem()
test_pyfunc_multiproc_noshrmem()
test_pyfunc_multiproc_max_rowsize_small()
test_pyfunc_multiproc_max_rowsize_large()
test_pyfunc_multiproc_basic_pipeline(plot=True)
test_pyfunc_multiproc_child_exception()
test_pyfunc_multiproc_mainproc_exception()

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -55,9 +55,16 @@ def test_batch_corner_cases():
assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n"
# each sub-test in this function is tested twice with exact parameter except that the second test passes each row
# to a pyfunc which makes a deep copy of the row
def test_variable_size_batch():
"""
Feature: Batch
Description: Test batch variations with repeat and with/without per_batch_map.
Each sub-test is tested with same parameters except that
- the second test uses per_batch_map which passes each row a pyfunc and makes a deep copy of the row
- the third test (if it exists) uses per_batch_map and python multiprocessing
Expectation: Results are the same, independent of per_batch_map or python_multiprocessing settings
"""
def check_res(arr1, arr2):
for ind, _ in enumerate(arr1):
if not np.array_equal(arr1[ind], np.array(arr2[ind])):
@ -108,6 +115,18 @@ def test_variable_size_batch():
res.append(item["num"])
return res
# same as test_batch_repeat_with_copy_map except with python multiprocessing enabled
def test_batch_repeat_with_copy_map_multiproc(gen_num, r, drop, func, num_workers, my_maxrowsize):
res = []
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"], num_parallel_workers=num_workers,
python_multiprocessing=True, max_rowsize=my_maxrowsize) \
.batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy,
num_parallel_workers=num_workers, python_multiprocessing=True,
max_rowsize=my_maxrowsize).repeat(r)
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(item["num"])
return res
tst1, tst2, tst3, tst4, tst5, tst6, tst7 = [], [], [], [], [], [], []
# no repeat, simple var size, based on batch_num
@ -140,6 +159,10 @@ def test_variable_size_batch():
assert check_res(tst7, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], [[3]],
[[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + str(tst7)
assert check_res(tst7, test_batch_repeat_with_copy_map(4, 4, False, add_one_by_epoch)), "\nMAP FAILED\n"
assert check_res(tst7, test_batch_repeat_with_copy_map_multiproc(
4, 4, False, add_one_by_epoch, 4, 1)), "\nMULTIPROC1 MAP FAILED\n"
assert check_res(tst7, test_batch_repeat_with_copy_map_multiproc(
4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
def test_basic_batch_map():
@ -369,11 +392,11 @@ def test_multi_col_map():
# test exceptions
assert "output_columns with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], 233)
assert "column_order with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], ["col1"], 233)
assert "columns that are not involved in 'per_batch_map' should not be in output_columns"\
assert "columns that are not involved in 'per_batch_map' should not be in output_columns" \
in batch_map_config(2, 2, split_col, ["col2"], ["col1"])
assert "the number of columns returned in 'per_batch_map' function should be 3"\
assert "the number of columns returned in 'per_batch_map' function should be 3" \
in batch_map_config(2, 2, split_col, ["col2"], ["col3", "col4", "col5"])
assert "'col-1' of 'input_columns' doesn't exist"\
assert "'col-1' of 'input_columns' doesn't exist" \
in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])