From 4949f6ca3989d4cd2abecc2264e643d037d2b9ca Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Wed, 24 Jun 2020 14:56:43 +0800 Subject: [PATCH] optimize the graph output of all nop node --- mindspore/ccsrc/device/kernel_runtime.cc | 27 ++++++++++++++++++------ mindspore/ccsrc/device/kernel_runtime.h | 4 ++-- mindspore/ccsrc/session/session_basic.cc | 12 +++++++++-- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 07b9257bb2c..43b350fab71 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -30,6 +30,7 @@ #include "kernel/common_utils.h" #include "kernel/oplib/oplib.h" #include "ir/value.h" +#include "pre_activate/common/helper.h" using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; @@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { } } -void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, +void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, AddressPtrList *kernel_outputs) { MS_EXCEPTION_IF_NULL(kernel); @@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { return GenAddrCleanLaunchArgs(cnode, kernel_inputs); } + auto is_all_nop_node = opt::IsAllNopNode(&graph); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); + DeviceAddressPtr device_address; + if (is_all_nop_node) { + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false); + } else { + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true); + } MS_EXCEPTION_IF_NULL(device_address); kernel::AddressPtr input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); @@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod kernel_inputs->emplace_back(input); } - for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetOutputAddr(kernel, i); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) { + DeviceAddressPtr device_address; + if (is_all_nop_node) { + device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + } else { + device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true); + } + MS_EXCEPTION_IF_NULL(device_address); kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); output->addr = device_address->ptr_; @@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod kernel_outputs->emplace_back(output); } - for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); kernel::AddressPtr workspace = std::make_shared(); MS_EXCEPTION_IF_NULL(workspace); @@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { MS_LOG(ERROR) << "Launch kernel failed."; diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 8442342e322..c69487c6f17 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -96,8 +96,8 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); + void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, + AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 730c20d6990..9f5ba81f904 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne } } // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) - auto address = AnfAlgo::GetOutputAddr(node, output_index); + DeviceAddressPtr address; + auto is_all_nop_node = opt::IsAllNopNode(&graph); + if (is_all_nop_node) { + // The graph does not remove the nop node. + address = AnfAlgo::GetMutableOutputAddr(node, output_index, false); + } else { + // The graph removes the nop node. + address = AnfAlgo::GetMutableOutputAddr(node, output_index, true); + } MS_EXCEPTION_IF_NULL(address); auto shape = AnfAlgo::GetOutputInferShape(node, output_index); TypeId type_id = kNumberTypeFloat32; @@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); + tensor->set_device_address(address); tensor->set_dirty(false); } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), LongToSize(tensor->data().nbytes()), tensor->data_type(),