!816 faster combine_like step

Merge pull request !816 from flywind/combine_opt
This commit is contained in:
mindspore-ci-bot 2020-04-29 09:13:51 +08:00 committed by Gitee
commit 3dd369cefa
2 changed files with 31 additions and 8 deletions

View File

@ -130,7 +130,7 @@ bool ParseAction(const ResourcePtr &res) {
// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
// all obj_map's graph shared base_graph
bool CombineLikeGraphs(const ResourcePtr &) {
bool CombineLikeGraphs(const ResourcePtr &res) {
auto &obj_map = parse::data_converter::GetObjGraphs();
for (auto it : obj_map) {
@ -147,13 +147,15 @@ bool CombineLikeGraphs(const ResourcePtr &) {
if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) {
continue;
}
auto mng = Manage(base_graph, false);
for (auto &fv : fg->paramter_obj_nodes()) {
TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter();
TraceManager::EndTrace();
auto repl_node = (*cloner->cloned_node())[fv];
(void)mng->Replace(repl_node, param);
auto &node_users = res->manager()->node_users()[fv];
for (auto &n : node_users) {
auto repl_n = (*cloner->cloned_node())[n.first]->cast<CNodePtr>();
repl_n->set_input(n.second, param);
}
}
MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();

View File

@ -24,9 +24,7 @@ from mindspore.ops import operations as P
def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend")
context.set_context(enable_task_sink = True, device_id = 0)
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
c1 = Tensor([2], mstype.int32)
c2 = Tensor([14], mstype.int32)
@ -135,6 +133,10 @@ def while_in_while_in_while(x, y, z):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_simple_if():
output = simple_if(c1, c2, c3)
expect = Tensor([6], mstype.int32)
@ -153,30 +155,49 @@ def test_if_in_if():
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_simple_while():
output = simple_while(c1, c2, c3)
expect = Tensor([21], mstype.int32)
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_by_while():
output = while_by_while(c1, c2, c3)
expect = Tensor([28], mstype.int32)
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_in_while():
output = while_in_while(c1, c2, c3)
expect = Tensor([1274], mstype.int32)
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_by_while_in_while():
output = while_by_while_in_while(c1, c2, c3)
expect = Tensor([350], mstype.int32)
assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_while_in_while_in_while():
output = while_in_while_in_while(c1, c2, c3)
expect = Tensor([2534], mstype.int32)