forked from mindspore-Ecosystem/mindspore
!1922 check parameter num_samples of sampler
Merge pull request !1922 from yanghaitao/yht_check_num_samples_sampler
This commit is contained in:
commit
3085e51e45
|
@ -218,6 +218,11 @@ class DistributedSampler(BuiltinSampler):
|
|||
if not isinstance(shuffle, bool):
|
||||
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
|
@ -282,6 +287,11 @@ class PKSampler(BuiltinSampler):
|
|||
if not isinstance(shuffle, bool):
|
||||
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
|
@ -385,6 +395,16 @@ class SequentialSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, start_index=None, num_samples=None):
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
if start_index is not None:
|
||||
if start_index < 0:
|
||||
raise ValueError("start_index should be a positive integer "
|
||||
"value or 0, but got start_index={}".format(start_index))
|
||||
|
||||
self.start_index = start_index
|
||||
super().__init__(num_samples)
|
||||
|
||||
|
@ -430,6 +450,11 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def __init__(self, indices, num_samples=None):
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ def test_sequential_sampler(print_res=False):
|
|||
|
||||
assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
|
||||
assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
|
||||
assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2
|
||||
assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue