forked from mindspore-Ecosystem/mindspore
optimize is all nop node detect in mem reuse
This commit is contained in:
parent
d6d93f16b1
commit
33d1427a14
|
@ -103,6 +103,7 @@ bool MemReuseUtil::InitDynamicWorkspaceKernelRef() {
|
||||||
bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) {
|
bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
graph_ = graph;
|
graph_ = graph;
|
||||||
|
is_all_nop_node_ = opt::IsAllNopNode(graph);
|
||||||
if (!InitDynamicOutputKernelRef()) {
|
if (!InitDynamicOutputKernelRef()) {
|
||||||
MS_LOG(INFO) << "InitDynamicOutputKernelRef fail";
|
MS_LOG(INFO) << "InitDynamicOutputKernelRef fail";
|
||||||
return false;
|
return false;
|
||||||
|
@ -223,7 +224,6 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) {
|
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) {
|
||||||
auto is_all_nop_node = opt::IsAllNopNode(graph_);
|
|
||||||
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) {
|
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) {
|
||||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number "
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number "
|
||||||
<< AnfAlgo::GetInputTensorNum(kernel);
|
<< AnfAlgo::GetInputTensorNum(kernel);
|
||||||
|
@ -231,7 +231,7 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
|
||||||
auto input_node = kernel->input(input_idx + 1);
|
auto input_node = kernel->input(input_idx + 1);
|
||||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||||
session::KernelWithIndex kernel_input;
|
session::KernelWithIndex kernel_input;
|
||||||
if (is_all_nop_node) {
|
if (is_all_nop_node_) {
|
||||||
// The graph does not remove the nop node.
|
// The graph does not remove the nop node.
|
||||||
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||||
} else {
|
} else {
|
||||||
|
@ -265,7 +265,6 @@ void MemReuseUtil::SetKernelDefMap() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemReuseUtil::SetKernelDefInputs() {
|
void MemReuseUtil::SetKernelDefInputs() {
|
||||||
auto is_all_nop_node = opt::IsAllNopNode(graph_);
|
|
||||||
for (const auto &kernel : graph_->execution_order()) {
|
for (const auto &kernel : graph_->execution_order()) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
auto key = kernel.get();
|
auto key = kernel.get();
|
||||||
|
@ -282,7 +281,7 @@ void MemReuseUtil::SetKernelDefInputs() {
|
||||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||||
session::KernelWithIndex input;
|
session::KernelWithIndex input;
|
||||||
if (is_all_nop_node) {
|
if (is_all_nop_node_) {
|
||||||
// The graph does not remove the nop node.
|
// The graph does not remove the nop node.
|
||||||
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||||
} else {
|
} else {
|
||||||
|
@ -349,11 +348,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemReuseUtil::SetGraphOutputRefCount() {
|
void MemReuseUtil::SetGraphOutputRefCount() {
|
||||||
auto is_all_nop_node = opt::IsAllNopNode(graph_);
|
|
||||||
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
||||||
for (const auto &node : nodes) {
|
for (const auto &node : nodes) {
|
||||||
session::KernelWithIndex kernel_input;
|
session::KernelWithIndex kernel_input;
|
||||||
if (is_all_nop_node) {
|
if (is_all_nop_node_) {
|
||||||
// The graph does not remove the nop node.
|
// The graph does not remove the nop node.
|
||||||
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
|
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -42,7 +42,7 @@ class MemReuseUtil {
|
||||||
KernelRefCountPtrList total_refs_list_;
|
KernelRefCountPtrList total_refs_list_;
|
||||||
KernelRefCountPtrList total_wk_ref_list_;
|
KernelRefCountPtrList total_wk_ref_list_;
|
||||||
KernelRefs kernel_workspace_refs_;
|
KernelRefs kernel_workspace_refs_;
|
||||||
MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr) {}
|
MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {}
|
||||||
~MemReuseUtil() {
|
~MemReuseUtil() {
|
||||||
if (graph_ != nullptr) {
|
if (graph_ != nullptr) {
|
||||||
graph_ = nullptr;
|
graph_ = nullptr;
|
||||||
|
@ -87,6 +87,7 @@ class MemReuseUtil {
|
||||||
private:
|
private:
|
||||||
int util_index_;
|
int util_index_;
|
||||||
const KernelGraph *graph_;
|
const KernelGraph *graph_;
|
||||||
|
bool is_all_nop_node_;
|
||||||
KernelRefCountPtrList ref_list_;
|
KernelRefCountPtrList ref_list_;
|
||||||
KernelDefPtrMaps kernel_def_ptr_list_;
|
KernelDefPtrMaps kernel_def_ptr_list_;
|
||||||
KernelRefCountPtrList last_ref_list_;
|
KernelRefCountPtrList last_ref_list_;
|
||||||
|
|
Loading…
Reference in New Issue