diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc index 7f92ad9a1e1..0253cd2b39d 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc @@ -52,13 +52,17 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be // converted to switch guarded. std::vector>> white_list( - {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, - {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}}, - {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, - {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}}, - {prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, - {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}}, - {prim::kPrimScalarSummary, {1}}, {prim::kPrimHistogramSummary, {1}}}); + {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, + {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, + {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, + {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, + {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, + {prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}}, + {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, + {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, + {prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}}, + {prim::kPrimHistogramSummary, {1}}}); for (auto &item : white_list) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { return IsPrimitiveCNode(node, item.first) && idx == index; @@ -80,7 +84,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; // replace the nodes which should be changed void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, - std::unordered_map repl_node, NodeInputReplMap repl_node_inputs) { + std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, + const FuncGraphPtr &func_graph) { for (auto &node_pair : nodes_changed) { CNodePtr old_node = node_pair.first; CNodePtr new_node = node_pair.second; @@ -99,9 +104,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vectorReplace(item.first, item.second)) { - MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString() - << " to new: " << item.second->DebugString(); + if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { + func_graph->set_output(item.second->cast()->input(1)); + } else if (!manager->Replace(item.first, item.second)) { + MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) + << " to new: " << item.second->DebugString(2); } } } @@ -154,7 +161,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( nodes_changed.emplace_back(node->cast(), new_node); } } - RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs); + RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); return graph; } @@ -508,11 +515,12 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) { + const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, + const AnfNodePtr &cond) { MS_EXCEPTION_IF_NULL(true_graph_output_abs); MS_EXCEPTION_IF_NULL(false_graph_output_abs); MS_EXCEPTION_IF_NULL(cond); - MS_EXCEPTION_IF_NULL(cond->func_graph()); + MS_EXCEPTION_IF_NULL(switch_graph); auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); MS_EXCEPTION_IF_NULL(PrimMerge); @@ -520,10 +528,10 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP std::vector merge_nodes; merge_nodes.push_back(NewValueNode(PrimMerge)); std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; - merge_nodes.push_back(cond->func_graph()->NewCNode(make_tuple_nodes)); + merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), - cond->func_graph()->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; - return cond->func_graph()->NewCNode(tuple_getitem_nodes); + switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; + return switch_graph->NewCNode(tuple_getitem_nodes); } else { abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); @@ -533,27 +541,29 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, NewValueNode(MakeValue(SizeToInt(i)))}; - auto true_node = cond->func_graph()->NewCNode(true_getitem_nodes); + auto true_node = switch_graph->NewCNode(true_getitem_nodes); std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, NewValueNode(MakeValue(SizeToInt(i)))}; - auto false_node = cond->func_graph()->NewCNode(false_getitem_nodes); + auto false_node = switch_graph->NewCNode(false_getitem_nodes); auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], - false_branch_tuple->elements()[i], cond); + false_branch_tuple->elements()[i], switch_graph, cond); make_tuple_nodes.push_back(merge_node); } - return cond->func_graph()->NewCNode(make_tuple_nodes); + return switch_graph->NewCNode(make_tuple_nodes); } } AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) { + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &switch_graph) { if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() << ", false:" << false_graph_output_abs->ToString(); } - return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, cond); + return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, + switch_graph, cond); } } // namespace internal } // namespace irpass diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h index 80d07625ed9..b2d6718857b 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.h @@ -168,7 +168,8 @@ FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond); + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &func_graph); } // namespace internal // {{prim::kPrimSwitch, X, G1, G2}, Xs} @@ -190,6 +191,20 @@ class ConvertSwitchReplacement : public AnfVisitor { if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) { return nullptr; } + // for switch replace method, only graphs without graph inside can be replaced + for (auto &item : g1_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } + + for (auto &item : g2_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } auto true_output = g1_->output()->abstract(); auto false_output = g2_->output()->abstract(); @@ -200,8 +215,8 @@ class ConvertSwitchReplacement : public AnfVisitor { auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); - - return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_); + auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); + return nnode; } void Visit(const AnfNodePtr &node) override { diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index f4a3a49b25f..a95c02ded63 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -162,7 +162,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { } OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); + opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); OptPassGroupMap map({ {"control_group", control_group}, {"renormalize", opt::OptPassConfig::Renormalize()}, diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index ddf8acc9b00..cf5fd593906 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -346,7 +346,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { if ((*value == *kAnyValue)) { auto value_desc = abs_base->value_desc(); MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) - << " for python primitive."; + << " for python primitive." << abs_base->ToString(); } MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " << value->ToString(); diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 0e965b6fb31..6d41d1cb5b1 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -24,6 +24,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import ms_function context.set_context(mode=context.GRAPH_MODE) @@ -371,7 +373,8 @@ def test_switch_layer(): class Layer1(nn.Cell): def __init__(self): super(Layer1, self).__init__() - self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') + self.z1 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') def construct(self, x): return x * self.z1 @@ -379,7 +382,8 @@ def test_switch_layer(): class Layer2(nn.Cell): def __init__(self): super(Layer2, self).__init__() - self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') + self.z2 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') def construct(self, x): return x * self.z2 @@ -388,7 +392,8 @@ def test_switch_layer(): def __init__(self): super(SwitchLayerCell, self).__init__() self.layers = (Layer1(), Layer2()) - self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') + self.z3 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') def construct(self, index, x): ret = F.switch_layer(index, self.layers)(x) * self.z3 @@ -406,7 +411,8 @@ def test_index_to_switch_layer(): class Layer1(nn.Cell): def __init__(self): super(Layer1, self).__init__() - self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') + self.z1 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') def construct(self, x): return x * self.z1 @@ -414,7 +420,8 @@ def test_index_to_switch_layer(): class Layer2(nn.Cell): def __init__(self): super(Layer2, self).__init__() - self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') + self.z2 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') def construct(self, x): return x * self.z2 @@ -423,7 +430,8 @@ def test_index_to_switch_layer(): def __init__(self): super(SwitchLayerCell, self).__init__() self.layers = (Layer1(), Layer2()) - self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') + self.z3 = Parameter( + Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') def construct(self, index, x): ret = self.layers[index](x) * self.z3 @@ -444,3 +452,69 @@ def test_control_depend_check(): depend = P.ControlDepend(2) with pytest.raises(TypeError) as e: depend = P.ControlDepend((2,)) + + +def test_if_nested_compile(): + class Net(nn.Cell): + def __init__(self, auto_prefix=True): + super().__init__(auto_prefix=auto_prefix) + self.squre = P.Square() + self.value = Tensor(3, dtype=ms.float32) + + def construct(self, x, y): + res = self.value + if x <= y: + res = x + res + res = y + res + else: + if x == y: + res = self.squre(self.value * y) + else: + res = self.squre(self.value) + return res + x = Tensor(1.0, dtype=ms.float32) + y = Tensor(2.0, dtype=ms.float32) + net = Net() + net(x, y) + + +def test_if_inside_for(): + class Net(nn.Cell): + def __init__(self, auto_prefix=True): + super().__init__(auto_prefix=auto_prefix) + self.squre = P.Square() + self.value = Tensor(3, dtype=ms.float32) + self.count = 4 + + def construct(self, x, y): + res = 0 + for i in range(self.count): + if i == x: + res = res + x + else: + res = res - y + return res + c1 = Tensor(1, dtype=ms.int32) + c2 = Tensor(1, dtype=ms.int32) + net = Net() + out = net(c1, c2) + + +def test_while_in_while(): + c1 = Tensor(1, dtype=ms.int32) + c2 = Tensor(2, dtype=ms.int32) + c3 = Tensor(3, dtype=ms.int32) + c4 = Tensor(4, dtype=ms.int32) + @ms_function + def while_in_while(x, y, z, u): + out = c4 + while x < y: + z = c4 + c4 + while z < y: + z = z + 1 + out = out + 1 + x = x + 1 + + out = out + 3 + return out + while_in_while(c1, c2, c3, c4) diff --git a/tests/ut/python/ops/test_layer_switch.py b/tests/ut/python/ops/test_layer_switch.py new file mode 100644 index 00000000000..af0da5a39ae --- /dev/null +++ b/tests/ut/python/ops/test_layer_switch.py @@ -0,0 +1,81 @@ +import numpy as np + +import mindspore +from mindspore import nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + + +context.set_context(mode=context.GRAPH_MODE) + + +class Layer1(nn.Cell): + def __init__(self): + super(Layer1, self).__init__() + self.net = nn.Conv2d(3, 1, 3, pad_mode='same') + self.pad = nn.Pad( + paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") + + def construct(self, x): + y = self.net(x) + return self.pad(y) + + +class Layer2(nn.Cell): + def __init__(self): + super(Layer2, self).__init__() + self.net = nn.Conv2d(3, 1, 7, pad_mode='same') + self.pad = nn.Pad( + paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") + + def construct(self, x): + y = self.net(x) + return self.pad(y) + + +class Layer3(nn.Cell): + def __init__(self): + super(Layer3, self).__init__() + self.net = nn.Conv2d(3, 3, 3, pad_mode='same') + + def construct(self, x): + return self.net(x) + + +class SwitchNet(nn.Cell): + def __init__(self): + super(SwitchNet, self).__init__() + self.layer1 = Layer1() + self.layer2 = Layer2() + self.layer3 = Layer3() + self.layers = (self.layer1, self.layer2, self.layer3) + self.fill = P.Fill() + + def construct(self, x, index): + y = self.layers[index](x) + return y + + +class MySwitchNet(nn.Cell): + def __init__(self): + super(MySwitchNet, self).__init__() + self.layer1 = Layer1() + self.layer2 = Layer2() + self.layer3 = Layer3() + self.layers = (self.layer1, self.layer2, self.layer3) + self.fill = P.Fill() + + def construct(self, x, index): + y = self.layers[0](x) + for i in range(len(self.layers)): + if i == index: + y = self.layers[i](x) + return y + + +def test_layer_switch(): + net = MySwitchNet() + x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) + index = Tensor(0, dtype=mindspore.int32) + y = net(x, index)