forked from mindspore-Ecosystem/mindspore
graph output address alloc free dynamically to support the dynamic shape sink mode
This commit is contained in:
parent
1aa1655b41
commit
1ed1088cf7
|
@ -82,7 +82,7 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
|
|||
|
||||
void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
if ((input_datas_num_ != 0) && (!input_op_datas_.empty())) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
|
@ -92,7 +92,7 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
if ((input_controls_num_ != 0) && (!input_op_controls_.empty())) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
|
|
|
@ -191,7 +191,7 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c
|
|||
UpdateRefCount(device_tensor.get(), is_max_ref_count);
|
||||
}
|
||||
|
||||
void FreeMemoryInner(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
|
||||
void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// The device context may be not accurate in the control flow scene, so need fetch by device name and device id.
|
||||
if ((device_context == nullptr) || (device_context->GetDeviceAddressType() != device_tensor->DeviceType())) {
|
||||
|
@ -213,7 +213,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
|
|||
device_tensor->DecreaseRefCount();
|
||||
if (device_tensor->ref_count() == 0) {
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
FreeMemoryInner(device_tensor, device_context);
|
||||
FreeMemory(device_tensor, device_context);
|
||||
}
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
@ -222,7 +222,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
|
|||
device_tensor->DecreaseDynamicRefCount(op_name);
|
||||
if ((device_tensor->dynamic_ref_count() == 0) && (device_tensor->GetPtr() != nullptr)) {
|
||||
MS_LOG(DEBUG) << "Free memory by the dynamic reference count, device address" << device_tensor->GetPtr();
|
||||
FreeMemoryInner(device_tensor, device_context);
|
||||
FreeMemory(device_tensor, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -221,6 +221,7 @@ void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = f
|
|||
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false);
|
||||
void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext *device_context,
|
||||
const std::string &op_name);
|
||||
void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context);
|
||||
|
||||
// Get front node by backend node.
|
||||
AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
|
||||
|
|
|
@ -67,6 +67,9 @@ class DataPrepareActor : public DebugAwareActor {
|
|||
|
||||
protected:
|
||||
void Init() override;
|
||||
void Run(OpContext<DeviceTensor> *const context) override {
|
||||
PrepareData({}, context, GraphExecutionStrategy::kPipeline);
|
||||
}
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
|
|
@ -83,9 +83,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
// Send to DataPrepareActor to trigger next step running.
|
||||
std::vector<std::vector<TensorPtr>> input_tensors;
|
||||
ActorDispatcher::Send(data_prepare_aid_, &DataPrepareActor::PrepareData, input_tensors, context,
|
||||
GraphExecutionStrategy::kPipeline);
|
||||
ActorDispatcher::Send(data_prepare_aid_, &OpActor::RunOpControl, from_aid, context);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,38 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||
// The persisted address can't be replaced.
|
||||
if (output_device_tensor->is_ptr_persisted()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (output_node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
|
||||
// But in the dynamic shape scenario, need to free the old memory and alloc new memory using the new shape size.
|
||||
if (output_node->isa<Parameter>() && !(output_node->cast<ParameterPtr>()->has_dynamic_shape())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdateOutputTensorShape(const std::vector<TensorPtr> &output_tensors,
|
||||
const std::vector<KernelWithIndex> &output_nodes) {
|
||||
for (size_t i = 0; i < output_tensors.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(output_tensors[i]);
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes[i].first, output_nodes[i].second);
|
||||
std::vector<int64_t> temp_shape;
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
output_tensors[i]->set_shape(temp_shape);
|
||||
}
|
||||
}
|
||||
|
||||
void OutputActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != output_nodes_.size()) {
|
||||
|
@ -29,6 +61,10 @@ void OutputActor::Init() {
|
|||
if (output_nodes_.size() != outputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The outputs number is wrong.";
|
||||
}
|
||||
// Check output device tensors number.
|
||||
if (outputs_.size() != output_device_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The output device tensors number is wrong.";
|
||||
}
|
||||
|
||||
// Set the number of actor running dependent messages.
|
||||
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
|
||||
|
@ -38,6 +74,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
++current_count_;
|
||||
// The last loop.
|
||||
if (loop_count_ == current_count_) {
|
||||
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
||||
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
|
||||
|
@ -61,26 +98,33 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
|
|||
output_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
|
||||
}
|
||||
|
||||
// For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here
|
||||
// Check outputs number.
|
||||
if (output_nodes_.size() != outputs_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is wrong.");
|
||||
}
|
||||
// Update output tensor's shape
|
||||
for (size_t i = 0; i < outputs_.size(); ++i) {
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes_[i].first, output_nodes_[i].second);
|
||||
std::vector<int64_t> temp_shape;
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
if (outputs_[i] == nullptr) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs_[i] is nullptr.");
|
||||
}
|
||||
outputs_[i]->set_shape(temp_shape);
|
||||
}
|
||||
// For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here.
|
||||
UpdateOutputTensorShape(outputs_, output_nodes_);
|
||||
|
||||
current_outputs_num_ = 0;
|
||||
current_count_ = 0;
|
||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
|
||||
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
|
||||
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
|
||||
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
|
||||
for (size_t i = 0; i < output_nodes_.size(); ++i) {
|
||||
auto &output_node = output_nodes_[i].first;
|
||||
auto &output_device_tensor = output_device_tensors_[i];
|
||||
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
|
||||
return;
|
||||
}
|
||||
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
||||
FreeMemory(output_device_tensor, device_contexts_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Send control arrow to trigger next step running.
|
||||
auto from_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_control_arrows_) {
|
||||
ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
|
@ -92,16 +136,17 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
|
|||
if (output_position >= outputs_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
|
||||
}
|
||||
// Save the output nodes and output device tensors.
|
||||
auto node_with_index = input_data->data_->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
output_nodes_[output_position] = node_with_index;
|
||||
output_device_tensors_[output_position] = input_data->data_;
|
||||
|
||||
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
||||
if (loop_count_ - current_count_ != 1) {
|
||||
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory in
|
||||
// the no last loop to avoid the memory leak when memory used by dynamic ref count.
|
||||
FreeMemoryByRefCount(input_data->data_, device_contexts_[output_position], GetAID().Name());
|
||||
return;
|
||||
}
|
||||
|
||||
auto node_with_index = input_data->data_->GetNodeIndex();
|
||||
auto tensor = CreateOutputTensor(node_with_index.first, node_with_index.second, output_position);
|
||||
if (tensor == nullptr) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Create output tensor failed.");
|
||||
|
@ -109,10 +154,6 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
|
|||
tensor->set_need_release_device_mem(true);
|
||||
outputs_[output_position] = tensor;
|
||||
current_outputs_num_++;
|
||||
|
||||
// Save the output nodes to clear the device tensor in the running end.
|
||||
output_nodes_[output_position] = node_with_index;
|
||||
output_device_tensors_[output_position] = input_data->data_;
|
||||
}
|
||||
|
||||
TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
|
||||
|
@ -193,9 +234,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
|||
}
|
||||
|
||||
// If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data:
|
||||
// 1.In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
|
||||
// 2.The persisted address can't be replaced.
|
||||
if (output_node->isa<ValueNode>() || output_node->isa<Parameter>() || device_tensor->is_ptr_persisted()) {
|
||||
if (IsOutputAddressPersisted(device_tensor, output_node)) {
|
||||
auto device_context = device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name());
|
||||
|
|
|
@ -1506,11 +1506,14 @@ void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set,
|
|||
LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
}
|
||||
// Link control arrows for custom actor
|
||||
|
||||
// Link control arrows for custom actor.
|
||||
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
|
||||
|
||||
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
|
||||
LinkControlArrowForOutputActor(actor_set->output_actor_.get(), actor_set);
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
||||
|
@ -1730,6 +1733,20 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
|
|||
// Loop count actor --> data prepare actor.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
|
||||
loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
|
||||
actor_set->data_prepare_actor_->input_controls_num_++;
|
||||
(void)actor_set->data_prepare_actor_->input_control_arrow_aids_.emplace_back(loop_count_actor->GetAID());
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
// There is no output actor in step mode.
|
||||
if (output_actor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Output actor --> data prepare actor.
|
||||
// The output actor needs to free the output memory in the running and needs this control arrow.
|
||||
AddControlArrow(output_actor, actor_set->data_prepare_actor_.get());
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
||||
|
|
|
@ -176,6 +176,7 @@ class GraphScheduler {
|
|||
const ControlNodeParserPtr &parser);
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
|
||||
const ControlNodeParserPtr &parser);
|
||||
void LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set);
|
||||
|
||||
// 3. The processing of linking output result arrows.
|
||||
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
|
||||
|
|
Loading…
Reference in New Issue