optimize the graph output of all nop node

This commit is contained in:
limingqi107 2020-06-24 14:56:43 +08:00
parent 107618d607
commit 4949f6ca39
3 changed files with 33 additions and 10 deletions

View File

@ -30,6 +30,7 @@
#include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h"
#include "ir/value.h"
#include "pre_activate/common/helper.h"
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
}
}
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel);
@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
}
auto is_all_nop_node = opt::IsAllNopNode(&graph);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
} else {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_inputs->emplace_back(input);
}
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
} else {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr output = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output);
output->addr = device_address->ptr_;
@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_outputs->emplace_back(output);
}
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";

View File

@ -96,8 +96,8 @@ class KernelRuntime {
private:
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph);
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);

View File

@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto address = AnfAlgo::GetOutputAddr(node, output_index);
DeviceAddressPtr address;
auto is_all_nop_node = opt::IsAllNopNode(&graph);
if (is_all_nop_node) {
// The graph does not remove the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
// The graph removes the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
}
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32;
@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_device_address(address);
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),