From 1ed1088cf725857722c5e5cdab3ca25ba34acb77 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Fri, 11 Mar 2022 11:56:58 +0800 Subject: [PATCH] graph output address alloc free dynamically to support the dynamic shape sink mode --- .../graph_scheduler/actor/abstract_actor.cc | 4 +- .../graph_scheduler/actor/actor_common.cc | 6 +- .../graph_scheduler/actor/actor_common.h | 1 + .../actor/data_prepare_actor.h | 3 + .../graph_scheduler/actor/loop_count_actor.cc | 4 +- .../graph_scheduler/actor/output_actor.cc | 91 +++++++++++++------ .../graph_scheduler/graph_scheduler.cc | 19 +++- .../runtime/graph_scheduler/graph_scheduler.h | 1 + 8 files changed, 94 insertions(+), 35 deletions(-) diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.cc index 8085a482836..5e4ba2aded8 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.cc @@ -82,7 +82,7 @@ bool AbstractActor::CheckRunningCondition(const OpContext *context void AbstractActor::EraseInput(const OpContext *context) { MS_EXCEPTION_IF_NULL(context); - if (input_datas_num_ != 0) { + if ((input_datas_num_ != 0) && (!input_op_datas_.empty())) { auto ret = input_op_datas_.erase(context->sequential_num_); if (ret == 0) { std::string error_info = "Erase input data failed: " + GetAID().Name(); @@ -92,7 +92,7 @@ void AbstractActor::EraseInput(const OpContext *context) { } } - if (input_controls_num_ != 0) { + if ((input_controls_num_ != 0) && (!input_op_controls_.empty())) { auto ret = input_op_controls_.erase(context->sequential_num_); if (ret == 0) { std::string error_info = "Erase input controls failed: " + GetAID().Name(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc index f64b51b139c..36af74fea57 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc @@ -191,7 +191,7 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c UpdateRefCount(device_tensor.get(), is_max_ref_count); } -void FreeMemoryInner(DeviceTensor *const device_tensor, const DeviceContext *device_context) { +void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context) { MS_EXCEPTION_IF_NULL(device_tensor); // The device context may be not accurate in the control flow scene, so need fetch by device name and device id. if ((device_context == nullptr) || (device_context->GetDeviceAddressType() != device_tensor->DeviceType())) { @@ -213,7 +213,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext device_tensor->DecreaseRefCount(); if (device_tensor->ref_count() == 0) { if (device_tensor->GetPtr() != nullptr) { - FreeMemoryInner(device_tensor, device_context); + FreeMemory(device_tensor, device_context); } device_tensor->ResetRefCount(); } @@ -222,7 +222,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext device_tensor->DecreaseDynamicRefCount(op_name); if ((device_tensor->dynamic_ref_count() == 0) && (device_tensor->GetPtr() != nullptr)) { MS_LOG(DEBUG) << "Free memory by the dynamic reference count, device address" << device_tensor->GetPtr(); - FreeMemoryInner(device_tensor, device_context); + FreeMemory(device_tensor, device_context); } } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h index 408b168fa56..d44c6bb9d6f 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h @@ -221,6 +221,7 @@ void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = f void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false); void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext *device_context, const std::string &op_name); +void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context); // Get front node by backend node. AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h index c9236618fa6..8309ca978e0 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.h @@ -67,6 +67,9 @@ class DataPrepareActor : public DebugAwareActor { protected: void Init() override; + void Run(OpContext *const context) override { + PrepareData({}, context, GraphExecutionStrategy::kPipeline); + } private: friend class GraphScheduler; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc index 44b5632907b..8150dc0508c 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc @@ -83,9 +83,7 @@ void LoopCountActor::SendOutput(OpContext *const context) { } // Send to DataPrepareActor to trigger next step running. - std::vector> input_tensors; - ActorDispatcher::Send(data_prepare_aid_, &DataPrepareActor::PrepareData, input_tensors, context, - GraphExecutionStrategy::kPipeline); + ActorDispatcher::Send(data_prepare_aid_, &OpActor::RunOpControl, from_aid, context); } } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc index 08fedc04e39..5207f3cf359 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc @@ -20,6 +20,38 @@ namespace mindspore { namespace runtime { +bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) { + MS_EXCEPTION_IF_NULL(output_node); + MS_EXCEPTION_IF_NULL(output_device_tensor); + // The persisted address can't be replaced. + if (output_device_tensor->is_ptr_persisted()) { + return true; + } + + if (output_node->isa()) { + return true; + } + + // In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced. + // But in the dynamic shape scenario, need to free the old memory and alloc new memory using the new shape size. + if (output_node->isa() && !(output_node->cast()->has_dynamic_shape())) { + return true; + } + + return false; +} + +void UpdateOutputTensorShape(const std::vector &output_tensors, + const std::vector &output_nodes) { + for (size_t i = 0; i < output_tensors.size(); ++i) { + MS_EXCEPTION_IF_NULL(output_tensors[i]); + auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes[i].first, output_nodes[i].second); + std::vector temp_shape; + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + output_tensors[i]->set_shape(temp_shape); + } +} + void OutputActor::Init() { // Check device contexts number. if (device_contexts_.size() != output_nodes_.size()) { @@ -29,6 +61,10 @@ void OutputActor::Init() { if (output_nodes_.size() != outputs_.size()) { MS_LOG(EXCEPTION) << "The outputs number is wrong."; } + // Check output device tensors number. + if (outputs_.size() != output_device_tensors_.size()) { + MS_LOG(EXCEPTION) << "The output device tensors number is wrong."; + } // Set the number of actor running dependent messages. running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size()); @@ -38,6 +74,7 @@ void OutputActor::RunOpControl(AID *const, OpContext *const contex MS_EXCEPTION_IF_NULL(context); ++current_count_; + // The last loop. if (loop_count_ == current_count_) { if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) { std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) + @@ -61,26 +98,33 @@ void OutputActor::RunOpControl(AID *const, OpContext *const contex output_device_tensors_[device_tensor_store_key.first] = device_tensor.get(); } - // For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here - // Check outputs number. - if (output_nodes_.size() != outputs_.size()) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is wrong."); - } - // Update output tensor's shape - for (size_t i = 0; i < outputs_.size(); ++i) { - auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes_[i].first, output_nodes_[i].second); - std::vector temp_shape; - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - if (outputs_[i] == nullptr) { - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs_[i] is nullptr."); - } - outputs_[i]->set_shape(temp_shape); - } + // For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here. + UpdateOutputTensorShape(outputs_, output_nodes_); current_outputs_num_ = 0; current_count_ = 0; SET_OPCONTEXT_SUCCESS_RET((*context)); } + + // The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory. + // 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene. + // 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene. + for (size_t i = 0; i < output_nodes_.size(); ++i) { + auto &output_node = output_nodes_[i].first; + auto &output_device_tensor = output_device_tensors_[i]; + if ((output_node == nullptr) || (output_device_tensor == nullptr)) { + return; + } + if (!IsOutputAddressPersisted(output_device_tensor, output_node)) { + FreeMemory(output_device_tensor, device_contexts_[i]); + } + } + + // Send control arrow to trigger next step running. + auto from_aid = const_cast(&GetAID()); + for (auto &output_control : output_control_arrows_) { + ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context); + } } void OutputActor::RunOpData(OpData *const input_data, OpContext *const context) { @@ -92,16 +136,17 @@ void OutputActor::RunOpData(OpData *const input_data, OpContext= outputs_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range."); } + // Save the output nodes and output device tensors. + auto node_with_index = input_data->data_->GetNodeIndex(); + MS_EXCEPTION_IF_NULL(node_with_index.first); + output_nodes_[output_position] = node_with_index; + output_device_tensors_[output_position] = input_data->data_; // Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1". if (loop_count_ - current_count_ != 1) { - // The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory in - // the no last loop to avoid the memory leak when memory used by dynamic ref count. - FreeMemoryByRefCount(input_data->data_, device_contexts_[output_position], GetAID().Name()); return; } - auto node_with_index = input_data->data_->GetNodeIndex(); auto tensor = CreateOutputTensor(node_with_index.first, node_with_index.second, output_position); if (tensor == nullptr) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Create output tensor failed."); @@ -109,10 +154,6 @@ void OutputActor::RunOpData(OpData *const input_data, OpContextset_need_release_device_mem(true); outputs_[output_position] = tensor; current_outputs_num_++; - - // Save the output nodes to clear the device tensor in the running end. - output_nodes_[output_position] = node_with_index; - output_device_tensors_[output_position] = input_data->data_; } TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) { @@ -193,9 +234,7 @@ void OutputActor::UpdateOutputDeviceAddress() { } // If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data: - // 1.In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced. - // 2.The persisted address can't be replaced. - if (output_node->isa() || output_node->isa() || device_tensor->is_ptr_persisted()) { + if (IsOutputAddressPersisted(device_tensor, output_node)) { auto device_context = device_contexts_[i]; MS_EXCEPTION_IF_NULL(device_context); device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name()); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc index 707207c205c..628a436bb6a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc @@ -1506,11 +1506,14 @@ void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set, graph_compiler_info.control_node_parser_); } - // Link control arrows for custom actor + + // Link control arrows for custom actor. LinkControlArrowForCustomActor(actor_set, graph_compiler_info); LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set, graph_compiler_info.control_node_parser_); + + LinkControlArrowForOutputActor(actor_set->output_actor_.get(), actor_set); } void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set, @@ -1730,6 +1733,20 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun // Loop count actor --> data prepare actor. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_); loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID(); + actor_set->data_prepare_actor_->input_controls_num_++; + (void)actor_set->data_prepare_actor_->input_control_arrow_aids_.emplace_back(loop_count_actor->GetAID()); +} + +void GraphScheduler::LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set) { + MS_EXCEPTION_IF_NULL(actor_set); + // There is no output actor in step mode. + if (output_actor == nullptr) { + return; + } + + // Output actor --> data prepare actor. + // The output actor needs to free the output memory in the running and needs this control arrow. + AddControlArrow(output_actor, actor_set->data_prepare_actor_.get()); } void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h index 2240b31eccb..c6ee760b378 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h @@ -176,6 +176,7 @@ class GraphScheduler { const ControlNodeParserPtr &parser); void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set, const ControlNodeParserPtr &parser); + void LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set); // 3. The processing of linking output result arrows. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);