forked from mindspore-Ecosystem/mindspore
Give concreted error when mappable dataset lack __len__ method
This commit is contained in:
parent
49f012ad74
commit
17c121b087
|
@ -122,17 +122,15 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||
|
||||
Args:
|
||||
step_size: the number of rows in each step. Usually the step size will be equal to the batch size (Default=1).
|
||||
|
||||
Examples:
|
||||
>>> my_cb = MyWaitedCallback(32)
|
||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> data = data.batch(32)
|
||||
>>> # define the model
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
|
||||
|
||||
Args:
|
||||
step_size: the number of rows in each step.
|
||||
Usually the step size will be equal to the batch size (Default=1)
|
||||
"""
|
||||
|
||||
def __init__(self, step_size=1):
|
||||
|
|
|
@ -31,13 +31,13 @@ class DatasetCache:
|
|||
|
||||
Args:
|
||||
session_id (int): A user assigned session id for the current pipeline.
|
||||
size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited,
|
||||
size (int, optional): Size of the memory set aside for the row caching (default=0, which means unlimited,
|
||||
note that it might bring in the risk of running out of memory on the machine).
|
||||
spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
|
||||
hostname (str, optional): Host name (default="127.0.0.1").
|
||||
port (int, optional): Port to connect to server (default=50052).
|
||||
num_connections (int, optional): Number of tcp/ip connections (default=12).
|
||||
prefetch_size (int, optional): Prefetch size (default=20).
|
||||
hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
|
||||
port (int, optional): Port to connect to server (default=None, use default port 50052).
|
||||
num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
|
||||
prefetch_size (int, optional): Prefetch size (default=None, use default value 20).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
|
|
@ -3978,6 +3978,8 @@ class GeneratorDataset(MappableDataset):
|
|||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||
# The reason why there is a try catch here is because when the new op is being constructed with shared
|
||||
# memory enabled, there will be an exception thrown if there is not enough shared memory available
|
||||
if self.source_len == -1:
|
||||
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
|
||||
try:
|
||||
if new_op.num_parallel_workers > 1:
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
|
||||
|
|
|
@ -68,7 +68,7 @@ def deserialize(input_dict=None, json_filepath=None):
|
|||
de.Dataset or None if error occurs.
|
||||
|
||||
Raises:
|
||||
OSError cannot open a file.
|
||||
OSError: Can not open the json file.
|
||||
|
||||
Examples:
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
|
|
|
@ -36,6 +36,8 @@ class DatasetGenerator:
|
|||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
|
||||
class DatasetGeneratorLarge:
|
||||
def __init__(self):
|
||||
self.data = np.array(range(4000))
|
||||
|
@ -485,6 +487,7 @@ def test_generator_17():
|
|||
np.testing.assert_array_equal(item["col1"], golden)
|
||||
i = i + 1
|
||||
|
||||
|
||||
def test_generator_18():
|
||||
"""
|
||||
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
|
||||
|
@ -512,6 +515,7 @@ def test_generator_18():
|
|||
golden = np.array([i * 5])
|
||||
np.testing.assert_array_equal(item["out0"], golden)
|
||||
|
||||
|
||||
def test_generator_19():
|
||||
"""
|
||||
Test multiprocessing flag with 2 different large columns
|
||||
|
@ -532,6 +536,64 @@ def test_generator_19():
|
|||
i = i + 1
|
||||
|
||||
|
||||
class RandomAccessDataset:
|
||||
def __init__(self):
|
||||
self.__data = np.random.sample((5, 1))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__data[item]
|
||||
|
||||
def __len__(self):
|
||||
return 5
|
||||
|
||||
|
||||
class RandomAccessDatasetWithoutLen:
|
||||
def __init__(self):
|
||||
self.__data = np.random.sample((5, 1))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__data[item]
|
||||
|
||||
|
||||
class IterableDataset:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
self.max = 10
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.count >= self.max:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return (np.array(self.count),)
|
||||
|
||||
|
||||
def test_generator_20():
|
||||
"""
|
||||
Test mappable and unmappable dataset as source for GeneratorDataset.
|
||||
"""
|
||||
logger.info("Test mappable and unmappable dataset as source for GeneratorDataset.")
|
||||
|
||||
# Mappable dataset
|
||||
data1 = ds.GeneratorDataset(RandomAccessDataset(), ["col0"])
|
||||
dataset_size1 = data1.get_dataset_size()
|
||||
assert dataset_size1 == 5
|
||||
|
||||
# Mappable dataset without __len__
|
||||
data2 = ds.GeneratorDataset(RandomAccessDatasetWithoutLen(), ["col0"])
|
||||
try:
|
||||
data2.get_dataset_size()
|
||||
except RuntimeError as e:
|
||||
assert "'__len__' method is required" in str(e)
|
||||
|
||||
# Unmappable dataset
|
||||
data3 = ds.GeneratorDataset(IterableDataset(), ["col0"])
|
||||
dataset_size3 = data3.get_dataset_size()
|
||||
assert dataset_size3 == 10
|
||||
|
||||
|
||||
def test_generator_error_1():
|
||||
def generator_np():
|
||||
for i in range(64):
|
||||
|
|
Loading…
Reference in New Issue