forked from mindspore-Ecosystem/mindspore
fixed num_samples bug for RandomSampler
This commit is contained in:
parent
6420f7248f
commit
3ef3d1433d
|
@ -76,6 +76,7 @@ Status RandomSampler::InitSampler() {
|
|||
|
||||
if (replacement_ == false) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
|
|
|
@ -57,6 +57,24 @@ def test_imagefolder_numsamples():
|
|||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 10
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 3
|
||||
|
||||
|
||||
def test_imagefolder_numshards():
|
||||
logger.info("Test Case numShards")
|
||||
|
|
Loading…
Reference in New Issue