forked from mindspore-Ecosystem/mindspore
!1380 make ShuffleOp have deterministic behavior for subsequent epochs
Merge pull request !1380 from Peilin/shuffle-subsequent-epoch-deterministic
This commit is contained in:
commit
a6b8451a33
|
@ -83,14 +83,14 @@ ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_con
|
||||||
// itself rather than waiting for the reset driven from operators above it in the pipeline.
|
// itself rather than waiting for the reset driven from operators above it in the pipeline.
|
||||||
Status ShuffleOp::SelfReset() {
|
Status ShuffleOp::SelfReset() {
|
||||||
MS_LOG(DEBUG) << "Shuffle operator performing a self-reset.";
|
MS_LOG(DEBUG) << "Shuffle operator performing a self-reset.";
|
||||||
// If ReshuffleEachEpoch is false, then we always use the same seed for every
|
// If reshuffle_each_epoch is false, then we always use the same seed for every
|
||||||
// epoch.
|
// epoch.
|
||||||
// If ReshuffleEachEpoch is true, then the first epoch uses the given seed,
|
// If reshuffle_each_epoch is true, then the first epoch uses the given seed,
|
||||||
// and all subsequent epochs will then reset the seed based on random device.
|
// and all subsequent epochs will then keep on using the rng_ without resetting it
|
||||||
if (reshuffle_each_epoch_) {
|
if (!reshuffle_each_epoch_) {
|
||||||
shuffle_seed_ = GetNewSeed();
|
rng_ = std::mt19937_64(shuffle_seed_);
|
||||||
}
|
}
|
||||||
rng_ = std::mt19937_64(shuffle_seed_);
|
|
||||||
shuffle_buffer_ = std::make_unique<TensorTable>();
|
shuffle_buffer_ = std::make_unique<TensorTable>();
|
||||||
buffer_counter_ = 0;
|
buffer_counter_ = 0;
|
||||||
shuffle_last_row_idx_ = 0;
|
shuffle_last_row_idx_ = 0;
|
||||||
|
|
Binary file not shown.
|
@ -47,7 +47,7 @@ def test_2ops_repeat_shuffle():
|
||||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
|
||||||
def skip_test_2ops_shuffle_repeat():
|
def test_2ops_shuffle_repeat():
|
||||||
"""
|
"""
|
||||||
Test Shuffle then Repeat
|
Test Shuffle then Repeat
|
||||||
"""
|
"""
|
||||||
|
@ -159,7 +159,7 @@ def test_2ops_shuffle_batch():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_2ops_repeat_shuffle()
|
test_2ops_repeat_shuffle()
|
||||||
# test_2ops_shuffle_repeat()
|
test_2ops_shuffle_repeat()
|
||||||
test_2ops_repeat_batch()
|
test_2ops_repeat_batch()
|
||||||
test_2ops_batch_repeat()
|
test_2ops_batch_repeat()
|
||||||
test_2ops_batch_shuffle()
|
test_2ops_batch_shuffle()
|
||||||
|
|
Loading…
Reference in New Issue