graph output address alloc free dynamically to support the dynamic shape sink mode

This commit is contained in:
limingqi107 2022-03-11 11:56:58 +08:00
parent 1aa1655b41
commit 1ed1088cf7
8 changed files with 94 additions and 35 deletions

View File

@ -82,7 +82,7 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
if ((input_datas_num_ != 0) && (!input_op_datas_.empty())) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
@ -92,7 +92,7 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
}
}
if (input_controls_num_ != 0) {
if ((input_controls_num_ != 0) && (!input_op_controls_.empty())) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();

View File

@ -191,7 +191,7 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c
UpdateRefCount(device_tensor.get(), is_max_ref_count);
}
void FreeMemoryInner(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_tensor);
// The device context may be not accurate in the control flow scene, so need fetch by device name and device id.
if ((device_context == nullptr) || (device_context->GetDeviceAddressType() != device_tensor->DeviceType())) {
@ -213,7 +213,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
if (device_tensor->GetPtr() != nullptr) {
FreeMemoryInner(device_tensor, device_context);
FreeMemory(device_tensor, device_context);
}
device_tensor->ResetRefCount();
}
@ -222,7 +222,7 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
device_tensor->DecreaseDynamicRefCount(op_name);
if ((device_tensor->dynamic_ref_count() == 0) && (device_tensor->GetPtr() != nullptr)) {
MS_LOG(DEBUG) << "Free memory by the dynamic reference count, device address" << device_tensor->GetPtr();
FreeMemoryInner(device_tensor, device_context);
FreeMemory(device_tensor, device_context);
}
}
}

View File

@ -221,6 +221,7 @@ void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = f
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false);
void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext *device_context,
const std::string &op_name);
void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_context);
// Get front node by backend node.
AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);

View File

@ -67,6 +67,9 @@ class DataPrepareActor : public DebugAwareActor {
protected:
void Init() override;
void Run(OpContext<DeviceTensor> *const context) override {
PrepareData({}, context, GraphExecutionStrategy::kPipeline);
}
private:
friend class GraphScheduler;

View File

@ -83,9 +83,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
}
// Send to DataPrepareActor to trigger next step running.
std::vector<std::vector<TensorPtr>> input_tensors;
ActorDispatcher::Send(data_prepare_aid_, &DataPrepareActor::PrepareData, input_tensors, context,
GraphExecutionStrategy::kPipeline);
ActorDispatcher::Send(data_prepare_aid_, &OpActor::RunOpControl, from_aid, context);
}
} // namespace runtime
} // namespace mindspore

View File

