forked from mindspore-Ecosystem/mindspore
!6301 [MD] Combine c++ and python ops in map
Merge pull request !6301 from nhussain/c_py_compose
This commit is contained in:
commit
2f3add4acd
|
@ -37,6 +37,9 @@ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp
|
|||
from mindspore._c_expression import typing
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
|
@ -411,7 +414,7 @@ class Dataset:
|
|||
return dataset
|
||||
|
||||
@check_map
|
||||
def map(self, operations=None, input_columns=None, output_columns=None, column_order=None,
|
||||
def map(self, operations, input_columns=None, output_columns=None, column_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||
"""
|
||||
Apply each operation in operations to this dataset.
|
||||
|
@ -432,7 +435,7 @@ class Dataset:
|
|||
Args:
|
||||
operations (Union[list[TensorOp], list[functions]]): List of operations to be
|
||||
applied on the dataset. Operations are applied in the order they appear in this list.
|
||||
input_columns (list[str]): List of the names of the columns that will be passed to
|
||||
input_columns (list[str], optional): List of the names of the columns that will be passed to
|
||||
the first operation as input. The size of this list must match the number of
|
||||
input columns expected by the first operator. (default=None, the first
|
||||
operation will be passed however many columns that is required, starting from
|
||||
|
@ -2061,8 +2064,25 @@ class MapDataset(DatasetOp):
|
|||
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.children.append(input_dataset)
|
||||
if operations is not None and not isinstance(operations, list):
|
||||
operations = [operations]
|
||||
if operations is not None:
|
||||
if not isinstance(operations, list):
|
||||
operations = [operations]
|
||||
elif isinstance(operations, list) and len(operations) > 1:
|
||||
# wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
|
||||
new_ops, start_ind, end_ind = [], 0, 0
|
||||
for i, op in enumerate(operations):
|
||||
if not callable(op):
|
||||
# reset counts
|
||||
if start_ind != end_ind:
|
||||
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
|
||||
new_ops.append(op)
|
||||
start_ind, end_ind = i + 1, i + 1
|
||||
else:
|
||||
end_ind += 1
|
||||
# do additional check in case the last operation is a Python operation
|
||||
if start_ind != end_ind:
|
||||
new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
|
||||
operations = new_ops
|
||||
self.operations = operations
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
input_columns = [input_columns]
|
||||
|
|
|
@ -86,6 +86,38 @@ class Compose:
|
|||
>>> py_vision.RandomErasing()])
|
||||
>>> # apply the transform to the dataset through dataset.map()
|
||||
>>> dataset = dataset.map(operations=transform, input_columns="image")
|
||||
>>>
|
||||
>>> # Compose is also be invoked implicitly, by just passing in a list of ops
|
||||
>>> # the above example then becomes:
|
||||
>>> transform_list = [py_vision.Decode(),
|
||||
>>> py_vision.RandomHorizontalFlip(0.5),
|
||||
>>> py_vision.ToTensor(),
|
||||
>>> py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
|
||||
>>> py_vision.RandomErasing()]
|
||||
>>>
|
||||
>>> # apply the transform to the dataset through dataset.map()
|
||||
>>> dataset = dataset.map(operations=transform, input_columns="image")
|
||||
>>>
|
||||
>>> # Certain C++ and Python ops can be combined, but not all of them
|
||||
>>> # An example of combined operations
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
>>> import mindspore.dataset.vision.c_transforms as c_vision
|
||||
>>>
|
||||
>>> data = ds.NumpySlicesDataset(arr, column_names=["cols"], shuffle=False)
|
||||
>>> transformed_list = [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]
|
||||
>>> data = data.map(operations=op_list, input_columns=["cols"])
|
||||
>>>
|
||||
>>> # Here is an example of mixing vision ops
|
||||
>>> data_dir = "/path/to/imagefolder_directory"
|
||||
>>> data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
>>> input_columns = ["column_names"]
|
||||
>>> data1 = data1.map(operations=op_list, input_columns=input_columns)
|
||||
>>> op_list=[c_vision.Decode(),
|
||||
>>> c_vision.Resize((224, 244)),
|
||||
>>> py_vision.ToPIL(),
|
||||
>>> np.array, # need to convert PIL image to a NumPy array to pass it to C++ operation
|
||||
>>> c_vision.Resize((24, 24))]
|
||||
"""
|
||||
|
||||
@check_compose_list
|
||||
|
@ -93,14 +125,14 @@ class Compose:
|
|||
self.transforms = transforms
|
||||
|
||||
@check_compose_call
|
||||
def __call__(self, img):
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Returns:
|
||||
lambda function, Lambda function that takes in an img to apply transformations on.
|
||||
lambda function, Lambda function that takes in an args to apply transformations on.
|
||||
"""
|
||||
return util.compose(img, self.transforms)
|
||||
return util.compose(self.transforms, *args)
|
||||
|
||||
|
||||
class RandomApply:
|
||||
|
|
|
@ -21,7 +21,17 @@ import numpy as np
|
|||
from ..core.py_util_helpers import is_numpy
|
||||
|
||||
|
||||
def compose(img, transforms):
|
||||
def all_numpy(args):
|
||||
""" for multi-input lambdas"""
|
||||
if isinstance(args, tuple):
|
||||
for value in args:
|
||||
if not is_numpy(value):
|
||||
return False
|
||||
return True
|
||||
return is_numpy(args)
|
||||
|
||||
|
||||
def compose(transforms, *args):
|
||||
"""
|
||||
Compose a list of transforms and apply on the image.
|
||||
|
||||
|
@ -32,13 +42,15 @@ def compose(img, transforms):
|
|||
Returns:
|
||||
img (numpy.ndarray), An augmented image in Numpy ndarray.
|
||||
"""
|
||||
if is_numpy(img):
|
||||
if all_numpy(args):
|
||||
for transform in transforms:
|
||||
img = transform(img)
|
||||
if is_numpy(img):
|
||||
return img
|
||||
raise TypeError('img should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(img)))
|
||||
raise TypeError('img should be Numpy ndarray. Got {}.'.format(type(img)))
|
||||
args = transform(*args)
|
||||
args = (args,) if not isinstance(args, tuple) else args
|
||||
|
||||
if all_numpy(args):
|
||||
return args
|
||||
raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(args)))
|
||||
raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args)))
|
||||
|
||||
|
||||
def one_hot_encoding(label, num_classes, epsilon):
|
||||
|
|
|
@ -213,6 +213,9 @@ def check_compose_list(method):
|
|||
type_check(transforms, (list,), transforms)
|
||||
if not transforms:
|
||||
raise ValueError("transforms list is empty.")
|
||||
for i, transfrom in enumerate(transforms):
|
||||
if not callable(transfrom):
|
||||
raise ValueError("transforms[{}] is not callable.".format(i))
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -225,11 +228,10 @@ def check_compose_call(method):
|
|||
def new_method(self, *args, **kwargs):
|
||||
sig = inspect.signature(method)
|
||||
ba = sig.bind_partial(method, *args, **kwargs)
|
||||
img = ba.arguments.get("img")
|
||||
img = ba.arguments.get("args")
|
||||
if img is None:
|
||||
raise TypeError(
|
||||
"Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -243,6 +245,10 @@ def check_random_apply(method):
|
|||
[transforms, prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(transforms, (list,), "transforms")
|
||||
|
||||
for i, transfrom in enumerate(transforms):
|
||||
if not callable(transfrom):
|
||||
raise ValueError("transforms[{}] is not callable.".format(i))
|
||||
|
||||
if prob is not None:
|
||||
type_check(prob, (float, int,), "prob")
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
|
@ -260,7 +266,9 @@ def check_transforms_list(method):
|
|||
[transforms], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(transforms, (list,), "transforms")
|
||||
|
||||
for i, transfrom in enumerate(transforms):
|
||||
if not callable(transfrom):
|
||||
raise ValueError("transforms[{}] is not callable.".format(i))
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -622,9 +622,9 @@ class FiveCrop:
|
|||
>>> from mindspore.dataset.transforms.py_transforms import Compose
|
||||
>>>
|
||||
>>> Compose([py_vision.Decode(),
|
||||
>>> py_vision.FiveCrop(size),
|
||||
>>> py_vision.FiveCrop(size=200),
|
||||
>>> # 4D stack of 5 images
|
||||
>>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
|
||||
>>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
|
||||
"""
|
||||
|
||||
@check_crop
|
||||
|
@ -662,9 +662,9 @@ class TenCrop:
|
|||
>>> from mindspore.dataset.transforms.py_transforms import Compose
|
||||
>>>
|
||||
>>> Compose([py_vision.Decode(),
|
||||
>>> py_vision.TenCrop(size),
|
||||
>>> py_vision.TenCrop(size=200),
|
||||
>>> # 4D stack of 10 images
|
||||
>>> lambda images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
|
||||
>>> lambda *images: numpy.stack([py_vision.ToTensor()(image) for image in images])])
|
||||
"""
|
||||
|
||||
@check_ten_crop
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -13,11 +13,19 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
import mindspore.dataset.transforms.py_transforms as py_ops
|
||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
|
||||
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_compose():
|
||||
|
@ -38,35 +46,294 @@ def test_compose():
|
|||
return str(e)
|
||||
|
||||
# Test simple compose with only 1 op, this would generate a warning
|
||||
assert test_config([[1, 0], [3, 4]], ops.Compose([ops.Fill(2)])) == [[2, 2], [2, 2]]
|
||||
assert test_config([[1, 0], [3, 4]], c_transforms.Compose([c_transforms.Fill(2)])) == [[2, 2], [2, 2]]
|
||||
|
||||
# Test 1 column -> 2 columns -> 1 -> 2 -> 1
|
||||
assert test_config([[1, 0]],
|
||||
ops.Compose([ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()])) \
|
||||
c_transforms.Compose(
|
||||
[c_transforms.Duplicate(), c_transforms.Concatenate(), c_transforms.Duplicate(),
|
||||
c_transforms.Concatenate()])) \
|
||||
== [[1, 0] * 4]
|
||||
# Test one Python transform followed by a C transform. Type after OneHot is a float (mixed use-case)
|
||||
assert test_config([1, 0], ops.Compose([py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)])) == [[[0, 1]], [[1, 0]]]
|
||||
|
||||
# Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case)
|
||||
assert test_config([1, 0],
|
||||
c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \
|
||||
== [[[0, 1]], [[1, 0]]]
|
||||
|
||||
# Test exceptions.
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
ops.Compose([1, ops.TypeCast(mstype.int32)])
|
||||
c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
|
||||
assert "op_list[0] is not a c_transform op (TensorOp) nor a callable pyfunc." in str(error_info.value)
|
||||
|
||||
# Test empty op list
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config([1, 0], ops.Compose([]))
|
||||
test_config([1, 0], c_transforms.Compose([]))
|
||||
assert "op_list can not be empty." in str(error_info.value)
|
||||
|
||||
# Test Python compose op
|
||||
assert test_config([1, 0], py_ops.Compose([py_ops.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]]
|
||||
assert test_config([1, 0], py_ops.Compose([py_ops.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]], [[2, 0]]]
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]]
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]],
|
||||
[[2, 0]]]
|
||||
|
||||
# Test nested Python compose op
|
||||
assert test_config([1, 0],
|
||||
py_ops.Compose([py_ops.Compose([py_ops.OneHotOp(2)]), (lambda x: x + x)])) \
|
||||
py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \
|
||||
== [[[0, 2]], [[2, 0]]]
|
||||
|
||||
# Test passing a list of Python ops without Compose wrapper
|
||||
assert test_config([1, 0],
|
||||
[py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \
|
||||
== [[[0, 2]], [[2, 0]]]
|
||||
assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[[0, 2]], [[2, 0]]]
|
||||
|
||||
# Test a non callable function
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
py_transforms.Compose([1])
|
||||
assert "transforms[0] is not callable." in str(error_info.value)
|
||||
|
||||
# Test empty Python op list
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config([1, 0], py_transforms.Compose([]))
|
||||
assert "transforms list is empty." in str(error_info.value)
|
||||
|
||||
# Pass in extra brackets
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
py_ops.Compose([(lambda x: x + x)])()
|
||||
py_transforms.Compose([(lambda x: x + x)])()
|
||||
assert "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])())." in str(
|
||||
error_info.value)
|
||||
|
||||
|
||||
def test_lambdas():
|
||||
"""
|
||||
Test Multi Column Python Compose Op
|
||||
"""
|
||||
ds.config.set_seed(0)
|
||||
|
||||
def test_config(arr, input_columns, output_cols, op_list):
|
||||
data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
|
||||
data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
|
||||
column_order=output_cols)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
||||
arr = ([[1]], [[3]])
|
||||
|
||||
assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([(lambda x, y: x)])) == [[1]]
|
||||
assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([lambda x, y: x, lambda x: x])) == [[1]]
|
||||
assert test_config(arr, ["col0", "col1"], ["a", "b"],
|
||||
py_transforms.Compose([lambda x, y: x, lambda x: (x, x * 2)])) == \
|
||||
[[1], [2]]
|
||||
assert test_config(arr, ["col0", "col1"], ["a", "b"],
|
||||
[lambda x, y: (x, x + y), lambda x, y: (x, y * 2)]) == [[1], [8]]
|
||||
|
||||
|
||||
def test_c_py_compose_transforms_module():
|
||||
"""
|
||||
Test combining Python and C++ transforms
|
||||
"""
|
||||
ds.config.set_seed(0)
|
||||
|
||||
def test_config(arr, input_columns, output_cols, op_list):
|
||||
data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
|
||||
data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
|
||||
column_order=output_cols)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
||||
arr = [1, 0]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \
|
||||
[[[False, True]],
|
||||
[[True, False]]]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \
|
||||
== [[[1, 1]], [[1, 1]]]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \
|
||||
== [[[2, 2]], [[2, 2]]]
|
||||
assert test_config([[1, 3]], ["cols"], ["cols"],
|
||||
[c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \
|
||||
== [[2, 6, -2]]
|
||||
|
||||
arr = ([[1]], [[3]])
|
||||
assert test_config(arr, ["col0", "col1"], ["a"], [(lambda x, y: x + y), c_transforms.PadEnd([2], -1)]) == [[4, -1]]
|
||||
|
||||
|
||||
def test_c_py_compose_vision_module(plot=False, run_golden=True):
|
||||
"""
|
||||
Test combining Python and C++ vision transforms
|
||||
"""
|
||||
original_seed = config_get_set_seed(10)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
def test_config(plot, file_name, op_list):
|
||||
data_dir = "../data/dataset/testImageNetData/train/"
|
||||
data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data1 = data1.map(operations=op_list, input_columns=["image"])
|
||||
data2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data2 = data2.map(operations=c_vision.Decode(), input_columns=["image"])
|
||||
original_images = []
|
||||
transformed_images = []
|
||||
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
transformed_images.append(item["image"])
|
||||
for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
original_images.append(item["image"])
|
||||
|
||||
if run_golden:
|
||||
# Compare with expected md5 from images
|
||||
save_and_check_md5(data1, file_name, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(original_images, transformed_images)
|
||||
|
||||
test_config(op_list=[c_vision.Decode(),
|
||||
py_vision.ToPIL(),
|
||||
py_vision.Resize((224, 224)),
|
||||
np.array],
|
||||
plot=plot, file_name="compose_c_py_1.npz")
|
||||
|
||||
test_config(op_list=[c_vision.Decode(),
|
||||
c_vision.Resize((224, 244)),
|
||||
py_vision.ToPIL(),
|
||||
np.array,
|
||||
c_vision.Resize((24, 24))],
|
||||
plot=plot, file_name="compose_c_py_2.npz")
|
||||
|
||||
test_config(op_list=[py_vision.Decode(),
|
||||
py_vision.Resize((224, 224)),
|
||||
np.array,
|
||||
c_vision.RandomColor()],
|
||||
plot=plot, file_name="compose_c_py_3.npz")
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers((original_num_parallel_workers))
|
||||
|
||||
|
||||
def test_py_transforms_with_c_vision():
|
||||
"""
|
||||
These examples will fail, as py_transforms.Random(Apply/Choice/Order) expect callable functions
|
||||
"""
|
||||
|
||||
ds.config.set_seed(0)
|
||||
|
||||
def test_config(op_list):
|
||||
data_dir = "../data/dataset/testImageNetData/train/"
|
||||
data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data = data.map(operations=op_list)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomApply([c_vision.Resize(200)]))
|
||||
assert "transforms[0] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomChoice([c_vision.Resize(200)]))
|
||||
assert "transforms[0] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
test_config(py_transforms.RandomOrder([np.array, c_vision.Resize(200)]))
|
||||
assert "transforms[1] is not callable." in str(error_info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
test_config([py_transforms.OneHotOp(20, 0.1)])
|
||||
assert "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" in str(
|
||||
error_info.value)
|
||||
|
||||
|
||||
def test_py_vision_with_c_transforms():
|
||||
"""
|
||||
Test combining Python vision operations with C++ transforms operations
|
||||
"""
|
||||
|
||||
ds.config.set_seed(0)
|
||||
|
||||
def test_config(op_list):
|
||||
data_dir = "../data/dataset/testImageNetData/train/"
|
||||
data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data1 = data1.map(operations=op_list, input_columns=["image"])
|
||||
transformed_images = []
|
||||
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
transformed_images.append(item["image"])
|
||||
return transformed_images
|
||||
|
||||
# Test with Mask Op
|
||||
output_arr = test_config([py_vision.Decode(),
|
||||
py_vision.CenterCrop((2)), np.array,
|
||||
c_transforms.Mask(c_transforms.Relational.GE, 100)])
|
||||
|
||||
exp_arr = [np.array([[[True, False, False],
|
||||
[True, False, False]],
|
||||
[[True, False, False],
|
||||
[True, False, False]]]),
|
||||
np.array([[[True, False, False],
|
||||
[True, False, False]],
|
||||
[[True, False, False],
|
||||
[True, False, False]]])]
|
||||
|
||||
for exp_a, output in zip(exp_arr, output_arr):
|
||||
np.testing.assert_array_equal(exp_a, output)
|
||||
|
||||
# Test with Fill Op
|
||||
output_arr = test_config([py_vision.Decode(),
|
||||
py_vision.CenterCrop((4)), np.array,
|
||||
c_transforms.Fill(10)])
|
||||
|
||||
exp_arr = [np.ones((4, 4, 3)) * 10] * 2
|
||||
for exp_a, output in zip(exp_arr, output_arr):
|
||||
np.testing.assert_array_equal(exp_a, output)
|
||||
|
||||
# Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors.
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
test_config([py_vision.Decode(),
|
||||
py_vision.CenterCrop((2)), np.array,
|
||||
c_transforms.Concatenate(0)])
|
||||
assert "Only 1D tensors supported" in str(error_info.value)
|
||||
|
||||
|
||||
def test_compose_with_custom_function():
|
||||
"""
|
||||
Test Python Compose with custom function
|
||||
"""
|
||||
|
||||
def custom_function(x):
|
||||
return (x, x * x)
|
||||
|
||||
# First dataset
|
||||
op_list = [
|
||||
lambda x: x * 3,
|
||||
custom_function,
|
||||
# convert two column output to one
|
||||
lambda *images: np.stack(images)
|
||||
]
|
||||
|
||||
data = ds.NumpySlicesDataset([[1, 2]], column_names=["col0"], shuffle=False)
|
||||
data = data.map(input_columns=["col0"], operations=op_list)
|
||||
#
|
||||
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
res.append(i["col0"].tolist())
|
||||
assert res == [[[3, 6], [9, 36]]]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compose()
|
||||
test_lambdas()
|
||||
test_c_py_compose_transforms_module()
|
||||
test_c_py_compose_vision_module(plot=True)
|
||||
test_py_transforms_with_c_vision()
|
||||
test_py_vision_with_c_transforms()
|
||||
test_compose_with_custom_function()
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_five_crop_op(plot=False):
|
|||
transforms_2 = [
|
||||
vision.Decode(),
|
||||
vision.FiveCrop(200),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
|
||||
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
|
||||
]
|
||||
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
|
||||
data2 = data2.map(operations=transform_2, input_columns=["image"])
|
||||
|
@ -91,7 +91,7 @@ def test_five_crop_error_msg():
|
|||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in data:
|
||||
pass
|
||||
error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
|
||||
error_msg = "TypeError: __call__() takes 2 positional arguments but 6 were given"
|
||||
|
||||
# error msg comes from ToTensor()
|
||||
assert error_msg in str(info.value)
|
||||
|
@ -108,7 +108,7 @@ def test_five_crop_md5():
|
|||
transforms = [
|
||||
vision.Decode(),
|
||||
vision.FiveCrop(100),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
|
||||
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 5 images
|
||||
]
|
||||
transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
|
||||
data = data.map(operations=transform, input_columns=["image"])
|
||||
|
|
|
@ -250,6 +250,32 @@ def test_case_9():
|
|||
i = i + 4
|
||||
|
||||
|
||||
def test_pyfunc_implicit_compose():
|
||||
"""
|
||||
Test Implicit Compose with pyfunc
|
||||
"""
|
||||
logger.info("Test n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
|
||||
|
||||
col = ["col0", "col1"]
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
data1 = data1.map(operations=[(lambda x, y: (x, x + y, x + y + 1)), (lambda x, y, z: (x, y, z))], input_columns=col,
|
||||
output_columns=["out0", "out1", "out2"], column_order=["out0", "out1", "out2"])
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# In this test, the dataset is 2x2 sequential tensors
|
||||
golden = np.array([[i, i + 1], [i + 2, i + 3]])
|
||||
np.testing.assert_array_equal(item["out0"], golden)
|
||||
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
|
||||
np.testing.assert_array_equal(item["out1"], golden)
|
||||
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
|
||||
np.testing.assert_array_equal(item["out2"], golden)
|
||||
i = i + 4
|
||||
|
||||
|
||||
def test_pyfunc_execption():
|
||||
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
|
||||
|
||||
|
@ -293,5 +319,6 @@ if __name__ == "__main__":
|
|||
test_case_7()
|
||||
test_case_8()
|
||||
test_case_9()
|
||||
test_pyfunc_implicit_compose()
|
||||
test_pyfunc_execption()
|
||||
skip_test_pyfunc_execption_multiprocess()
|
||||
|
|
|
@ -46,7 +46,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
|
|||
transforms_2 = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop(crop_size, use_vertical_flip=vertical_flip),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
|
||||
data2 = data2.map(operations=transform_2, input_columns=["image"])
|
||||
|
@ -109,7 +109,7 @@ def test_ten_crop_md5():
|
|||
transforms_2 = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop((200, 100), use_vertical_flip=True),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
lambda *images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
|
||||
data2 = data2.map(operations=transform_2, input_columns=["image"])
|
||||
|
@ -176,7 +176,7 @@ def test_ten_crop_wrong_img_error_msg():
|
|||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
data.create_tuple_iterator(num_epochs=1).get_next()
|
||||
error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
|
||||
error_msg = "TypeError: __call__() takes 2 positional arguments but 11 were given"
|
||||
|
||||
# error msg comes from ToTensor()
|
||||
assert error_msg in str(info.value)
|
||||
|
|
Loading…
Reference in New Issue