- release the wait if training is ended
This commit is contained in:
parent
3da5ca4170
commit
8eeceb267b
|
@ -53,7 +53,11 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
|
|||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
f(cb_param);
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
self.epoch_event = threading.Event()
|
||||
self.epoch_run_context = None
|
||||
|
||||
self.training_ended = False
|
||||
|
||||
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||
"""
|
||||
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
||||
|
@ -180,6 +182,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_epoch_num > 1:
|
||||
if not self.training_ended:
|
||||
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
|
||||
self.epoch_event.clear()
|
||||
if not success:
|
||||
|
@ -205,6 +208,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_step_num > self.step_size:
|
||||
if not self.training_ended:
|
||||
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
|
||||
self.step_event.clear()
|
||||
if not success:
|
||||
|
@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
||||
|
||||
return c_cb
|
||||
|
||||
def end(self, run_context):
|
||||
self.epoch_end(run_context)
|
||||
self.step_end(run_context)
|
||||
self.training_ended = True
|
||||
|
|
|
@ -410,6 +410,22 @@ def test_callbacks_exceptions():
|
|||
assert "RuntimeError: Bad begin" in str(err.value)
|
||||
|
||||
|
||||
def test_callbacks_train_end():
|
||||
logger.info("test_callback_sink_simulation")
|
||||
# No asserts are needed, just test there is no deadlock or exceptions
|
||||
events = []
|
||||
epochs = 2
|
||||
|
||||
my_cb = MyWaitedCallback(events, 1)
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=[my_cb])
|
||||
data = data.to_device()
|
||||
data.send(num_epochs=epochs)
|
||||
time.sleep(0.5)
|
||||
my_cb.end(run_context={})
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def test_callbacks_one_cb():
|
||||
logger.info("test_callbacks_one_cb")
|
||||
|
||||
|
@ -458,3 +474,4 @@ if __name__ == '__main__':
|
|||
test_callbacks_non_sink()
|
||||
test_callbacks_one_cb()
|
||||
test_callbacks_non_sink_mismatch_size()
|
||||
test_callbacks_train_end()
|
||||
|
|
Loading…
Reference in New Issue