forked from mindspore-Ecosystem/mindspore
implemention of new api: apply
This commit is contained in:
parent
94589ce611
commit
1a1cbc6814
|
@ -499,6 +499,51 @@ class Dataset:
|
|||
|
||||
return ProjectDataset(self, columns)
|
||||
|
||||
def apply(self, apply_func):
|
||||
"""
|
||||
Apply a function in this dataset.
|
||||
|
||||
The specified apply_func is a function that must take one 'Dataset' as an argument
|
||||
and return a preprogressing 'Dataset'.
|
||||
|
||||
Args:
|
||||
apply_func (function): A function that must take one 'Dataset' as an argument and
|
||||
return a preprogressing 'Dataset'.
|
||||
|
||||
Returns:
|
||||
Dataset, applied by the function.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # Generate 1d int numpy array from 0 - 6
|
||||
>>> def generator_1d():
|
||||
>>> for i in range(6):
|
||||
>>> yield (np.array([i]),)
|
||||
>>> # 1) get all data from dataset
|
||||
>>> data = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
>>> # 2) declare a apply_func function
|
||||
>>> def apply_func(ds):
|
||||
>>> ds = ds.batch(2)
|
||||
>>> return ds
|
||||
>>> # 3) use apply to call apply_func
|
||||
>>> data = data.apply(apply_func)
|
||||
>>> for item in data.create_dict_iterator():
|
||||
>>> print(item["data"])
|
||||
|
||||
Raises:
|
||||
TypeError: If apply_func is not a function.
|
||||
TypeError: If apply_func doesn't return a Dataset.
|
||||
"""
|
||||
|
||||
if not hasattr(apply_func, '__call__'):
|
||||
raise TypeError("apply_func must be a function.")
|
||||
|
||||
dataset = apply_func(self)
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("apply_func must return a dataset.")
|
||||
return dataset
|
||||
|
||||
def device_que(self, prefetch_size=None):
|
||||
"""
|
||||
Returns a transferredDataset that transfer data through tdt.
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import numpy as np
|
||||
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
|
||||
# Generate 1d int numpy array from 0 - 64
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def test_apply_generator_case():
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data2 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.repeat(2)
|
||||
return ds.batch(4)
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
data2 = data2.repeat(2)
|
||||
data2 = data2.batch(4)
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
assert np.array_equal(item1["data"], item2["data"])
|
||||
|
||||
def test_apply_imagefolder_case():
|
||||
# apply dataset map operations
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
|
||||
data2 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
|
||||
|
||||
decode_op = vision.Decode()
|
||||
normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.map(operations = decode_op)
|
||||
ds = ds.map(operations = normalize_op)
|
||||
ds = ds.repeat(2)
|
||||
return ds
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
data2 = data2.map(operations = decode_op)
|
||||
data2 = data2.map(operations = normalize_op)
|
||||
data2 = data2.repeat(2)
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
assert np.array_equal(item1["image"], item2["image"])
|
||||
|
||||
def test_apply_flow_case_0(id=0):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
def test_apply_flow_case_1(id=1):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
def test_apply_flow_case_2(id=2):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
def test_apply_flow_case_3(id=3):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
def test_apply_exception_case():
|
||||
# apply exception operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.repeat(2)
|
||||
return ds.batch(4)
|
||||
|
||||
def exception_fn(ds):
|
||||
return np.array([[0], [1], [3], [4], [5]])
|
||||
|
||||
try:
|
||||
data1 = data1.apply("123")
|
||||
for _ in data1.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
data1 = data1.apply(exception_fn)
|
||||
for _ in data1.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
data2 = data1.apply(dataset_fn)
|
||||
data3 = data1.apply(dataset_fn)
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
pass
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("Running test_apply.py test_apply_generator_case() function")
|
||||
test_apply_generator_case()
|
||||
|
||||
logger.info("Running test_apply.py test_apply_imagefolder_case() function")
|
||||
test_apply_imagefolder_case()
|
||||
|
||||
logger.info("Running test_apply.py test_apply_flow_case(id) function")
|
||||
test_apply_flow_case_0()
|
||||
test_apply_flow_case_1()
|
||||
test_apply_flow_case_2()
|
||||
test_apply_flow_case_3()
|
||||
|
||||
logger.info("Running test_apply.py test_apply_exception_case() function")
|
||||
test_apply_exception_case()
|
||||
|
Loading…
Reference in New Issue