forked from mindspore-Ecosystem/mindspore
check num_samples
This commit is contained in:
parent
3c1d35295c
commit
854308ff75
|
@ -243,6 +243,8 @@ def check_param_type(param_list, param_dict, param_type):
|
||||||
if param_dict.get(param_name) is not None:
|
if param_dict.get(param_name) is not None:
|
||||||
if param_name == 'num_parallel_workers':
|
if param_name == 'num_parallel_workers':
|
||||||
check_num_parallel_workers(param_dict.get(param_name))
|
check_num_parallel_workers(param_dict.get(param_name))
|
||||||
|
if param_name == 'num_samples':
|
||||||
|
check_num_samples(param_dict.get(param_name))
|
||||||
else:
|
else:
|
||||||
check_type(param_dict.get(param_name), param_name, param_type)
|
check_type(param_dict.get(param_name), param_name, param_type)
|
||||||
|
|
||||||
|
@ -262,6 +264,12 @@ def check_num_parallel_workers(value):
|
||||||
raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count()))
|
raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count()))
|
||||||
|
|
||||||
|
|
||||||
|
def check_num_samples(value):
|
||||||
|
check_type(value, 'num_samples', int)
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError("num_samples must be greater than 0!")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_dir(dataset_dir):
|
def check_dataset_dir(dataset_dir):
|
||||||
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
|
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
|
||||||
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
|
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
|
||||||
|
|
|
@ -33,14 +33,14 @@ def test_imagefolder_shardings(print_res=False):
|
||||||
# total 44 rows in dataset
|
# total 44 rows in dataset
|
||||||
assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows
|
assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows
|
||||||
assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows
|
assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows
|
||||||
assert (sharding_config(4, 3, 0, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows
|
assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows
|
||||||
# total 22 in dataset rows because of class indexing which takes only 2 folders
|
# total 22 in dataset rows because of class indexing which takes only 2 folders
|
||||||
assert (len(sharding_config(4, 0, 0, True, {"class1": 111, "class2": 999})) == 6)
|
assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6)
|
||||||
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3)
|
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3)
|
||||||
# test with repeat
|
# test with repeat
|
||||||
assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
|
assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
|
||||||
assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
|
assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
|
||||||
assert (len(sharding_config(5, 1, 0, True, {"class1": 111, "class2": 999}, 4)) == 20)
|
assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20)
|
||||||
|
|
||||||
|
|
||||||
def test_manifest_shardings(print_res=False):
|
def test_manifest_shardings(print_res=False):
|
||||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
|
||||||
def skip_test_exception():
|
def skip_test_exception():
|
||||||
|
@ -29,5 +30,23 @@ def skip_test_exception():
|
||||||
assert "The shape size 1 of input tensor is invalid" in str(info.value)
|
assert "The shape size 1 of input tensor is invalid" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_exception():
|
||||||
|
num_samples = 0
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||||
|
assert "num_samples must be greater than 0" in str(info.value)
|
||||||
|
num_samples = -1
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||||
|
assert "num_samples must be greater than 0" in str(info.value)
|
||||||
|
num_samples = 1
|
||||||
|
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||||
|
data = data.map(input_columns=["image"], operations=vision.Decode())
|
||||||
|
data = data.map(input_columns=["image"], operations=vision.Resize((100, 100)))
|
||||||
|
num_iters = 0
|
||||||
|
for item in data.create_dict_iterator():
|
||||||
|
num_iters += 1
|
||||||
|
assert num_iters == 1
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_exception()
|
test_exception()
|
||||||
|
|
Loading…
Reference in New Issue