diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 9286379a759..c337aad857d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -462,7 +462,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { std::vector graphs{}; auto graphs_cnode = sw->input(2)->cast(); auto &graphs_inputs = graphs_cnode->inputs(); - if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode(graphs_inputs[1])) { + if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 && + IsValueNode(graphs_inputs[1])) { (void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs), [](const AnfNodePtr &vnode) { return GetValueNode(vnode); }); } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index a4a2f494e60..9bfa3669406 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -89,6 +89,7 @@ class GetItemTransformACrossGraph { ss << idx; auto new_fg_outer = TransformableClone(fg, std::make_shared(ss.str())); + fg->manager()->AddFuncGraph(new_fg_outer); auto output_outer = new_fg_outer->output(); if (!IsValueNode(output_outer)) { MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; @@ -486,7 +487,7 @@ class IncorporateGetitemSwitchLayerA : public AnfVisitor { switch_layer_ = inputs[0]; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); } - if (is_in_switch_ && cnode->size() > 2) { + if (is_in_switch_ && cnode->size() >= 2) { auto &inputs = cnode->inputs(); if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode(inputs[1])) { (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), @@ -578,7 +579,7 @@ class IncorporateGetitemSwitchLayerB : public AnfVisitor { switch_layer_call_ = inputs[0]; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_)); } - if (is_in_switch_ && cnode->size() > 2) { + if (is_in_switch_ && cnode->size() >= 2) { auto &inputs = cnode->inputs(); if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode(inputs[1])) { (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h index f355a54b86f..e360e016cc8 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h @@ -36,10 +36,9 @@ class SwitchLayerDeferInline : public AnfVisitor { auto tuple = dyn_cast(cnode->inputs()[2]->abstract()); for (auto elem : tuple->elements()) { auto abstract = dyn_cast(elem); - if (abstract == nullptr) { - return nullptr; + if (abstract != nullptr) { + *(abstract->func_graph()->switch_layer_input()) = true; } - *(abstract->func_graph()->switch_layer_input()) = true; } return nullptr; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 966275b964f..113545491f8 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -137,6 +137,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.arithmetic_simplify2_, irpass.same_eliminate_, irpass.check_bprop_eliminate_, + irpass.switch_layer_defer_inline_, irpass.replace_applicator_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); diff --git a/mindspore/core/abstract/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc index 24a95709f9b..1cc077d300a 100644 --- a/mindspore/core/abstract/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -16,6 +16,7 @@ #include "abstract/param_validator.h" #include "abstract/infer_functions.h" +#include "abstract/abstract_function.h" #include "abstract/utils.h" #include "utils/symbolic.h" @@ -121,12 +122,18 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP for (size_t i = 0; i < branches.size(); i++) { MS_EXCEPTION_IF_NULL(branches[i]); if (!branches[i]->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " - << branches[i]->ToString() << " as the " << i << "th element."; + MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " + << branches[i]->ToString() << " as the " << i << "th element."; } } auto b = branches[0]; + // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0] + // which will cancel the out of bound checking for index + if (branches.size() == 1) { + AbstractFuncAtomPtrList func_list{b->cast()}; + return std::make_shared(func_list); + } for (size_t i = 1; i < branches.size(); i++) { b = b->Join(branches[i]); } diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 85cc97f4c14..6e8467b700d 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -444,6 +444,86 @@ def test_index_to_switch_layer(): C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) +def test_parser_switch_layer_switch_in_bprop(): + class OneInputBprop(nn.Cell): + def __init__(self, funcs): + super(OneInputBprop, self).__init__() + self.op = P.ReLU() + self.funcs = funcs + def construct(self, i, x): + return self.op(x) + def bprop(self, i, x, out, dout): + return i, self.funcs[i](x, dout) + + class Add(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.TensorAdd() + + def construct(self, x, y): + return self.op(x, y) + + class Mul(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x, y): + return self.op(x, y) + func1 = Add() + func2 = Mul() + funcs = (func1, func2) + net = OneInputBprop(funcs) + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + grad = Tensor(np.random.randn(2, 2).astype(np.float32)) + i = Tensor(1, mstype.int32) + grad_net = C.grad_all_with_sens(net) + grad_net(i, input1, grad) + + +def test_parser_switch_layer_inputs_tuple(): + class TwoInputTupleFinalNet(nn.Cell): + def __init__(self, funcs): + super().__init__() + self.funcs = funcs + + def construct(self, i, inputa, inputb): + inputs = (inputa, inputb) + x = self.funcs[i](inputs) + return x + + class Add(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.TensorAdd() + + def construct(self, x): + y = self.op(x[0], x[1]) + return self.op(x[0], y) + + class Mul(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x): + y = self.op(x[0], x[1]) + return self.op(x[0], y) + + func1 = Add() + func2 = Mul() + + funcs = (func1, func2) + net = TwoInputTupleFinalNet(funcs) + + input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + i = Tensor(1, mstype.int32) + grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + back_net = C.grad_all_with_sens(net) + back_out = back_net(i, input1, input2, grad) + + def test_switch_layer_with_single_prim(): class SwitchLayerCell(nn.Cell): def __init__(self): @@ -494,6 +574,35 @@ def test_switch_layer_env_eliminate(): net2(x, i) +def test_switch_layer_single_layer(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') + self.funs = (self.conv,) + + def construct(self, x, index): + x = self.funs[index](x) + return x + + class NetGrad(nn.Cell): + def __init__(self, net): + super(NetGrad, self).__init__() + self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) + self.net = net + self.weights = ParameterTuple(self.net.trainable_params()) + + def construct(self, x, index): + weights = self.weights + grad = self.grad_op(self.net, weights)(x, index) + return grad + net = Net() + net2 = NetGrad(net) + x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) + i = Tensor(1, ms.int32) + net2(x, i) + + def test_control_depend_check(): with pytest.raises(TypeError) as e: P.ControlDepend(0.0)