diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index b3b364b00c7..7eea5501d5a 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -261,8 +261,7 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph auto &kernels = graph->execution_order(); for (auto &kernel : kernels) { MS_EXCEPTION_IF_NULL(kernel); - auto kernel_name = AnfAlgo::GetCNodeName(kernel); - if (kernel_name == kAllReduceOpName) { + if (AnfAlgo::IsCommunicationOp(kernel)) { AllocCommunicationOpInputDynamicRes(kernel); AllocCommunicationOpOutputDynamicRes(kernel); } @@ -272,27 +271,31 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; size_t total_size = 0; std::vector size_list; DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - // The inputs of communication kernel are not released. - if (device_address->ptr_ != nullptr) { - MS_LOG(INFO) << "The inputs of communication kernel are not released."; - mem_manager_->FreeMemFromMemPool(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; } total_size += device_address->size_; size_list.emplace_back(device_address->size_); addr_list.emplace_back(device_address); } - mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); } void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; size_t total_size = 0; std::vector size_list; DeviceAddressPtrList addr_list; @@ -302,15 +305,33 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - // The outputs of communication kernel are not released. - if (device_address->ptr_ != nullptr) { - MS_LOG(INFO) << "The outputs of communication kernel are not released."; - mem_manager_->FreeMemFromMemPool(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; } total_size += output_sizes[i]; size_list.emplace_back(output_sizes[i]); addr_list.emplace_back(device_address); } + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); +} + +void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + if (!is_need_alloc_memory) { + return; + } + if (is_need_free_memory) { + for (const auto &iter : addr_list) { + MS_EXCEPTION_IF_NULL(iter); + // Free the inputs/outputs of communication kernel which are not released. + if (iter->ptr_ != nullptr) { + mem_manager_->FreeMemFromMemPool(iter); + } + } + } mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index 33d4b4be70c..6f0eefc27a5 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -58,6 +58,9 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); + void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); std::unordered_map mem_reuse_util_map_; diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc index dce54495b0e..8dd8dfb5e0b 100644 --- a/mindspore/ccsrc/device/memory_manager.cc +++ b/mindspore/ccsrc/device/memory_manager.cc @@ -172,7 +172,7 @@ void MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList ad std::vector size_list) { auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); if (addr_list.size() != device_ptr_list.size()) { - MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; } for (size_t i = 0; i < addr_list.size(); i++) { MS_EXCEPTION_IF_NULL(device_ptr_list[i]); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 3d5be5298ae..e1a18d95da0 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -514,10 +514,6 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; } } - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; - } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetOutputAddr(output_idx); @@ -539,10 +535,6 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; } } - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; - } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableOutputAddr(output_idx);