forked from mindspore-Ecosystem/mindspore
add func type check for switch layer
This commit is contained in:
parent
492e41a4af
commit
abab21ed57
|
@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
|
|||
// args: tuple of items, index
|
||||
const std::string op_name = std::string("TupleGetItemTensor");
|
||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
auto ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr functions = ret_graph->add_parameter();
|
||||
auto functions = ret_graph->add_parameter();
|
||||
auto index = ret_graph->add_parameter();
|
||||
|
||||
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
||||
|
|
|
@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
|
|||
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
const size_t maximum_layer_num = 1000;
|
||||
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
|
||||
if (branches.size() < 1 || branches.size() > maximum_layer_num) {
|
||||
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
|
||||
<< branches.size() << " branches.";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
if (!branches[i]->isa<AbstractFunction>()) {
|
||||
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
|
||||
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
|
||||
<< branches[i]->ToString() << " as the " << i << "th element.";
|
||||
}
|
||||
|
|
|
@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch():
|
|||
net = NetConditionLackBranch()
|
||||
with pytest.raises(Exception):
|
||||
net(input_tensor_1, input_tensor_2)
|
||||
|
||||
|
||||
def test_parser_switch_layer_func_primitive():
|
||||
class FinalNet(nn.Cell):
|
||||
def __init__(self, funcs):
|
||||
super().__init__()
|
||||
self.funcs = funcs
|
||||
|
||||
def construct(self, i, input1):
|
||||
x = self.funcs[i](input1)
|
||||
return x
|
||||
|
||||
func1 = P.ReLU()
|
||||
func2 = P.Softmax()
|
||||
funcs = (func1, func2)
|
||||
net = FinalNet(funcs)
|
||||
|
||||
input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(1, mstype.int32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
net(i, input1)
|
||||
|
|
Loading…
Reference in New Issue