implemention of new api: apply

This commit is contained in:
xiefangqi 2020-04-14 19:21:15 +08:00
parent 94589ce611
commit 1a1cbc6814
2 changed files with 281 additions and 0 deletions

View File

@ -499,6 +499,51 @@ class Dataset:
return ProjectDataset(self, columns)
def apply(self, apply_func):
"""
Apply a function in this dataset.
The specified apply_func is a function that must take one 'Dataset' as an argument
and return a preprogressing 'Dataset'.
Args:
apply_func (function): A function that must take one 'Dataset' as an argument and
return a preprogressing 'Dataset'.
Returns:
Dataset, applied by the function.
Examples:
>>> import numpy as np
>>> import mindspore.dataset as ds
>>> # Generate 1d int numpy array from 0 - 6
>>> def generator_1d():
>>> for i in range(6):
>>> yield (np.array([i]),)
>>> # 1) get all data from dataset
>>> data = ds.GeneratorDataset(generator_1d, ["data"])
>>> # 2) declare a apply_func function
>>> def apply_func(ds):
>>> ds = ds.batch(2)
>>> return ds
>>> # 3) use apply to call apply_func
>>> data = data.apply(apply_func)
>>> for item in data.create_dict_iterator():
>>> print(item["data"])
Raises:
TypeError: If apply_func is not a function.
TypeError: If apply_func doesn't return a Dataset.
"""
if not hasattr(apply_func, '__call__'):
raise TypeError("apply_func must be a function.")
dataset = apply_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("apply_func must return a dataset.")
return dataset
def device_que(self, prefetch_size=None):
"""
Returns a transferredDataset that transfer data through tdt.

View File

@ -0,0 +1,236 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np
DATA_DIR = "../data/dataset/testPK/data"
# Generate 1d int numpy array from 0 - 64
def generator_1d():
for i in range(64):
yield (np.array([i]),)
def test_apply_generator_case():
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data2 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
ds = ds.repeat(2)
return ds.batch(4)
data1 = data1.apply(dataset_fn)
data2 = data2.repeat(2)
data2 = data2.batch(4)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
assert np.array_equal(item1["data"], item2["data"])
def test_apply_imagefolder_case():
# apply dataset map operations
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
data2 = ds.ImageFolderDatasetV2(DATA_DIR, num_shards=4, shard_id=3)
decode_op = vision.Decode()
normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
def dataset_fn(ds):
ds = ds.map(operations = decode_op)
ds = ds.map(operations = normalize_op)
ds = ds.repeat(2)
return ds
data1 = data1.apply(dataset_fn)
data2 = data2.map(operations = decode_op)
data2 = data2.map(operations = normalize_op)
data2 = data2.repeat(2)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
assert np.array_equal(item1["image"], item2["image"])
def test_apply_flow_case_0(id=0):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
if id == 0:
ds = ds.batch(4)
elif id == 1:
ds = ds.repeat(2)
elif id == 2:
ds = ds.batch(4)
ds = ds.repeat(2)
else:
ds = ds.shuffle(buffer_size=4)
return ds
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter = num_iter + 1
if id == 0:
assert num_iter == 16
elif id == 1:
assert num_iter == 128
elif id == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_1(id=1):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
if id == 0:
ds = ds.batch(4)
elif id == 1:
ds = ds.repeat(2)
elif id == 2:
ds = ds.batch(4)
ds = ds.repeat(2)
else:
ds = ds.shuffle(buffer_size=4)
return ds
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter = num_iter + 1
if id == 0:
assert num_iter == 16
elif id == 1:
assert num_iter == 128
elif id == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_2(id=2):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
if id == 0:
ds = ds.batch(4)
elif id == 1:
ds = ds.repeat(2)
elif id == 2:
ds = ds.batch(4)
ds = ds.repeat(2)
else:
ds = ds.shuffle(buffer_size=4)
return ds
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter = num_iter + 1
if id == 0:
assert num_iter == 16
elif id == 1:
assert num_iter == 128
elif id == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_flow_case_3(id=3):
# apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
if id == 0:
ds = ds.batch(4)
elif id == 1:
ds = ds.repeat(2)
elif id == 2:
ds = ds.batch(4)
ds = ds.repeat(2)
else:
ds = ds.shuffle(buffer_size=4)
return ds
data1 = data1.apply(dataset_fn)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter = num_iter + 1
if id == 0:
assert num_iter == 16
elif id == 1:
assert num_iter == 128
elif id == 2:
assert num_iter == 32
else:
assert num_iter == 64
def test_apply_exception_case():
# apply exception operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds):
ds = ds.repeat(2)
return ds.batch(4)
def exception_fn(ds):
return np.array([[0], [1], [3], [4], [5]])
try:
data1 = data1.apply("123")
for _ in data1.create_dict_iterator():
pass
assert False
except TypeError:
pass
try:
data1 = data1.apply(exception_fn)
for _ in data1.create_dict_iterator():
pass
assert False
except TypeError:
pass
try:
data2 = data1.apply(dataset_fn)
data3 = data1.apply(dataset_fn)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
pass
assert False
except ValueError:
pass
if __name__ == '__main__':
logger.info("Running test_apply.py test_apply_generator_case() function")
test_apply_generator_case()
logger.info("Running test_apply.py test_apply_imagefolder_case() function")
test_apply_imagefolder_case()
logger.info("Running test_apply.py test_apply_flow_case(id) function")
test_apply_flow_case_0()
test_apply_flow_case_1()
test_apply_flow_case_2()
test_apply_flow_case_3()
logger.info("Running test_apply.py test_apply_exception_case() function")
test_apply_exception_case()