forked from mindspore-Ecosystem/mindspore
add paramter check for numpyslices and num_shards
This commit is contained in:
parent
2f1b0dc531
commit
7fa0d9e7e4
|
@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset):
|
|||
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
|
||||
required (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
||||
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
|
||||
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
|
||||
when num_shards is also specified. Random accessible input is required.
|
||||
|
||||
|
@ -4878,6 +4878,11 @@ class _NumpySlicesDataset:
|
|||
else:
|
||||
self.data = (np.array(data),)
|
||||
|
||||
# check whether the data length in each column is equal
|
||||
data_len = [len(data_item) for data_item in self.data]
|
||||
if data_len[1:] != data_len[:-1]:
|
||||
raise ValueError("Data length in each column is not equal.")
|
||||
|
||||
# Init column_name
|
||||
if column_list is not None:
|
||||
self.column_list = column_list
|
||||
|
@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset):
|
|||
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
|
||||
required (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
||||
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
|
||||
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
|
||||
when num_shards is also specified. Random accessible input is required.
|
||||
|
||||
|
|
|
@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
raise RuntimeError("sampler and sharding cannot be specified at the same time.")
|
||||
|
||||
if num_shards is not None:
|
||||
check_positive_int32(num_shards, "num_shards")
|
||||
if shard_id is None:
|
||||
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
|
||||
if shard_id < 0 or shard_id >= num_shards:
|
||||
|
@ -529,6 +530,7 @@ def check_generatordataset(method):
|
|||
# These two parameters appear together.
|
||||
raise ValueError("num_shards and shard_id need to be passed in together")
|
||||
if num_shards is not None:
|
||||
check_positive_int32(num_shards, "num_shards")
|
||||
if shard_id >= num_shards:
|
||||
raise ValueError("shard_id should be less than num_shards")
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards():
|
|||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
|
Loading…
Reference in New Issue