From 1af32f74f93883d1e28989c121a2c04c036725d1 Mon Sep 17 00:00:00 2001 From: huanghui Date: Mon, 11 Oct 2021 14:19:53 +0800 Subject: [PATCH] support None argument for the outermost net --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 6 +++--- .../parse/test_outermost_net_pass_non_tensor_inputs.py | 10 +++++----- .../pynative_mode/test_outermost_non_tensor_input.py | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 28f12c21b42..99190a6e9cd 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -144,8 +144,8 @@ bool CheckArgValid(const py::handle &arg) { return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); } - return py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || - (py::isinstance(arg) && !py::hasattr(arg, "__parameter__")); + return py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || + py::isinstance(arg) || (py::isinstance(arg) && !py::hasattr(arg, "__parameter__")); } std::string GetCompileExceptionInfo() { @@ -235,7 +235,7 @@ void CheckArgsValid(const py::tuple &args) { for (size_t i = 0; i < args.size(); i++) { if (!CheckArgValid(args[i])) { MS_EXCEPTION(TypeError) - << "The inputs types of the outermost network support bool, int, float, tensor, " + << "The inputs types of the outermost network support bool, int, float, None, tensor, " "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " "and tuple or list containing only these types, and dict whose values are these types, but the " << i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'."; diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index 8ff8e4b25e2..368914a0b34 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -95,7 +95,7 @@ def test_grad_first_input_net(): def test_net_inputs_including_str(): with pytest.raises(TypeError) as err: grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 1th arg type is , value is 'ok'" in str(err.value) @@ -104,7 +104,7 @@ def test_net_inputs_including_str(): def test_outermost_net_pass_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 1th arg type is , " \ @@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter(): def test_outermost_net_pass_tuple_including_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p)) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 6th arg type is , value is '(" in str(err.value) @@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter(): def test_outermost_net_pass_list_including_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 4th arg type is , value is '[" in str(err.value) @@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter(): def test_grad_net_pass_dict_including_parameter(): with pytest.raises(TypeError) as err: grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 3th arg type is , value is '{" in str(err.value) diff --git a/tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py b/tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py index e3303cec6ed..d56687e1308 100644 --- a/tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py +++ b/tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py @@ -95,7 +95,7 @@ def test_grad_first_input_net(): def test_net_inputs_including_str(): with pytest.raises(TypeError) as err: grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 1th arg type is , value is 'ok'" in str(err.value) @@ -104,7 +104,7 @@ def test_net_inputs_including_str(): def test_outermost_net_pass_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 1th arg type is , " \ @@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter(): def test_outermost_net_pass_tuple_including_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p)) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 6th arg type is , value is '(" in str(err.value) @@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter(): def test_outermost_net_pass_list_including_parameter(): with pytest.raises(TypeError) as err: forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 4th arg type is , value is '[" in str(err.value) @@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter(): def test_grad_net_pass_dict_including_parameter(): with pytest.raises(TypeError) as err: grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) - assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ "but the 3th arg type is , value is '{" in str(err.value)