fix: padded dataset with non div & repeat

This commit is contained in:
jonyguo 2020-08-25 22:43:21 +08:00
parent 9e20e17590
commit d262c63214
2 changed files with 18 additions and 0 deletions

View File

@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
if (!samples_per_buffer_) {
non_empty_ = false;
}
} else if (!samples_per_buffer_ && !non_empty_) {
// If the buffer is empty, we add samples with subscript 0 in the current dataset.
// This step is to make up for the solution that the code default buffer is not empty before.

View File

@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
count += 1
assert count == 2
def test_celeba_padded():
data = ds.CelebADataset("../data/dataset/testCelebAData/")
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
padded_ds = ds.PaddedDataset(padded_samples)
data = data + padded_ds
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
data.use_sampler(dis_sampler)
data = data.repeat(2)
count = 0
for _ in data.create_dict_iterator():
count = count + 1
assert count == 2
if __name__ == '__main__':
test_TFRecord_Padded()
test_GeneratorDataSet_Padded()