!5194 fix: padded dataset when no div and with repeat op for br:r0.7

Merge pull request !5194 from guozhijian/fix_padded_with_no_div_repeat_r0.7
This commit is contained in:
mindspore-ci-bot 2020-08-26 10:47:37 +08:00 committed by Gitee
commit fd8ad73689
2 changed files with 18 additions and 0 deletions

View File

@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
if (!samples_per_buffer_) {
non_empty_ = false;
}
} else if (!samples_per_buffer_ && !non_empty_) {
// If the buffer is empty, we add samples with subscript 0 in the current dataset.
// This step is to make up for the solution that the code default buffer is not empty before.

View File

@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
count += 1
assert count == 2
def test_celeba_padded():
data = ds.CelebADataset("../data/dataset/testCelebAData/")
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
padded_ds = ds.PaddedDataset(padded_samples)
data = data + padded_ds
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
data.use_sampler(dis_sampler)
data = data.repeat(2)
count = 0
for _ in data.create_dict_iterator():
count = count + 1
assert count == 2
if __name__ == '__main__':
test_TFRecord_Padded()
test_GeneratorDataSet_Padded()