forked from mindspore-Ecosystem/mindspore
fixed num_samples bug for RandomSampler
This commit is contained in:
parent
6420f7248f
commit
3ef3d1433d
|
@ -76,6 +76,7 @@ Status RandomSampler::InitSampler() {
|
||||||
|
|
||||||
if (replacement_ == false) {
|
if (replacement_ == false) {
|
||||||
num_samples_ = std::min(num_samples_, num_rows_);
|
num_samples_ = std::min(num_samples_, num_rows_);
|
||||||
|
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||||
|
|
||||||
shuffled_ids_.reserve(num_rows_);
|
shuffled_ids_.reserve(num_rows_);
|
||||||
for (int64_t i = 0; i < num_rows_; i++) {
|
for (int64_t i = 0; i < num_rows_; i++) {
|
||||||
|
|
|
@ -57,6 +57,24 @@ def test_imagefolder_numsamples():
|
||||||
logger.info("Number of data in data1: {}".format(num_iter))
|
logger.info("Number of data in data1: {}".format(num_iter))
|
||||||
assert num_iter == 10
|
assert num_iter == 10
|
||||||
|
|
||||||
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
||||||
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for item in data1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
assert num_iter == 3
|
||||||
|
|
||||||
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
||||||
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for item in data1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
assert num_iter == 3
|
||||||
|
|
||||||
|
|
||||||
def test_imagefolder_numshards():
|
def test_imagefolder_numshards():
|
||||||
logger.info("Test Case numShards")
|
logger.info("Test Case numShards")
|
||||||
|
|
Loading…
Reference in New Issue