@ -20,6 +20,38 @@
namespace mindspore {
namespace runtime {
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(output_device_tensor);
// The persisted address can't be replaced.
if (output_device_tensor->is_ptr_persisted()) {
return true;
}
if (output_node->isa<ValueNode>()) {
return true;
}
// In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
// But in the dynamic shape scenario, need to free the old memory and alloc new memory using the new shape size.
if (output_node->isa<Parameter>() && !(output_node->cast<ParameterPtr>()->has_dynamic_shape())) {
return true;
}
return false;
}
void UpdateOutputTensorShape(const std::vector<TensorPtr> &output_tensors,
const std::vector<KernelWithIndex> &output_nodes) {
for (size_t i = 0; i < output_tensors.size(); ++i) {
MS_EXCEPTION_IF_NULL(output_tensors[i]);
auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes[i].first, output_nodes[i].second);
std::vector<int64_t> temp_shape;
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
output_tensors[i]->set_shape(temp_shape);
}
}
void OutputActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != output_nodes_.size()) {
@ -29,6 +61,10 @@ void OutputActor::Init() {
if (output_nodes_.size() != outputs_.size()) {
MS_LOG(EXCEPTION) << "The outputs number is wrong.";
}
// Check output device tensors number.
if (outputs_.size() != output_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "The output device tensors number is wrong.";
}
// Set the number of actor running dependent messages.
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
@ -38,6 +74,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
MS_EXCEPTION_IF_NULL(context);
++current_count_;
// The last loop.
if (loop_count_ == current_count_) {
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
@ -61,26 +98,33 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
output_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
}
// For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here
// Check outputs number.
if (output_nodes_.size() != outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs number is wrong.");
}
// Update output tensor's shape
for (size_t i = 0; i < outputs_.size(); ++i) {
auto shape = common::AnfAlgo::GetOutputInferShape(output_nodes_[i].first, output_nodes_[i].second);
std::vector<int64_t> temp_shape;
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
if (outputs_[i] == nullptr) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The outputs_[i] is nullptr.");
}
outputs_[i]->set_shape(temp_shape);
}
// For dynamic_shape, UpdateOp maybe run after RunOpData, so it's needed to update shape of output tensor here.
UpdateOutputTensorShape(outputs_, output_nodes_);
current_outputs_num_ = 0;
current_count_ = 0;
SET_OPCONTEXT_SUCCESS_RET((*context));
}
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
for (size_t i = 0; i < output_nodes_.size(); ++i) {
auto &output_node = output_nodes_[i].first;
auto &output_device_tensor = output_device_tensors_[i];
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
return;
}
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
FreeMemory(output_device_tensor, device_contexts_[i]);
}
}
// Send control arrow to trigger next step running.
auto from_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context);
}
}
void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
@ -92,16 +136,17 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
if (output_position >= outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
}
// Save the output nodes and output device tensors.
auto node_with_index = input_data->data_->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
output_nodes_[output_position] = node_with_index;
output_device_tensors_[output_position] = input_data->data_;
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
if (loop_count_ - current_count_ != 1) {
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory in
// the no last loop to avoid the memory leak when memory used by dynamic ref count.
FreeMemoryByRefCount(input_data->data_, device_contexts_[output_position], GetAID().Name());
return;
}
auto node_with_index = input_data->data_->GetNodeIndex();
auto tensor = CreateOutputTensor(node_with_index.first, node_with_index.second, output_position);
if (tensor == nullptr) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Create output tensor failed.");
@ -109,10 +154,6 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
tensor->set_need_release_device_mem(true);
outputs_[output_position] = tensor;
current_outputs_num_++;
// Save the output nodes to clear the device tensor in the running end.
output_nodes_[output_position] = node_with_index;
output_device_tensors_[output_position] = input_data->data_;
}
TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
@ -193,9 +234,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
}
// If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data:
// 1.In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
// 2.The persisted address can't be replaced.
if (output_node->isa<ValueNode>() || output_node->isa<Parameter>() || device_tensor->is_ptr_persisted()) {
if (IsOutputAddressPersisted(device_tensor, output_node)) {
auto device_context = device_contexts_[i];
MS_EXCEPTION_IF_NULL(device_context);
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name());

View File

@ -1506,11 +1506,14 @@ void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set,
LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
graph_compiler_info.control_node_parser_);
}
// Link control arrows for custom actor
// Link control arrows for custom actor.
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
graph_compiler_info.control_node_parser_);
LinkControlArrowForOutputActor(actor_set->output_actor_.get(), actor_set);
}
void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
@ -1730,6 +1733,20 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
// Loop count actor --> data prepare actor.
MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
actor_set->data_prepare_actor_->input_controls_num_++;
(void)actor_set->data_prepare_actor_->input_control_arrow_aids_.emplace_back(loop_count_actor->GetAID());
}
void GraphScheduler::LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
// There is no output actor in step mode.
if (output_actor == nullptr) {
return;
}
// Output actor --> data prepare actor.
// The output actor needs to free the output memory in the running and needs this control arrow.
AddControlArrow(output_actor, actor_set->data_prepare_actor_.get());
}
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,

View File

@ -176,6 +176,7 @@ class GraphScheduler {
const ControlNodeParserPtr &parser);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
const ControlNodeParserPtr &parser);
void LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set);
// 3. The processing of linking output result arrows.
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);