fix reset wrong epoch number

This commit is contained in:
Xiao Tianci 2023-02-22 17:05:23 +08:00
parent ef8fed4302
commit 81b7facea0
4 changed files with 23 additions and 20 deletions

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace dataset {
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer")
.def("Reset", [](TreeConsumer &self, int64_t step, uint64_t epoch) {
.def("Reset", [](TreeConsumer &self, int64_t step, int64_t epoch) {
THROW_IF_ERROR(self.Reset(step, epoch));
});
}));

View File

@ -115,17 +115,19 @@ def _get_training_dataset():
return _train_dataset
def _reset_training_dataset(step, epoch):
def _reset_training_dataset(global_step, dataset_size):
"""
Reset the training dataset to the given step and epoch number.
Reset the training dataset to the given global step.
Args:
step (int): Global step number.
epoch (int): Global epoch number
global_step (int): Number of global steps that have completed training.
Dataset will provide data from its next step after reset.
dataset_size (int): Number of steps per epoch.
"""
dataset = _get_training_dataset()
if dataset is not None:
dataset._reset(step, epoch) # pylint: disable=protected-access
epoch = global_step // dataset_size
dataset._reset(global_step, epoch) # pylint: disable=protected-access
else:
raise RuntimeError("Training dataset is not set.")

View File

@ -677,7 +677,7 @@ class Model:
cb_params.train_network = train_network
# Perform recovery for process which is restarted.
self._reset_training_step_for_abnormal_process(cb_params)
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
# Perform recovery for process which is not restarted.
self._reset_training_step_for_normal_process(cb_params, dataset_helper)
@ -808,7 +808,7 @@ class Model:
else:
self.need_load_ckpt = False
def _reset_training_step_for_abnormal_process(self, cb_params):
def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
"""
Execute recovery for abnormal exit process when restart.
@ -823,7 +823,7 @@ class Model:
os.remove(cb_params.latest_ckpt_file)
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
+ cb_params.latest_ckpt_file) from e
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
self.need_load_ckpt = False
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
@ -852,9 +852,9 @@ class Model:
self.epoch_iter = recovery_epoch_num
cb_params.cur_epoch_num = self.epoch_iter + 1
cb_params.last_save_ckpt_step = cb_params.cur_step_num
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
else:
_reset_training_dataset(0, 0)
_reset_training_dataset(0, dataset_helper.sink_size())
_set_recovery_context(need_reset=False)

View File

@ -120,7 +120,7 @@ def run_reset(data, num_epochs: int, failure_point: int):
expected2.append(d)
if cur_step + 1 == failure_point:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // size)
ds.engine.datasets._reset_training_dataset(failure_point, size)
failed = True
break
cur_step += 1
@ -149,12 +149,12 @@ def run_reset_error(data, num_epochs: int, failure_point: int):
if failure_point > 0:
with pytest.raises(RuntimeError) as err:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
assert "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs." in str(err.value)
else:
with pytest.raises(RuntimeError) as err:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
assert "Cannot reset the pipeline, reset step must be >= 0." in str(err.value)
@ -291,7 +291,7 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_
break
if failure:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // dataset_size)
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
failure = False
for d in expected2_itr:
expected2.append(d)
@ -353,7 +353,8 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil):
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, data.get_dataset_size())
failure = False
for d in expected2_itr:
expected2.append(d)
@ -401,7 +402,7 @@ def test_reset_shuffle():
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
failure = False
for step, d in enumerate(expected2_itr):
expected2.append(d)
@ -454,7 +455,7 @@ def test_reset_sampler(sampler):
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
failure = False
for step, d in enumerate(expected2_itr):
expected2.append(d)
@ -505,7 +506,7 @@ def test_reset_batch(fast_recovery):
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
failure = False
for step, d in enumerate(itr):
expected2.append(d)
@ -553,7 +554,7 @@ def test_reset_nonmappable():
break
if failure:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, (failure_point//dataset_size))
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
failure = False
# let's collect the remaining rows of this epoch
if failure_point % dataset_size != 0: