diff --git a/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc index 56416f72f98..6763dada429 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc +++ b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc @@ -53,7 +53,11 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param if (Py_IsInitialized() == 0) { return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); } - f(cb_param); + try { + f(cb_param); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } } return Status::OK(); } diff --git a/mindspore/dataset/callback/ds_callback.py b/mindspore/dataset/callback/ds_callback.py index 52ffdaffbcb..e4a1a45412a 100644 --- a/mindspore/dataset/callback/ds_callback.py +++ b/mindspore/dataset/callback/ds_callback.py @@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback): self.epoch_event = threading.Event() self.epoch_run_context = None + self.training_ended = False + def sync_epoch_begin(self, train_run_context, ds_run_context): """ Called before a new dataset epoch is started and after the previous training epoch is ended. @@ -180,10 +182,11 @@ class WaitedDSCallback(Callback, DSCallback): ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_epoch_num > 1: - success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) - self.epoch_event.clear() - if not success: - raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") + if not self.training_ended: + success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) + self.epoch_event.clear() + if not success: + raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") # by the time this thread wakes up, self.epoch_run_context is already available self.sync_epoch_begin(self.epoch_run_context, ds_run_context) @@ -205,11 +208,12 @@ class WaitedDSCallback(Callback, DSCallback): ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_step_num > self.step_size: - success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) - self.step_event.clear() - if not success: - raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") - # by the time this thread wakes up, self.epoch_run_context is already available + if not self.training_ended: + success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) + self.step_event.clear() + if not success: + raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") + # by the time this thread wakes up, self.epoch_run_context is already available self.sync_step_begin(self.step_run_context, ds_run_context) def create_runtime_obj(self): @@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback): raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") return c_cb + + def end(self, run_context): + self.epoch_end(run_context) + self.step_end(run_context) + self.training_ended = True diff --git a/tests/ut/python/dataset/test_callbacks.py b/tests/ut/python/dataset/test_callbacks.py index d75ab4fbe06..77c9d23953e 100644 --- a/tests/ut/python/dataset/test_callbacks.py +++ b/tests/ut/python/dataset/test_callbacks.py @@ -410,6 +410,22 @@ def test_callbacks_exceptions(): assert "RuntimeError: Bad begin" in str(err.value) +def test_callbacks_train_end(): + logger.info("test_callback_sink_simulation") + # No asserts are needed, just test there is no deadlock or exceptions + events = [] + epochs = 2 + + my_cb = MyWaitedCallback(events, 1) + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=[my_cb]) + data = data.to_device() + data.send(num_epochs=epochs) + time.sleep(0.5) + my_cb.end(run_context={}) + time.sleep(0.5) + + def test_callbacks_one_cb(): logger.info("test_callbacks_one_cb") @@ -458,3 +474,4 @@ if __name__ == '__main__': test_callbacks_non_sink() test_callbacks_one_cb() test_callbacks_non_sink_mismatch_size() + test_callbacks_train_end()