diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 37228788767..d894a4d488d 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -92,7 +92,7 @@ def zeros_like_tensor(x): return value -def switch(c, x, y): +def Switch(c, x, y): """Implement `switch`.""" return x if c else y diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 95317b6dba7..0efd4341e5a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -47,7 +47,7 @@ namespace abstract { using mindspore::parse::PyObjectWrapper; std::unordered_set prims_to_skip_undetermined_infer{ - "MakeTuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; + "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5ad0c97e776..1d6008c6a87 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -184,7 +184,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; constexpr auto kBNInferGradOpName = "BNInferGrad"; constexpr auto kCallOpName = "call"; constexpr auto kPartialOpName = "partial"; -constexpr auto kSwitchOpName = "switch"; +constexpr auto kSwitchOpName = "Switch"; constexpr auto kReturnOpName = "Return"; constexpr auto kLarsV2OpName = "LarsV2"; constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index db7850faa99..29ba65d1304 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -63,7 +63,7 @@ using InstType = std::pair; using InstSet = std::vector; using InstFunctionMap = std::map>; -const std::vector inst_str{"call", "tail_call", "Return", "partial", "switch", +const std::vector inst_str{"call", "tail_call", "Return", "partial", "Switch", "switch_return", "tuple", "input", "external", "push", "primitive", "graph", "pad_stack", "switch_layer"}; class StructPartial : public Base { diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 56dfbf235e5..f7627d20e68 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -403,7 +403,7 @@ inline const PrimitivePtr kPrimWhere = std::make_shared("Where"); // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); -inline const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +inline const PrimitivePtr kPrimSwitch = std::make_shared("Switch"); inline const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); inline const PrimitivePtr kPrimAssign = std::make_shared("Assign"); inline const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index cf9d88c25b1..c72b34d199f 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -267,7 +267,7 @@ def bprop_control_depend(x, y, out, dout): return C.zeros_like(x), C.zeros_like(y) -@bprops.register("switch") +@bprops.register("Switch") def bprop_switch(cond, tb, fb, out, dout): """Backpropagator for primitive `switch`.""" return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index c3044295748..5051936d5ff 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -181,7 +181,7 @@ env_setitem = Primitive('env_setitem') env_getitem = Primitive('env_getitem') env_add = Primitive('env_add') J = Primitive('J') -switch = Primitive('switch') +switch = Primitive('Switch') switch_layer = Primitive('switch_layer') # for sum bprop reduced_shape = Primitive("reduced_shape") diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 11984d305bf..43b5e94ae30 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -310,7 +310,7 @@ TEST_F(TestOps, Col2ImV1Test) { // Statements TEST_F(TestOps, SwitchTest) { - auto prim = std::make_shared("switch"); + auto prim = std::make_shared("Switch"); ASSERT_EQ(prim->name(), kPrimSwitch->name()); } diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 8470c5b81ba..7d8274a2671 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -294,7 +294,7 @@ TEST_F(TestPrim, test_J_2) { // tail half TEST_F(TestPrim, test_switch1) { - PrimitivePtr switch_ = std::make_shared("switch"); + PrimitivePtr switch_ = std::make_shared("Switch"); FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); AbstractBasePtr arg0 = FromValue(true, false); @@ -307,7 +307,7 @@ TEST_F(TestPrim, test_switch1) { } TEST_F(TestPrim, test_switch2) { - PrimitivePtr switch_ = std::make_shared("switch"); + PrimitivePtr switch_ = std::make_shared("Switch"); FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); AbstractBasePtr arg0 = FromValue(false, false); diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 29ea3b43243..c2c773c2e5f 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -30,7 +30,7 @@ from mindspore.ops.operations import _grad_ops as G scalar_add = Primitive(Constants.kScalarAdd) scalar_mul = Primitive(Constants.kScalarMul) tuple_getitem = Primitive(Constants.kTupleGetItem) -switch = Primitive('switch') +switch = Primitive('Switch') def test_sexp_conversion():