forked from mindspore-Ecosystem/mindspore
!6221 stop send data to device after end of sequence
Merge pull request !6221 from anzhengqi/stop-send-at-eos
This commit is contained in:
commit
f6ac30ef29
|
@ -81,6 +81,7 @@ PYBIND_REGISTER(
|
|||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
|
|
|
@ -291,6 +291,16 @@ Status DEPipeline::StopSend() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ContinueSend() {
|
||||
// tree_.root() must be DeviceQueueOp
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
|
||||
if (op == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp");
|
||||
}
|
||||
op->ContinueSend();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||
|
|
|
@ -203,6 +203,9 @@ class DEPipeline {
|
|||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status StopSend();
|
||||
|
||||
Status ContinueSend();
|
||||
|
||||
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
|
|
|
@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
if (current_buffer->eoe() && send_epoch_end_) {
|
||||
TensorRow currRow;
|
||||
while (stop_send_) {
|
||||
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
||||
}
|
||||
auto status =
|
||||
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
|
@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||
stop_send_ = true;
|
||||
}
|
||||
if (isProfilingEnable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
|
|
|
@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp {
|
|||
|
||||
void StopSend() { stop_send_ = true; }
|
||||
|
||||
void ContinueSend() {
|
||||
MS_LOG(INFO) << "continue send at the beginning of the epoch";
|
||||
stop_send_ = false;
|
||||
}
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the node
|
||||
void Print(std::ostream &out, // In: The output stream to print to
|
||||
|
|
|
@ -2585,6 +2585,9 @@ class TransferDataset(DatasetOp):
|
|||
def stop_send(self):
|
||||
self.iterator.depipeline.StopSend()
|
||||
|
||||
def continue_send(self):
|
||||
self.iterator.depipeline.ContinueSend()
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
|
|
|
@ -163,6 +163,10 @@ class DatasetHelper:
|
|||
"""Free up resources about data sink."""
|
||||
self.iter.stop_send()
|
||||
|
||||
def continue_send(self):
|
||||
"""continue send data to device at the beginning of epoch."""
|
||||
self.iter.continue_send()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
|
@ -182,6 +186,7 @@ class _DatasetIter:
|
|||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -442,6 +442,7 @@ class Model:
|
|||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
dataset_helper.continue_send()
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
|
|
|
@ -64,6 +64,9 @@ class MindData:
|
|||
def stop_send(self):
|
||||
pass
|
||||
|
||||
def continue_send(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
|
|
|
@ -98,6 +98,6 @@ def test_deeplabv3_1p():
|
|||
print("expect loss: ", callback.loss)
|
||||
print("expect time: ", callback.time)
|
||||
expect_loss = 0.92
|
||||
expect_time = 40
|
||||
expect_time = 43
|
||||
assert callback.loss.asnumpy() <= expect_loss
|
||||
assert callback.time <= expect_time
|
||||
|
|
Loading…
Reference in New Issue