Give concreted error when mappable dataset lack __len__ method

This commit is contained in:
YangLuo 2021-06-26 11:17:13 +08:00
parent 49f012ad74
commit 17c121b087
6 changed files with 85 additions and 23 deletions

View File

@ -122,17 +122,15 @@ class WaitedDSCallback(Callback, DSCallback):
This class can be used to execute a user defined logic right after the previous step or epoch. This class can be used to execute a user defined logic right after the previous step or epoch.
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters. For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
Args:
step_size: the number of rows in each step. Usually the step size will be equal to the batch size (Default=1).
Examples: Examples:
>>> my_cb = MyWaitedCallback(32) >>> my_cb = MyWaitedCallback(32)
>>> data = data.map(operations=AugOp(), callbacks=my_cb) >>> data = data.map(operations=AugOp(), callbacks=my_cb)
>>> data = data.batch(32) >>> data = data.batch(32)
>>> # define the model >>> # define the model
>>> model.train(epochs, data, callbacks=[my_cb]) >>> model.train(epochs, data, callbacks=[my_cb])
Args:
step_size: the number of rows in each step.
Usually the step size will be equal to the batch size (Default=1)
""" """
def __init__(self, step_size=1): def __init__(self, step_size=1):

View File

@ -31,13 +31,13 @@ class DatasetCache:
Args: Args:
session_id (int): A user assigned session id for the current pipeline. session_id (int): A user assigned session id for the current pipeline.
size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited, size (int, optional): Size of the memory set aside for the row caching (default=0, which means unlimited,
note that it might bring in the risk of running out of memory on the machine). note that it might bring in the risk of running out of memory on the machine).
spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False). spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
hostname (str, optional): Host name (default="127.0.0.1"). hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
port (int, optional): Port to connect to server (default=50052). port (int, optional): Port to connect to server (default=None, use default port 50052).
num_connections (int, optional): Number of tcp/ip connections (default=12). num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
prefetch_size (int, optional): Prefetch size (default=20). prefetch_size (int, optional): Prefetch size (default=None, use default value 20).
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds

View File

@ -3978,6 +3978,8 @@ class GeneratorDataset(MappableDataset):
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared # The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available # memory enabled, there will be an exception thrown if there is not enough shared memory available
if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try: try:
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,

View File

@ -68,7 +68,7 @@ def deserialize(input_dict=None, json_filepath=None):
de.Dataset or None if error occurs. de.Dataset or None if error occurs.
Raises: Raises:
OSError cannot open a file. OSError: Can not open the json file.
Examples: Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100) >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)

View File

@ -36,6 +36,8 @@ class DatasetGenerator:
def __len__(self): def __len__(self):
return 10 return 10
class DatasetGeneratorLarge: class DatasetGeneratorLarge:
def __init__(self): def __init__(self):
self.data = np.array(range(4000)) self.data = np.array(range(4000))
@ -485,6 +487,7 @@ def test_generator_17():
np.testing.assert_array_equal(item["col1"], golden) np.testing.assert_array_equal(item["col1"], golden)
i = i + 1 i = i + 1
def test_generator_18(): def test_generator_18():
""" """
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag) Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
@ -512,6 +515,7 @@ def test_generator_18():
golden = np.array([i * 5]) golden = np.array([i * 5])
np.testing.assert_array_equal(item["out0"], golden) np.testing.assert_array_equal(item["out0"], golden)
def test_generator_19(): def test_generator_19():
""" """
Test multiprocessing flag with 2 different large columns Test multiprocessing flag with 2 different large columns
@ -532,6 +536,64 @@ def test_generator_19():
i = i + 1 i = i + 1
class RandomAccessDataset:
def __init__(self):
self.__data = np.random.sample((5, 1))
def __getitem__(self, item):
return self.__data[item]
def __len__(self):
return 5
class RandomAccessDatasetWithoutLen:
def __init__(self):
self.__data = np.random.sample((5, 1))
def __getitem__(self, item):
return self.__data[item]
class IterableDataset:
def __init__(self):
self.count = 0
self.max = 10
def __iter__(self):
return self
def __next__(self):
if self.count >= self.max:
raise StopIteration
self.count += 1
return (np.array(self.count),)
def test_generator_20():
"""
Test mappable and unmappable dataset as source for GeneratorDataset.
"""
logger.info("Test mappable and unmappable dataset as source for GeneratorDataset.")
# Mappable dataset
data1 = ds.GeneratorDataset(RandomAccessDataset(), ["col0"])
dataset_size1 = data1.get_dataset_size()
assert dataset_size1 == 5
# Mappable dataset without __len__
data2 = ds.GeneratorDataset(RandomAccessDatasetWithoutLen(), ["col0"])
try:
data2.get_dataset_size()
except RuntimeError as e:
assert "'__len__' method is required" in str(e)
# Unmappable dataset
data3 = ds.GeneratorDataset(IterableDataset(), ["col0"])
dataset_size3 = data3.get_dataset_size()
assert dataset_size3 == 10
def test_generator_error_1(): def test_generator_error_1():
def generator_np(): def generator_np():
for i in range(64): for i in range(64):