diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc index 3d79f3561d..99c94acd2e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc @@ -81,6 +81,7 @@ PYBIND_REGISTER( .def("GetNumClasses", &DEPipeline::GetNumClasses) .def("GetRepeatCount", &DEPipeline::GetRepeatCount) .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) + .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); }) .def("SaveDataset", [](DEPipeline &de, const std::vector &file_names, const std::string &file_type) { THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); return true; diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index 73f45a2d2d..6bfdf0e4ff 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -291,6 +291,16 @@ Status DEPipeline::StopSend() { return Status::OK(); } +Status DEPipeline::ContinueSend() { + // tree_.root() must be DeviceQueueOp + DeviceQueueOp *op = dynamic_cast(tree_->root().get()); + if (op == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp"); + } + op->ContinueSend(); + return Status::OK(); +} + int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h index 80a2c3c2ab..7002237e58 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h @@ -203,6 +203,9 @@ class DEPipeline { Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); Status StopSend(); + + Status ContinueSend(); + Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 2f99262d09..883c7e6d0e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() { } if (current_buffer->eoe() && send_epoch_end_) { TensorRow currRow; + while (stop_send_) { + MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal..."; + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); if (status == TdtStatus::FAILED) { @@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() { return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); } } + MS_LOG(INFO) << "an epoch has already sent, now stop send data."; + stop_send_ = true; } if (isProfilingEnable) { connector_size = ChildOpConnectorSize(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 99feb4ea0e..6b84d60b16 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp { void StopSend() { stop_send_ = true; } + void ContinueSend() { + MS_LOG(INFO) << "continue send at the beginning of the epoch"; + stop_send_ = false; + } + // Name: Print() // Description: A function that prints info about the node void Print(std::ostream &out, // In: The output stream to print to diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index cc99dc9052..d8531b4fa5 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2588,6 +2588,9 @@ class TransferDataset(DatasetOp): def stop_send(self): self.iterator.depipeline.StopSend() + def continue_send(self): + self.iterator.depipeline.ContinueSend() + class RangeDataset(MappableDataset): """ diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 93dba6239a..2d3e2e7ce1 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -163,6 +163,10 @@ class DatasetHelper: """Free up resources about data sink.""" self.iter.stop_send() + def continue_send(self): + """continue send data to device at the beginning of epoch.""" + self.iter.continue_send() + class _DatasetIter: """Base iter for dataset helper""" @@ -182,6 +186,7 @@ class _DatasetIter: _send_data_no_flag(dataset, epoch_num) self.stop_send = dataset.__TRANSFER_DATASET__.stop_send + self.continue_send = dataset.__TRANSFER_DATASET__.continue_send self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): diff --git a/mindspore/train/model.py b/mindspore/train/model.py index a523d44f9a..7e997edfd4 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -442,6 +442,7 @@ class Model: cb_params.net_outputs = outputs list_callback.step_end(run_context) + dataset_helper.continue_send() list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 84f0ae00da..0ee0a90813 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -64,6 +64,9 @@ class MindData: def stop_send(self): pass + def continue_send(self): + pass + def __len__(self): return self._size diff --git a/tests/st/networks/models/deeplabv3/test_deeplabv3.py b/tests/st/networks/models/deeplabv3/test_deeplabv3.py index 5a5fab107a..9131385c9a 100644 --- a/tests/st/networks/models/deeplabv3/test_deeplabv3.py +++ b/tests/st/networks/models/deeplabv3/test_deeplabv3.py @@ -98,6 +98,6 @@ def test_deeplabv3_1p(): print("expect loss: ", callback.loss) print("expect time: ", callback.time) expect_loss = 0.92 - expect_time = 40 + expect_time = 43 assert callback.loss.asnumpy() <= expect_loss assert callback.time <= expect_time