Give concreted error when mappable dataset lack __len__ method

This commit is contained in:
YangLuo 2021-06-26 11:17:13 +08:00
parent 49f012ad74
commit 17c121b087
6 changed files with 85 additions and 23 deletions

View File

@ -122,17 +122,15 @@ class WaitedDSCallback(Callback, DSCallback):
This class can be used to execute a user defined logic right after the previous step or epoch.
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
Args:
step_size: the number of rows in each step. Usually the step size will be equal to the batch size (Default=1).
Examples:
>>> my_cb = MyWaitedCallback(32)
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
>>> data = data.batch(32)
>>> # define the model
>>> model.train(epochs, data, callbacks=[my_cb])
Args:
step_size: the number of rows in each step.
Usually the step size will be equal to the batch size (Default=1)
"""
def __init__(self, step_size=1):

View File

@ -31,13 +31,13 @@ class DatasetCache:
Args:
session_id (int): A user assigned session id for the current pipeline.
size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited,
size (int, optional): Size of the memory set aside for the row caching (default=0, which means unlimited,
note that it might bring in the risk of running out of memory on the machine).
spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
hostname (str, optional): Host name (default="127.0.0.1").
port (int, optional): Port to connect to server (default=50052).
num_connections (int, optional): Number of tcp/ip connections (default=12).
prefetch_size (int, optional): Prefetch size (default=20).
hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
port (int, optional): Port to connect to server (default=None, use default port 50052).
num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
prefetch_size (int, optional): Prefetch size (default=None, use default value 20).
Examples:
>>> import mindspore.dataset as ds

View File

@ -3978,6 +3978,8 @@ class GeneratorDataset(MappableDataset):
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available
if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try:
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,

View File

@ -68,7 +68,7 @@ def deserialize(input_dict=None, json_filepath=None):
de.Dataset or None if error occurs.
Raises:
OSError cannot open a file.
OSError: Can not open the json file.
Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)

View File

@ -36,6 +36,8 @@ class DatasetGenerator:
def __len__(self):
return 10
class DatasetGeneratorLarge:
def __init__(self):
self.data = np.array(range(4000))
@ -485,6 +487,7 @@ def test_generator_17():
np.testing.assert_array_equal(item["col1"], golden)
i = i + 1
def test_generator_18():
"""
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
@ -512,6 +515,7 @@ def test_generator_18():
golden = np.array([i * 5])
np.testing.assert_array_equal(item["out0"], golden)
def test_generator_19():
"""
Test multiprocessing flag with 2 different large columns
@ -532,6 +536,64 @@ def test_generator_19():
i = i + 1
class RandomAccessDataset:
def __init__(self):
self.__data = np.random.sample((5, 1))
def __getitem__(self, item):
return self.__data[item]
def __len__(self):
return 5
class RandomAccessDatasetWithoutLen:
def __init__(self):
self.__data = np.random.sample((5, 1))
def __getitem__(self, item):
return self.__data[item]
class IterableDataset:
def __init__(self):
self.count = 0
self.max = 10
def __iter__(self):
return self
def __next__(self):
if self.count >= self.max:
raise StopIteration
self.count += 1
return (np.array(self.count),)
def test_generator_20():
"""
Test mappable and unmappable dataset as source for GeneratorDataset.
"""
logger.info("Test mappable and unmappable dataset as source for GeneratorDataset.")
# Mappable dataset
data1 = ds.GeneratorDataset(RandomAccessDataset(), ["col0"])
dataset_size1 = data1.get_dataset_size()
assert dataset_size1 == 5
# Mappable dataset without __len__
data2 = ds.GeneratorDataset(RandomAccessDatasetWithoutLen(), ["col0"])
try:
data2.get_dataset_size()
except RuntimeError as e:
assert "'__len__' method is required" in str(e)
# Unmappable dataset
data3 = ds.GeneratorDataset(IterableDataset(), ["col0"])
dataset_size3 = data3.get_dataset_size()
assert dataset_size3 == 10
def test_generator_error_1():
def generator_np():
for i in range(64):