!6221 stop send data to device after end of sequence

Merge pull request !6221 from anzhengqi/stop-send-at-eos
This commit is contained in:
mindspore-ci-bot 2020-09-15 16:10:35 +08:00 committed by Gitee
commit f6ac30ef29
10 changed files with 38 additions and 1 deletions

View File

@ -81,6 +81,7 @@ PYBIND_REGISTER(
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;

View File

@ -291,6 +291,16 @@ Status DEPipeline::StopSend() {
return Status::OK();
}
Status DEPipeline::ContinueSend() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
if (op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp");
}
op->ContinueSend();
return Status::OK();
}
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }

View File

@ -203,6 +203,9 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ContinueSend();
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);

View File

@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() {
}
if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow;
while (stop_send_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) {
@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
stop_send_ = true;
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();

View File

@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp {
void StopSend() { stop_send_ = true; }
void ContinueSend() {
MS_LOG(INFO) << "continue send at the beginning of the epoch";
stop_send_ = false;
}
// Name: Print()
// Description: A function that prints info about the node
void Print(std::ostream &out, // In: The output stream to print to

View File

@ -2585,6 +2585,9 @@ class TransferDataset(DatasetOp):
def stop_send(self):
self.iterator.depipeline.StopSend()
def continue_send(self):
self.iterator.depipeline.ContinueSend()
class RangeDataset(MappableDataset):
"""

View File

@ -163,6 +163,10 @@ class DatasetHelper:
"""Free up resources about data sink."""
self.iter.stop_send()
def continue_send(self):
"""continue send data to device at the beginning of epoch."""
self.iter.continue_send()
class _DatasetIter:
"""Base iter for dataset helper"""
@ -182,6 +186,7 @@ class _DatasetIter:
_send_data_no_flag(dataset, epoch_num)
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
def __iter__(self):

View File

@ -442,6 +442,7 @@ class Model:
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
dataset_helper.continue_send()
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:

View File

@ -64,6 +64,9 @@ class MindData:
def stop_send(self):
pass
def continue_send(self):
pass
def __len__(self):
return self._size

View File

@ -98,6 +98,6 @@ def test_deeplabv3_1p():
print("expect loss: ", callback.loss)
print("expect time: ", callback.time)
expect_loss = 0.92
expect_time = 40
expect_time = 43
assert callback.loss.asnumpy() <= expect_loss
assert callback.time <= expect_time