From 0e89813759d8958bc05eb7192607f07cdde8c998 Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Sat, 9 May 2020 14:13:06 +0800 Subject: [PATCH] add resolve transform valuetuple to maketuple of graphs add testcase --- mindspore/ccsrc/pipeline/parse/resolve.cc | 87 +++++++++++++---------- tests/ut/python/ops/test_tuple.py | 74 +++++++++++++++++++ 2 files changed, 122 insertions(+), 39 deletions(-) create mode 100644 tests/ut/python/ops/test_tuple.py diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index 18f186dbb15..576f63a1cfc 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -170,51 +170,59 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, return true; } +bool IsAllGraphInValueSequence(const std::vector &value_vec) { + for (auto &elem : value_vec) { + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + auto is_graph = IsAllGraphInValueSequence(vec); + if (!is_graph) { + return false; + } + } else if (!elem->isa()) { + return false; + } + } + return true; +} + +AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const std::vector &value_vec) { + std::vector nodes; + nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &elem : value_vec) { + AnfNodePtr node = nullptr; + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + node = TransformToMakeTupleNodes(manager, func_graph, vec); + } else if (elem->isa()) { + FuncGraphPtr new_fg = elem->cast(); + manager->AddFuncGraph(new_fg); + node = NewValueNode(new_fg); + } else { + MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); + } + nodes.emplace_back(node); + } + auto cnode = func_graph->NewCNode(nodes); + return cnode; +} + // transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, +bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); const auto &value_vec = GetValue>(value_node->value()); - bool has_graph_in_list = false; - for (auto &elemv : value_vec) { - MS_EXCEPTION_IF_NULL(elemv); - if (elemv->isa()) { - FuncGraphPtr new_fg = elemv->cast(); - manager->AddFuncGraph(new_fg); - has_graph_in_list = true; - continue; - } - if (has_graph_in_list) { - MS_LOG(EXCEPTION) << "List has graph in it, but not all is graph"; - } + if (!IsAllGraphInValueSequence(value_vec)) { + return false; } + // The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // So if has graph in list, try to replace the node with make tuple of graph value node. - if (has_graph_in_list) { - // change the vector of graph to be make_list of graph value node - std::vector list_vec; - auto make_list_op = NewValueNode(prim::kPrimMakeTuple); - list_vec.emplace_back(make_list_op); - (void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec), - [](const ValuePtr &value) { return NewValueNode(value); }); - FuncGraphPtr cnode_graph = nullptr; - auto users = manager->node_users()[node]; - for (auto &use : users) { - auto use_node = use.first; - MS_EXCEPTION_IF_NULL(use_node); - if (use_node->isa()) { - cnode_graph = use_node->func_graph(); - } - } - - if (cnode_graph) { - CNodePtr list_app = cnode_graph->NewCNode(list_vec); - // replace the ret ptr to be make_list of graph value node - *transformed = list_app; - } else { - MS_LOG(EXCEPTION) << "Can not find apply for node use when replacing node of vector of graph"; - } - } + // we do this because the graphmanger won't investigate the graph inside valuetuple, + // change the vector of graph to be make_tuple of graph value node + auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); + // replace the ret ptr to be make tuple of graph value node + *transformed = node_tuple_graphs; return true; } @@ -245,7 +253,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr // if the constant node is constant of vector of graph ,add graph to manager if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { - (void)TransformVectorGraphValueNode(manager, node, resolved_node->cast(), &resolved_node); + (void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast(), + &resolved_node); } TraceManager::EndTrace(); diff --git a/tests/ut/python/ops/test_tuple.py b/tests/ut/python/ops/test_tuple.py new file mode 100644 index 00000000000..5b3d5d52ae1 --- /dev/null +++ b/tests/ut/python/ops/test_tuple.py @@ -0,0 +1,74 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.context as context +import functools +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore.ops import operations as P +from mindspore import context +from ..ut_filter import non_graph_engine +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward \ + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +class TupleGraphNet(nn.Cell): + def __init__(self): + super(TupleGraphNet, self).__init__() + self.conv1 = nn.Conv2d(3, 1, 3, pad_mode='same') + self.conv2 = nn.Conv2d(3, 1, 7, pad_mode='same') + self.conv3 = nn.Conv2d(3, 3, 3, pad_mode='same') + self.layers = (self.conv1, self.conv2, self.conv3) + + def construct(self, x): + return self.layers[0](x) + + +class NestTupleGraphNet(nn.Cell): + def __init__(self): + super(NestTupleGraphNet, self).__init__() + self.conv1 = nn.Conv2d(3, 1, 3, pad_mode='same') + self.conv2 = nn.Conv2d(3, 1, 7, pad_mode='same') + self.conv3 = nn.Conv2d(3, 3, 3, pad_mode='same') + self.layers = ((self.conv1, self.conv2), + (self.conv2, self.conv1, self.conv3)) + + def construct(self, x): + return self.layers[0][1](x) + + +test_case_ops = [ + ('TupleGraph', { + 'block': TupleGraphNet(), + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), + ('NestTupleGraph', { + 'block': NestTupleGraphNet(), + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), +] + +test_case_lists = [test_case_ops] +test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) +# use -k to select certain testcast +# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm + + +@non_graph_engine +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) +def test_exec(): + context.set_context(mode=context.GRAPH_MODE) + return test_exec_case