Change switch to Switch

This commit is contained in:
l00591931 2021-03-11 17:06:55 +08:00
parent b730069942
commit bbdb050fc7
10 changed files with 11 additions and 11 deletions

View File

@ -92,7 +92,7 @@ def zeros_like_tensor(x):
return value
def switch(c, x, y):
def Switch(c, x, y):
"""Implement `switch`."""
return x if c else y

View File

@ -47,7 +47,7 @@ namespace abstract {
using mindspore::parse::PyObjectWrapper;
std::unordered_set<std::string> prims_to_skip_undetermined_infer{
"MakeTuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {

View File

@ -184,7 +184,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
constexpr auto kBNInferGradOpName = "BNInferGrad";
constexpr auto kCallOpName = "call";
constexpr auto kPartialOpName = "partial";
constexpr auto kSwitchOpName = "switch";
constexpr auto kSwitchOpName = "Switch";
constexpr auto kReturnOpName = "Return";
constexpr auto kLarsV2OpName = "LarsV2";
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";

View File

@ -63,7 +63,7 @@ using InstType = std::pair<Instruction, VectorRef>;
using InstSet = std::vector<InstType>;
using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "switch",
const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "Switch",
"switch_return", "tuple", "input", "external", "push",
"primitive", "graph", "pad_stack", "switch_layer"};
class StructPartial : public Base {

View File

@ -403,7 +403,7 @@ inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("Switch");
inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");

View File

@ -267,7 +267,7 @@ def bprop_control_depend(x, y, out, dout):
return C.zeros_like(x), C.zeros_like(y)
@bprops.register("switch")
@bprops.register("Switch")
def bprop_switch(cond, tb, fb, out, dout):
"""Backpropagator for primitive `switch`."""
return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \

View File

@ -181,7 +181,7 @@ env_setitem = Primitive('env_setitem')
env_getitem = Primitive('env_getitem')
env_add = Primitive('env_add')
J = Primitive('J')
switch = Primitive('switch')
switch = Primitive('Switch')
switch_layer = Primitive('switch_layer')
# for sum bprop
reduced_shape = Primitive("reduced_shape")

View File

@ -310,7 +310,7 @@ TEST_F(TestOps, Col2ImV1Test) {
// Statements
TEST_F(TestOps, SwitchTest) {
auto prim = std::make_shared<Primitive>("switch");
auto prim = std::make_shared<Primitive>("Switch");
ASSERT_EQ(prim->name(), kPrimSwitch->name());
}

View File

@ -294,7 +294,7 @@ TEST_F(TestPrim, test_J_2) {
// tail half
TEST_F(TestPrim, test_switch1) {
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
AbstractBasePtr arg0 = FromValue(true, false);
@ -307,7 +307,7 @@ TEST_F(TestPrim, test_switch1) {
}
TEST_F(TestPrim, test_switch2) {
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
AbstractBasePtr arg0 = FromValue(false, false);

View File

@ -30,7 +30,7 @@ from mindspore.ops.operations import _grad_ops as G
scalar_add = Primitive(Constants.kScalarAdd)
scalar_mul = Primitive(Constants.kScalarMul)
tuple_getitem = Primitive(Constants.kTupleGetItem)
switch = Primitive('switch')
switch = Primitive('Switch')
def test_sexp_conversion():