not eliminate memcpy when nexe node is graph output

This commit is contained in:
laiyongqiang 2020-07-30 15:22:53 +08:00
parent ab4c43007f
commit 2458431750
2 changed files with 10 additions and 1 deletions

View File

@ -56,12 +56,19 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co
return nullptr; return nullptr;
} }
// 3. next_node is not nop node and it has only one input which is memcpy's output // 3. next_node is not nop node, not graph output and it has only one input which is memcpy's output
for (auto &item : next_nodes) { for (auto &item : next_nodes) {
auto next_node = item.first->cast<CNodePtr>(); auto next_node = item.first->cast<CNodePtr>();
if (opt::IsNopNode(next_node)) { if (opt::IsNopNode(next_node)) {
return nullptr; return nullptr;
} }
auto graph_outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(graph_outputs.begin(), graph_outputs.end(), next_node);
if (iter != graph_outputs.end()) {
return nullptr;
}
if (next_node->inputs().size() != 2) { if (next_node->inputs().size() != 2) {
MS_LOG(DEBUG) << "next node has more than one input"; MS_LOG(DEBUG) << "next node has more than one input";
return nullptr; return nullptr;

View File

@ -44,12 +44,14 @@ def test_getnext_memcpy_elimination(tag):
res = get_next() res = get_next()
res = memcpy_async_attr(res) res = memcpy_async_attr(res)
res = cast(res) res = cast(res)
res = add(res)
return res return res
@fns @fns
def after(): def after():
res = get_next() res = get_next()
res = cast(res) res = cast(res)
res = add(res)
return res return res
return fns[tag] return fns[tag]