forked from OSSInnovation/mindspore
stop send data to device after end of sequence
This commit is contained in:
parent
144998894f
commit
8e1a2ef5ae
|
@ -81,6 +81,7 @@ PYBIND_REGISTER(
|
||||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||||
|
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
|
||||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -291,6 +291,16 @@ Status DEPipeline::StopSend() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DEPipeline::ContinueSend() {
|
||||||
|
// tree_.root() must be DeviceQueueOp
|
||||||
|
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
|
||||||
|
if (op == nullptr) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp");
|
||||||
|
}
|
||||||
|
op->ContinueSend();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||||
|
|
||||||
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||||
|
|
|
@ -203,6 +203,9 @@ class DEPipeline {
|
||||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||||
|
|
||||||
Status StopSend();
|
Status StopSend();
|
||||||
|
|
||||||
|
Status ContinueSend();
|
||||||
|
|
||||||
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||||
std::shared_ptr<DatasetOp> *bottom);
|
std::shared_ptr<DatasetOp> *bottom);
|
||||||
|
|
||||||
|
|
|
@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
}
|
}
|
||||||
if (current_buffer->eoe() && send_epoch_end_) {
|
if (current_buffer->eoe() && send_epoch_end_) {
|
||||||
TensorRow currRow;
|
TensorRow currRow;
|
||||||
|
while (stop_send_) {
|
||||||
|
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
|
||||||
|
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
||||||
|
}
|
||||||
auto status =
|
auto status =
|
||||||
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
|
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
|
||||||
if (status == TdtStatus::FAILED) {
|
if (status == TdtStatus::FAILED) {
|
||||||
|
@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||||
|
stop_send_ = true;
|
||||||
}
|
}
|
||||||
if (isProfilingEnable) {
|
if (isProfilingEnable) {
|
||||||
connector_size = ChildOpConnectorSize();
|
connector_size = ChildOpConnectorSize();
|
||||||
|
|
|
@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
|
|
||||||
void StopSend() { stop_send_ = true; }
|
void StopSend() { stop_send_ = true; }
|
||||||
|
|
||||||
|
void ContinueSend() {
|
||||||
|
MS_LOG(INFO) << "continue send at the beginning of the epoch";
|
||||||
|
stop_send_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Name: Print()
|
// Name: Print()
|
||||||
// Description: A function that prints info about the node
|
// Description: A function that prints info about the node
|
||||||
void Print(std::ostream &out, // In: The output stream to print to
|
void Print(std::ostream &out, // In: The output stream to print to
|
||||||
|
|
|
@ -2588,6 +2588,9 @@ class TransferDataset(DatasetOp):
|
||||||
def stop_send(self):
|
def stop_send(self):
|
||||||
self.iterator.depipeline.StopSend()
|
self.iterator.depipeline.StopSend()
|
||||||
|
|
||||||
|
def continue_send(self):
|
||||||
|
self.iterator.depipeline.ContinueSend()
|
||||||
|
|
||||||
|
|
||||||
class RangeDataset(MappableDataset):
|
class RangeDataset(MappableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -163,6 +163,10 @@ class DatasetHelper:
|
||||||
"""Free up resources about data sink."""
|
"""Free up resources about data sink."""
|
||||||
self.iter.stop_send()
|
self.iter.stop_send()
|
||||||
|
|
||||||
|
def continue_send(self):
|
||||||
|
"""continue send data to device at the beginning of epoch."""
|
||||||
|
self.iter.continue_send()
|
||||||
|
|
||||||
|
|
||||||
class _DatasetIter:
|
class _DatasetIter:
|
||||||
"""Base iter for dataset helper"""
|
"""Base iter for dataset helper"""
|
||||||
|
@ -182,6 +186,7 @@ class _DatasetIter:
|
||||||
_send_data_no_flag(dataset, epoch_num)
|
_send_data_no_flag(dataset, epoch_num)
|
||||||
|
|
||||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||||
|
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
|
||||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -442,6 +442,7 @@ class Model:
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
list_callback.step_end(run_context)
|
list_callback.step_end(run_context)
|
||||||
|
|
||||||
|
dataset_helper.continue_send()
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.epoch_end(run_context)
|
||||||
should_stop = should_stop or run_context.get_stop_requested()
|
should_stop = should_stop or run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
|
|
|
@ -64,6 +64,9 @@ class MindData:
|
||||||
def stop_send(self):
|
def stop_send(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def continue_send(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,6 @@ def test_deeplabv3_1p():
|
||||||
print("expect loss: ", callback.loss)
|
print("expect loss: ", callback.loss)
|
||||||
print("expect time: ", callback.time)
|
print("expect time: ", callback.time)
|
||||||
expect_loss = 0.92
|
expect_loss = 0.92
|
||||||
expect_time = 40
|
expect_time = 43
|
||||||
assert callback.loss.asnumpy() <= expect_loss
|
assert callback.loss.asnumpy() <= expect_loss
|
||||||
assert callback.time <= expect_time
|
assert callback.time <= expect_time
|
||||||
|
|
Loading…
Reference in New Issue