support None argument for the outermost net

This commit is contained in:
huanghui 2021-10-11 14:19:53 +08:00
parent 2c738757c3
commit 1af32f74f9
3 changed files with 13 additions and 13 deletions

View File

@ -144,8 +144,8 @@ bool CheckArgValid(const py::handle &arg) {
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
}
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) ||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
}
std::string GetCompileExceptionInfo() {
@ -235,7 +235,7 @@ void CheckArgsValid(const py::tuple &args) {
for (size_t i = 0; i < args.size(); i++) {
if (!CheckArgValid(args[i])) {
MS_EXCEPTION(TypeError)
<< "The inputs types of the outermost network support bool, int, float, tensor, "
<< "The inputs types of the outermost network support bool, int, float, None, tensor, "
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
"and tuple or list containing only these types, and dict whose values are these types, but the "
<< i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";

View File

@ -95,7 +95,7 @@ def test_grad_first_input_net():
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
@ -104,7 +104,7 @@ def test_net_inputs_including_str():
def test_outermost_net_pass_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
def test_outermost_net_pass_tuple_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
def test_outermost_net_pass_list_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
def test_grad_net_pass_dict_including_parameter():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)

View File

@ -95,7 +95,7 @@ def test_grad_first_input_net():
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
@ -104,7 +104,7 @@ def test_net_inputs_including_str():
def test_outermost_net_pass_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
def test_outermost_net_pass_tuple_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
def test_outermost_net_pass_list_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
def test_grad_net_pass_dict_including_parameter():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)