forked from mindspore-Ecosystem/mindspore
!12348 Fix SequentialSampler issue
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1e75ac45c4
|
@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
||||||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||||
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
|
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
|
||||||
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
|
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
|
||||||
if (child_num_rows - (start_index_ - current_id_) <= 0) {
|
if (child_num_rows - start_index_ <= 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (child_num_rows - (start_index_ - current_id_) < num_samples)
|
if (child_num_rows - start_index_ < num_samples)
|
||||||
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples
|
num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
|
||||||
? num_samples
|
|
||||||
: num_samples - (start_index_ - current_id_);
|
|
||||||
return num_samples;
|
return num_samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain():
|
||||||
# Use 1 statement to add child sampler
|
# Use 1 statement to add child sampler
|
||||||
np_data = [1, 2, 3, 4]
|
np_data = [1, 2, 3, 4]
|
||||||
sampler = ds.SequentialSampler(start_index=1, num_samples=2)
|
sampler = ds.SequentialSampler(start_index=1, num_samples=2)
|
||||||
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
||||||
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
|
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
|
||||||
|
|
||||||
# Verify dataset size
|
# Verify dataset size
|
||||||
data1_size = data1.get_dataset_size()
|
data1_size = data1.get_dataset_size()
|
||||||
logger.info("dataset size is: {}".format(data1_size))
|
logger.info("dataset size is: {}".format(data1_size))
|
||||||
assert data1_size == 4
|
assert data1_size == 1
|
||||||
|
|
||||||
# Verify number of rows
|
# Verify number of rows
|
||||||
assert sum([1 for _ in data1]) == 4
|
assert sum([1 for _ in data1]) == 1
|
||||||
|
|
||||||
# Verify dataset contents
|
# Verify dataset contents
|
||||||
res = []
|
res = []
|
||||||
|
|
Loading…
Reference in New Issue