forked from mindspore-Ecosystem/mindspore
!12348 Fix SequentialSampler issue
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1e75ac45c4
|
@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
|
||||
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
|
||||
if (child_num_rows - (start_index_ - current_id_) <= 0) {
|
||||
if (child_num_rows - start_index_ <= 0) {
|
||||
return 0;
|
||||
}
|
||||
if (child_num_rows - (start_index_ - current_id_) < num_samples)
|
||||
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples
|
||||
? num_samples
|
||||
: num_samples - (start_index_ - current_id_);
|
||||
if (child_num_rows - start_index_ < num_samples)
|
||||
num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
|
||||
return num_samples;
|
||||
}
|
||||
|
||||
|
|
|
@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain():
|
|||
# Use 1 statement to add child sampler
|
||||
np_data = [1, 2, 3, 4]
|
||||
sampler = ds.SequentialSampler(start_index=1, num_samples=2)
|
||||
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
||||
sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
||||
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 4
|
||||
assert data1_size == 1
|
||||
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 4
|
||||
assert sum([1 for _ in data1]) == 1
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
|
|
Loading…
Reference in New Issue