!5184 fix: padded dataset when no div and with repeat op
Merge pull request !5184 from guozhijian/fix_padded_with_no_div_repeat
This commit is contained in:
commit
20b3134785
|
@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
|
|||
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
|
||||
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
if (!samples_per_buffer_) {
|
||||
non_empty_ = false;
|
||||
}
|
||||
} else if (!samples_per_buffer_ && !non_empty_) {
|
||||
// If the buffer is empty, we add samples with subscript 0 in the current dataset.
|
||||
// This step is to make up for the solution that the code default buffer is not empty before.
|
||||
|
|
|
@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
|
|||
count += 1
|
||||
assert count == 2
|
||||
|
||||
def test_celeba_padded():
|
||||
data = ds.CelebADataset("../data/dataset/testCelebAData/")
|
||||
|
||||
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
|
||||
padded_ds = ds.PaddedDataset(padded_samples)
|
||||
data = data + padded_ds
|
||||
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
|
||||
data.use_sampler(dis_sampler)
|
||||
data = data.repeat(2)
|
||||
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_TFRecord_Padded()
|
||||
test_GeneratorDataSet_Padded()
|
||||
|
|
Loading…
Reference in New Issue