!12348 Fix SequentialSampler issue

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-12 22:05:40 +08:00 committed by Gitee
commit 1e75ac45c4
2 changed files with 6 additions and 8 deletions

View File

@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows // For this sampler we need to take start_index into account. Because for example in the case we are given n rows
// and start_index != 0 and num_samples >= n then we can't return all the n rows. // and start_index != 0 and num_samples >= n then we can't return all the n rows.
if (child_num_rows - (start_index_ - current_id_) <= 0) { if (child_num_rows - start_index_ <= 0) {
return 0; return 0;
} }
if (child_num_rows - (start_index_ - current_id_) < num_samples) if (child_num_rows - start_index_ < num_samples)
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
? num_samples
: num_samples - (start_index_ - current_id_);
return num_samples; return num_samples;
} }

View File

@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain():
# Use 1 statement to add child sampler # Use 1 statement to add child sampler
np_data = [1, 2, 3, 4] np_data = [1, 2, 3, 4]
sampler = ds.SequentialSampler(start_index=1, num_samples=2) sampler = ds.SequentialSampler(start_index=1, num_samples=2)
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
# Verify dataset size # Verify dataset size
data1_size = data1.get_dataset_size() data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size)) logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 4 assert data1_size == 1
# Verify number of rows # Verify number of rows
assert sum([1 for _ in data1]) == 4 assert sum([1 for _ in data1]) == 1
# Verify dataset contents # Verify dataset contents
res = [] res = []