fixed num_samples bug for RandomSampler

This commit is contained in:
Peilin Wang 2020-05-29 13:09:26 -04:00
parent 6420f7248f
commit 3ef3d1433d
2 changed files with 19 additions and 0 deletions

View File

@ -76,6 +76,7 @@ Status RandomSampler::InitSampler() {
if (replacement_ == false) { if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_); num_samples_ = std::min(num_samples_, num_rows_);
num_samples_ = std::min(num_samples_, user_num_samples_);
shuffled_ids_.reserve(num_rows_); shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) { for (int64_t i = 0; i < num_rows_; i++) {

View File

@ -57,6 +57,24 @@ def test_imagefolder_numsamples():
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 10 assert num_iter == 10
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 3
def test_imagefolder_numshards(): def test_imagefolder_numshards():
logger.info("Test Case numShards") logger.info("Test Case numShards")