forked from mindspore-Ecosystem/mindspore
Change switch to Switch
This commit is contained in:
parent
b730069942
commit
bbdb050fc7
|
@ -92,7 +92,7 @@ def zeros_like_tensor(x):
|
|||
return value
|
||||
|
||||
|
||||
def switch(c, x, y):
|
||||
def Switch(c, x, y):
|
||||
"""Implement `switch`."""
|
||||
return x if c else y
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace abstract {
|
|||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
std::unordered_set<std::string> prims_to_skip_undetermined_infer{
|
||||
"MakeTuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
|
||||
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
|
||||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
|
|
|
@ -184,7 +184,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
|
|||
constexpr auto kBNInferGradOpName = "BNInferGrad";
|
||||
constexpr auto kCallOpName = "call";
|
||||
constexpr auto kPartialOpName = "partial";
|
||||
constexpr auto kSwitchOpName = "switch";
|
||||
constexpr auto kSwitchOpName = "Switch";
|
||||
constexpr auto kReturnOpName = "Return";
|
||||
constexpr auto kLarsV2OpName = "LarsV2";
|
||||
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
|
||||
|
|
|
@ -63,7 +63,7 @@ using InstType = std::pair<Instruction, VectorRef>;
|
|||
using InstSet = std::vector<InstType>;
|
||||
using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
|
||||
|
||||
const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "switch",
|
||||
const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "Switch",
|
||||
"switch_return", "tuple", "input", "external", "push",
|
||||
"primitive", "graph", "pad_stack", "switch_layer"};
|
||||
class StructPartial : public Base {
|
||||
|
|
|
@ -403,7 +403,7 @@ inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
|
|||
|
||||
// Statements
|
||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
|
||||
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
|
||||
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("Switch");
|
||||
inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
|
||||
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
||||
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
|
|
|
@ -267,7 +267,7 @@ def bprop_control_depend(x, y, out, dout):
|
|||
return C.zeros_like(x), C.zeros_like(y)
|
||||
|
||||
|
||||
@bprops.register("switch")
|
||||
@bprops.register("Switch")
|
||||
def bprop_switch(cond, tb, fb, out, dout):
|
||||
"""Backpropagator for primitive `switch`."""
|
||||
return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
|
||||
|
|
|
@ -181,7 +181,7 @@ env_setitem = Primitive('env_setitem')
|
|||
env_getitem = Primitive('env_getitem')
|
||||
env_add = Primitive('env_add')
|
||||
J = Primitive('J')
|
||||
switch = Primitive('switch')
|
||||
switch = Primitive('Switch')
|
||||
switch_layer = Primitive('switch_layer')
|
||||
# for sum bprop
|
||||
reduced_shape = Primitive("reduced_shape")
|
||||
|
|
|
@ -310,7 +310,7 @@ TEST_F(TestOps, Col2ImV1Test) {
|
|||
|
||||
// Statements
|
||||
TEST_F(TestOps, SwitchTest) {
|
||||
auto prim = std::make_shared<Primitive>("switch");
|
||||
auto prim = std::make_shared<Primitive>("Switch");
|
||||
ASSERT_EQ(prim->name(), kPrimSwitch->name());
|
||||
}
|
||||
|
||||
|
|
|
@ -294,7 +294,7 @@ TEST_F(TestPrim, test_J_2) {
|
|||
|
||||
// tail half
|
||||
TEST_F(TestPrim, test_switch1) {
|
||||
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
|
||||
PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
|
||||
|
||||
AbstractBasePtr arg0 = FromValue(true, false);
|
||||
|
@ -307,7 +307,7 @@ TEST_F(TestPrim, test_switch1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPrim, test_switch2) {
|
||||
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
|
||||
PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
|
||||
|
||||
AbstractBasePtr arg0 = FromValue(false, false);
|
||||
|
|
|
@ -30,7 +30,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
scalar_add = Primitive(Constants.kScalarAdd)
|
||||
scalar_mul = Primitive(Constants.kScalarMul)
|
||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||
switch = Primitive('switch')
|
||||
switch = Primitive('Switch')
|
||||
|
||||
|
||||
def test_sexp_conversion():
|
||||
|
|
Loading…
Reference in New Issue