diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 9524da4cfd3..8d4c4026391 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -168,15 +168,15 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, return true; } -bool IsAllGraphInValueSequence(const std::vector &value_vec) { +bool IsAllFuncInValueSequence(const std::vector &value_vec) { for (auto &elem : value_vec) { if (elem->isa() || elem->isa()) { const auto &vec = GetValue>(elem); - auto is_graph = IsAllGraphInValueSequence(vec); + auto is_graph = IsAllFuncInValueSequence(vec); if (!is_graph) { return false; } - } else if (!elem->isa()) { + } else if (!elem->isa() && !elem->isa()) { return false; } } @@ -196,6 +196,8 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F FuncGraphPtr new_fg = elem->cast(); manager->AddFuncGraph(new_fg); node = NewValueNode(new_fg); + } else if (elem->isa()) { + node = NewValueNode(elem); } else { MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); } @@ -205,19 +207,21 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F return cnode; } -// transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, - const ValueNodePtr &value_node, AnfNodePtr *const transformed) { +// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node +bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); const auto &value_vec = GetValue>(value_node->value()); - if (!IsAllGraphInValueSequence(value_vec)) { + if (!IsAllFuncInValueSequence(value_vec)) { return false; } - // The celllist or ordered_cell will be parsed as valuetuple of const graph in it, + // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // So if has graph in list, try to replace the node with make tuple of graph value node. // we do this because the graphmanger won't investigate the graph inside valuetuple, - // change the vector of graph to be make_tuple of graph value node + // change the vector of graph to be make_tuple of graph value node. + // (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all + // independent nodes. auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); // replace the ret ptr to be make tuple of graph value node *transformed = node_tuple_graphs; @@ -251,8 +255,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr // if the constant node is constant of vector of graph ,add graph to manager if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { - (void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast(), - &resolved_node); + (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast(), + &resolved_node); } TraceManager::EndTrace(); diff --git a/tests/ut/python/pipeline/parse/test_for_stmt.py b/tests/ut/python/pipeline/parse/test_for_stmt.py index 4930dae796d..748c73e8738 100644 --- a/tests/ut/python/pipeline/parse/test_for_stmt.py +++ b/tests/ut/python/pipeline/parse/test_for_stmt.py @@ -17,6 +17,9 @@ from dataclasses import dataclass import numpy as np from mindspore import Tensor, Model, context +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F from mindspore.nn import Cell from mindspore.nn import ReLU from ...ut_filter import non_graph_engine @@ -66,3 +69,58 @@ def function_access_base(number): def test_access_0040(): """ test_access_0040 """ function_access_base(2) + + +class OpSeqNet(Cell): + def __init__(self, loop_count=1): + super().__init__() + self.loop_count = loop_count + self.op_seq = (P.Sqrt(), P.Reciprocal(), P.Square()) + + def construct(self, x): + t = x + for op in self.op_seq: + t = op(t) + return t + + +def test_op_seq_test(): + context.set_context(mode=context.GRAPH_MODE) + net = OpSeqNet() + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net(input_me) + + +_grad_fusion = C.MultitypeFuncGraph("grad_fushion") + + +@_grad_fusion.register("Tensor", "Function") +def tensor_grad_scale(x, op): + return op(x) + + +class AllReduceTest(Cell): + def __init__(self, loop_count=1): + super().__init__() + self.op_list = () + self.fushion_flag = [0, 1, 1, 0, 1, 0] + for i in self.fushion_flag: + op = P.AllReduce().add_prim_attr('fusion', i) + self.op_list = self.op_list + (op,) + self.hyper_map = C.HyperMap() + + def construct(self, x): + ret = () + for _ in self.fushion_flag: + ret = ret + (x,) + fushion_res = self.hyper_map(F.partial(_grad_fusion), ret, self.op_list) + return fushion_res + + +def test_allreduce_fushio_test(): + context.set_context(mode=context.GRAPH_MODE) + net = AllReduceTest() + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net(input_me)