forked from mindspore-Ecosystem/mindspore
!2560 optimize the graph output of all nop node
Merge pull request !2560 from limingqi107/master
This commit is contained in:
commit
f7bf4bcd22
|
@ -30,6 +30,7 @@
|
|||
#include "kernel/common_utils.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "ir/value.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
|
||||
|
@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
|
@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
||||
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
|
||||
}
|
||||
auto is_all_nop_node = opt::IsAllNopNode(&graph);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
|
||||
DeviceAddressPtr device_address;
|
||||
if (is_all_nop_node) {
|
||||
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
|
||||
} else {
|
||||
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
kernel_inputs->emplace_back(input);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
||||
DeviceAddressPtr device_address;
|
||||
if (is_all_nop_node) {
|
||||
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
} else {
|
||||
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
output->addr = device_address->ptr_;
|
||||
|
@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
kernel_outputs->emplace_back(output);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
|
||||
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
|
@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
|
|
|
@ -96,8 +96,8 @@ class KernelRuntime {
|
|||
|
||||
private:
|
||||
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
|
||||
void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
|
||||
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
||||
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
|
||||
bool LaunchKernelMod(const session::KernelGraph &graph);
|
||||
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
|
||||
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);
|
||||
|
|
|
@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
}
|
||||
}
|
||||
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
|
||||
auto address = AnfAlgo::GetOutputAddr(node, output_index);
|
||||
DeviceAddressPtr address;
|
||||
auto is_all_nop_node = opt::IsAllNopNode(&graph);
|
||||
if (is_all_nop_node) {
|
||||
// The graph does not remove the nop node.
|
||||
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
|
||||
} else {
|
||||
// The graph removes the nop node.
|
||||
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
TypeId type_id = kNumberTypeFloat32;
|
||||
|
@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
|
||||
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_dirty(false);
|
||||
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
|
||||
|
|
Loading…
Reference in New Issue