fix reset wrong epoch number
This commit is contained in:
parent
ef8fed4302
commit
81b7facea0
|
@ -25,7 +25,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer")
|
||||
.def("Reset", [](TreeConsumer &self, int64_t step, uint64_t epoch) {
|
||||
.def("Reset", [](TreeConsumer &self, int64_t step, int64_t epoch) {
|
||||
THROW_IF_ERROR(self.Reset(step, epoch));
|
||||
});
|
||||
}));
|
||||
|
|
|
@ -115,17 +115,19 @@ def _get_training_dataset():
|
|||
return _train_dataset
|
||||
|
||||
|
||||
def _reset_training_dataset(step, epoch):
|
||||
def _reset_training_dataset(global_step, dataset_size):
|
||||
"""
|
||||
Reset the training dataset to the given step and epoch number.
|
||||
Reset the training dataset to the given global step.
|
||||
|
||||
Args:
|
||||
step (int): Global step number.
|
||||
epoch (int): Global epoch number
|
||||
global_step (int): Number of global steps that have completed training.
|
||||
Dataset will provide data from its next step after reset.
|
||||
dataset_size (int): Number of steps per epoch.
|
||||
"""
|
||||
dataset = _get_training_dataset()
|
||||
if dataset is not None:
|
||||
dataset._reset(step, epoch) # pylint: disable=protected-access
|
||||
epoch = global_step // dataset_size
|
||||
dataset._reset(global_step, epoch) # pylint: disable=protected-access
|
||||
else:
|
||||
raise RuntimeError("Training dataset is not set.")
|
||||
|
||||
|
|
|
@ -677,7 +677,7 @@ class Model:
|
|||
cb_params.train_network = train_network
|
||||
|
||||
# Perform recovery for process which is restarted.
|
||||
self._reset_training_step_for_abnormal_process(cb_params)
|
||||
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
|
||||
# Perform recovery for process which is not restarted.
|
||||
self._reset_training_step_for_normal_process(cb_params, dataset_helper)
|
||||
|
||||
|
@ -808,7 +808,7 @@ class Model:
|
|||
else:
|
||||
self.need_load_ckpt = False
|
||||
|
||||
def _reset_training_step_for_abnormal_process(self, cb_params):
|
||||
def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
|
||||
"""
|
||||
Execute recovery for abnormal exit process when restart.
|
||||
|
||||
|
@ -823,7 +823,7 @@ class Model:
|
|||
os.remove(cb_params.latest_ckpt_file)
|
||||
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
||||
+ cb_params.latest_ckpt_file) from e
|
||||
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
|
||||
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
|
||||
self.need_load_ckpt = False
|
||||
|
||||
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
|
||||
|
@ -852,9 +852,9 @@ class Model:
|
|||
self.epoch_iter = recovery_epoch_num
|
||||
cb_params.cur_epoch_num = self.epoch_iter + 1
|
||||
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
||||
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
|
||||
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
|
||||
else:
|
||||
_reset_training_dataset(0, 0)
|
||||
_reset_training_dataset(0, dataset_helper.sink_size())
|
||||
|
||||
_set_recovery_context(need_reset=False)
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ def run_reset(data, num_epochs: int, failure_point: int):
|
|||
expected2.append(d)
|
||||
if cur_step + 1 == failure_point:
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // size)
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, size)
|
||||
failed = True
|
||||
break
|
||||
cur_step += 1
|
||||
|
@ -149,12 +149,12 @@ def run_reset_error(data, num_epochs: int, failure_point: int):
|
|||
if failure_point > 0:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
|
||||
assert "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs." in str(err.value)
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
|
||||
assert "Cannot reset the pipeline, reset step must be >= 0." in str(err.value)
|
||||
|
||||
|
||||
|
@ -291,7 +291,7 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_
|
|||
break
|
||||
if failure:
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // dataset_size)
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
|
||||
failure = False
|
||||
for d in expected2_itr:
|
||||
expected2.append(d)
|
||||
|
@ -353,7 +353,8 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil):
|
|||
failure = True
|
||||
break
|
||||
if failure:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, data.get_dataset_size())
|
||||
failure = False
|
||||
for d in expected2_itr:
|
||||
expected2.append(d)
|
||||
|
@ -401,7 +402,7 @@ def test_reset_shuffle():
|
|||
failure = True
|
||||
break
|
||||
if failure:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
|
||||
failure = False
|
||||
for step, d in enumerate(expected2_itr):
|
||||
expected2.append(d)
|
||||
|
@ -454,7 +455,7 @@ def test_reset_sampler(sampler):
|
|||
failure = True
|
||||
break
|
||||
if failure:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
|
||||
failure = False
|
||||
for step, d in enumerate(expected2_itr):
|
||||
expected2.append(d)
|
||||
|
@ -505,7 +506,7 @@ def test_reset_batch(fast_recovery):
|
|||
failure = True
|
||||
break
|
||||
if failure:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, data1.get_dataset_size()) # pylint: disable=W0212
|
||||
failure = False
|
||||
for step, d in enumerate(itr):
|
||||
expected2.append(d)
|
||||
|
@ -553,7 +554,7 @@ def test_reset_nonmappable():
|
|||
break
|
||||
if failure:
|
||||
# pylint: disable=W0212
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, (failure_point//dataset_size))
|
||||
ds.engine.datasets._reset_training_dataset(failure_point, dataset_size)
|
||||
failure = False
|
||||
# let's collect the remaining rows of this epoch
|
||||
if failure_point % dataset_size != 0:
|
||||
|
|
Loading…
Reference in New Issue