add func type check for switch layer

This commit is contained in:
panyifeng 2020-08-21 15:24:47 +08:00
parent 492e41a4af
commit abab21ed57
3 changed files with 30 additions and 14 deletions

View File

@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
// args: tuple of items, index
const std::string op_name = std::string("TupleGetItemTensor");
abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractBasePtrList branches = branches_abs->elements();
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr functions = ret_graph->add_parameter();
auto functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter();
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
return ret_graph;
}
MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
}
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {

View File

@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
AbstractBasePtrList branches = branches_abs->elements();
const size_t maximum_layer_num = 1000;
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
if (branches.size() < 1 || branches.size() > maximum_layer_num) {
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
<< branches.size() << " branches.";
}
for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<AbstractFunction>()) {
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element.";
}

View File

@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch():
net = NetConditionLackBranch()
with pytest.raises(Exception):
net(input_tensor_1, input_tensor_2)
def test_parser_switch_layer_func_primitive():
class FinalNet(nn.Cell):
def __init__(self, funcs):
super().__init__()
self.funcs = funcs
def construct(self, i, input1):
x = self.funcs[i](input1)
return x
func1 = P.ReLU()
func2 = P.Softmax()
funcs = (func1, func2)
net = FinalNet(funcs)
input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)
with pytest.raises(ValueError):
net(i, input1)