forked from mindspore-Ecosystem/mindspore
Give concreted error when mappable dataset lack __len__ method
This commit is contained in:
parent
49f012ad74
commit
17c121b087
|
@ -122,17 +122,15 @@ class WaitedDSCallback(Callback, DSCallback):
|
||||||
This class can be used to execute a user defined logic right after the previous step or epoch.
|
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||||
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size: the number of rows in each step. Usually the step size will be equal to the batch size (Default=1).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> my_cb = MyWaitedCallback(32)
|
>>> my_cb = MyWaitedCallback(32)
|
||||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||||
>>> data = data.batch(32)
|
>>> data = data.batch(32)
|
||||||
>>> # define the model
|
>>> # define the model
|
||||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
step_size: the number of rows in each step.
|
|
||||||
Usually the step size will be equal to the batch size (Default=1)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, step_size=1):
|
def __init__(self, step_size=1):
|
||||||
|
|
|
@ -31,13 +31,13 @@ class DatasetCache:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (int): A user assigned session id for the current pipeline.
|
session_id (int): A user assigned session id for the current pipeline.
|
||||||
size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited,
|
size (int, optional): Size of the memory set aside for the row caching (default=0, which means unlimited,
|
||||||
note that it might bring in the risk of running out of memory on the machine).
|
note that it might bring in the risk of running out of memory on the machine).
|
||||||
spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
|
spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False).
|
||||||
hostname (str, optional): Host name (default="127.0.0.1").
|
hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
|
||||||
port (int, optional): Port to connect to server (default=50052).
|
port (int, optional): Port to connect to server (default=None, use default port 50052).
|
||||||
num_connections (int, optional): Number of tcp/ip connections (default=12).
|
num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
|
||||||
prefetch_size (int, optional): Prefetch size (default=20).
|
prefetch_size (int, optional): Prefetch size (default=None, use default value 20).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
|
|
@ -3978,6 +3978,8 @@ class GeneratorDataset(MappableDataset):
|
||||||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||||
# The reason why there is a try catch here is because when the new op is being constructed with shared
|
# The reason why there is a try catch here is because when the new op is being constructed with shared
|
||||||
# memory enabled, there will be an exception thrown if there is not enough shared memory available
|
# memory enabled, there will be an exception thrown if there is not enough shared memory available
|
||||||
|
if self.source_len == -1:
|
||||||
|
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
|
||||||
try:
|
try:
|
||||||
if new_op.num_parallel_workers > 1:
|
if new_op.num_parallel_workers > 1:
|
||||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
|
||||||
|
|
|
@ -68,7 +68,7 @@ def deserialize(input_dict=None, json_filepath=None):
|
||||||
de.Dataset or None if error occurs.
|
de.Dataset or None if error occurs.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
OSError cannot open a file.
|
OSError: Can not open the json file.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||||
|
|
|
@ -36,6 +36,8 @@ class DatasetGenerator:
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return 10
|
return 10
|
||||||
|
|
||||||
|
|
||||||
class DatasetGeneratorLarge:
|
class DatasetGeneratorLarge:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = np.array(range(4000))
|
self.data = np.array(range(4000))
|
||||||
|
@ -485,6 +487,7 @@ def test_generator_17():
|
||||||
np.testing.assert_array_equal(item["col1"], golden)
|
np.testing.assert_array_equal(item["col1"], golden)
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
def test_generator_18():
|
def test_generator_18():
|
||||||
"""
|
"""
|
||||||
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
|
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
|
||||||
|
@ -512,6 +515,7 @@ def test_generator_18():
|
||||||
golden = np.array([i * 5])
|
golden = np.array([i * 5])
|
||||||
np.testing.assert_array_equal(item["out0"], golden)
|
np.testing.assert_array_equal(item["out0"], golden)
|
||||||
|
|
||||||
|
|
||||||
def test_generator_19():
|
def test_generator_19():
|
||||||
"""
|
"""
|
||||||
Test multiprocessing flag with 2 different large columns
|
Test multiprocessing flag with 2 different large columns
|
||||||
|
@ -532,6 +536,64 @@ def test_generator_19():
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
|
class RandomAccessDataset:
|
||||||
|
def __init__(self):
|
||||||
|
self.__data = np.random.sample((5, 1))
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__data[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
|
||||||
|
class RandomAccessDatasetWithoutLen:
|
||||||
|
def __init__(self):
|
||||||
|
self.__data = np.random.sample((5, 1))
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__data[item]
|
||||||
|
|
||||||
|
|
||||||
|
class IterableDataset:
|
||||||
|
def __init__(self):
|
||||||
|
self.count = 0
|
||||||
|
self.max = 10
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.count >= self.max:
|
||||||
|
raise StopIteration
|
||||||
|
self.count += 1
|
||||||
|
return (np.array(self.count),)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator_20():
|
||||||
|
"""
|
||||||
|
Test mappable and unmappable dataset as source for GeneratorDataset.
|
||||||
|
"""
|
||||||
|
logger.info("Test mappable and unmappable dataset as source for GeneratorDataset.")
|
||||||
|
|
||||||
|
# Mappable dataset
|
||||||
|
data1 = ds.GeneratorDataset(RandomAccessDataset(), ["col0"])
|
||||||
|
dataset_size1 = data1.get_dataset_size()
|
||||||
|
assert dataset_size1 == 5
|
||||||
|
|
||||||
|
# Mappable dataset without __len__
|
||||||
|
data2 = ds.GeneratorDataset(RandomAccessDatasetWithoutLen(), ["col0"])
|
||||||
|
try:
|
||||||
|
data2.get_dataset_size()
|
||||||
|
except RuntimeError as e:
|
||||||
|
assert "'__len__' method is required" in str(e)
|
||||||
|
|
||||||
|
# Unmappable dataset
|
||||||
|
data3 = ds.GeneratorDataset(IterableDataset(), ["col0"])
|
||||||
|
dataset_size3 = data3.get_dataset_size()
|
||||||
|
assert dataset_size3 == 10
|
||||||
|
|
||||||
|
|
||||||
def test_generator_error_1():
|
def test_generator_error_1():
|
||||||
def generator_np():
|
def generator_np():
|
||||||
for i in range(64):
|
for i in range(64):
|
||||||
|
|
Loading…
Reference in New Issue