From 24584317501a623a9680f86a4515ba0ede837f87 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Thu, 30 Jul 2020 15:22:53 +0800 Subject: [PATCH] not eliminate memcpy when nexe node is graph output --- .../ascend/enhancer/getnext_memcpy_elimination.cc | 9 ++++++++- .../pre_activate/getnext_memcpy_elimination_test.py | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc index a729cdd0f9..a1d957f72c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc @@ -56,12 +56,19 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co return nullptr; } - // 3. next_node is not nop node and it has only one input which is memcpy's output + // 3. next_node is not nop node, not graph output and it has only one input which is memcpy's output for (auto &item : next_nodes) { auto next_node = item.first->cast(); if (opt::IsNopNode(next_node)) { return nullptr; } + + auto graph_outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + auto iter = std::find(graph_outputs.begin(), graph_outputs.end(), next_node); + if (iter != graph_outputs.end()) { + return nullptr; + } + if (next_node->inputs().size() != 2) { MS_LOG(DEBUG) << "next node has more than one input"; return nullptr; diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py index 61310d186f..444cf8282d 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/getnext_memcpy_elimination_test.py @@ -44,12 +44,14 @@ def test_getnext_memcpy_elimination(tag): res = get_next() res = memcpy_async_attr(res) res = cast(res) + res = add(res) return res @fns def after(): res = get_next() res = cast(res) + res = add(res) return res return fns[tag]