forked from OSSInnovation/mindspore
not eliminate memcpy when nexe node is graph output
This commit is contained in:
parent
ab4c43007f
commit
2458431750
|
@ -56,12 +56,19 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. next_node is not nop node and it has only one input which is memcpy's output
|
// 3. next_node is not nop node, not graph output and it has only one input which is memcpy's output
|
||||||
for (auto &item : next_nodes) {
|
for (auto &item : next_nodes) {
|
||||||
auto next_node = item.first->cast<CNodePtr>();
|
auto next_node = item.first->cast<CNodePtr>();
|
||||||
if (opt::IsNopNode(next_node)) {
|
if (opt::IsNopNode(next_node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto graph_outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||||
|
auto iter = std::find(graph_outputs.begin(), graph_outputs.end(), next_node);
|
||||||
|
if (iter != graph_outputs.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
if (next_node->inputs().size() != 2) {
|
if (next_node->inputs().size() != 2) {
|
||||||
MS_LOG(DEBUG) << "next node has more than one input";
|
MS_LOG(DEBUG) << "next node has more than one input";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -44,12 +44,14 @@ def test_getnext_memcpy_elimination(tag):
|
||||||
res = get_next()
|
res = get_next()
|
||||||
res = memcpy_async_attr(res)
|
res = memcpy_async_attr(res)
|
||||||
res = cast(res)
|
res = cast(res)
|
||||||
|
res = add(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after():
|
def after():
|
||||||
res = get_next()
|
res = get_next()
|
||||||
res = cast(res)
|
res = cast(res)
|
||||||
|
res = add(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue