forked from mindspore-Ecosystem/mindspore
!24656 Support None argument for the outermost net
Merge pull request !24656 from huanghui/support-none-arg
This commit is contained in:
commit
6250998108
|
@ -144,8 +144,8 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
||||
}
|
||||
|
||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) ||
|
||||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
|
||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
|
||||
py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
|
||||
}
|
||||
|
||||
std::string GetCompileExceptionInfo() {
|
||||
|
@ -245,7 +245,7 @@ void CheckArgsValid(const py::tuple &args) {
|
|||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (!CheckArgValid(args[i])) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "The inputs types of the outermost network support bool, int, float, tensor, "
|
||||
<< "The inputs types of the outermost network support bool, int, float, None, tensor, "
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
|
||||
"and tuple or list containing only these types, and dict whose values are these types, but the "
|
||||
<< i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_grad_first_input_net():
|
|||
def test_net_inputs_including_str():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
|
||||
|
@ -104,7 +104,7 @@ def test_net_inputs_including_str():
|
|||
def test_outermost_net_pass_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
|
||||
|
@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
|
|||
def test_outermost_net_pass_tuple_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
|
||||
|
@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
|
|||
def test_outermost_net_pass_list_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
|
||||
|
@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
|
|||
def test_grad_net_pass_dict_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_grad_first_input_net():
|
|||
def test_net_inputs_including_str():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
|
||||
|
@ -104,7 +104,7 @@ def test_net_inputs_including_str():
|
|||
def test_outermost_net_pass_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
|
||||
|
@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
|
|||
def test_outermost_net_pass_tuple_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
|
||||
|
@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
|
|||
def test_outermost_net_pass_list_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
|
||||
|
@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
|
|||
def test_grad_net_pass_dict_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue