!11348 add check for outermost net inputs type and support isinstance first arg is an empty list

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-20 19:00:08 +08:00 committed by Gitee
commit 55594db01c
9 changed files with 159 additions and 61 deletions

View File

@ -42,6 +42,7 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4,
mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8}
def mean(x, axis=(), keep_dims=False): def mean(x, axis=(), keep_dims=False):
""" """
Reduces a dimension of a tensor by averaging all elements in the dimension. Reduces a dimension of a tensor by averaging all elements in the dimension.
@ -218,11 +219,11 @@ def swapaxes(x, axis1, axis2):
perm = F.make_range(0, x.ndim) perm = F.make_range(0, x.ndim)
new_perm = None new_perm = None
if axis2 + 1 < x.ndim: if axis2 + 1 < x.ndim:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
else: else:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1] perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
return F.transpose(x, new_perm) return F.transpose(x, new_perm)
@ -343,7 +344,7 @@ def isinstance_(x, base_type):
def while_cond(x): def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled""" """For while condition, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x)) is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond: if is_cond:
@ -373,7 +374,8 @@ def check_type_same(x_type, base_type):
target_type = pytype_to_mstype[base_type] target_type = pytype_to_mstype[base_type]
return isinstance(x_type, target_type) return isinstance(x_type, target_type)
except KeyError: except KeyError:
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, "
f"Tensor, Parameter, or a tuple only including these types, but got {base_type}")
@constexpr @constexpr
@ -441,7 +443,7 @@ def check_view_shape(x):
return x return x
# convert noraml param_check functions to constexpr functions # convert normal param_check functions to constexpr functions
check_astype_dtype_const = constexpr(validator.check_astype_dtype) check_astype_dtype_const = constexpr(validator.check_astype_dtype)
check_transpose_axis_const = constexpr(validator.check_transpose_axis) check_transpose_axis_const = constexpr(validator.check_transpose_axis)
check_reshape_shp_const = constexpr(validator.check_reshape_shp) check_reshape_shp_const = constexpr(validator.check_reshape_shp)
@ -449,8 +451,9 @@ check_flatten_order_const = constexpr(validator.check_flatten_order)
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
def tensor_bool(x): def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value""" """tensor as condition, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x)) is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond and F.isconstant(x): if is_cond and F.isconstant(x):
return const_tensor_to_bool(x) return const_tensor_to_bool(x)

View File

@ -382,7 +382,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
.def(py::init<>()); .def(py::init<>());
})); }));
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
MS_EXCEPTION_IF_NULL(sequeue); MS_EXCEPTION_IF_NULL(sequeue);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();

View File

@ -104,7 +104,7 @@ class Tail : public MetaFuncGraph {
MS_DECLARE_PARENT(Tail, MetaFuncGraph) MS_DECLARE_PARENT(Tail, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }

View File

@ -156,9 +156,12 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
} }
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) { bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
if (value_vec.empty()) {
return false;
}
for (auto &elem : value_vec) { for (auto &elem : value_vec) {
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) { if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<std::vector<ValuePtr>>(elem); const auto &vec = GetValue<ValuePtrList>(elem);
auto is_graph = IsAllFuncInValueSequence(vec); auto is_graph = IsAllFuncInValueSequence(vec);
if (!is_graph) { if (!is_graph) {
return false; return false;
@ -194,20 +197,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
return cnode; return cnode;
} }
// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node // transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const ValueNodePtr &value_node, AnfNodePtr *const transformed) { const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value()); const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
if (!IsAllFuncInValueSequence(value_vec)) { if (!IsAllFuncInValueSequence(value_vec)) {
return false; return false;
} }
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node. // So if has graph in list, try to replace the node with make tuple of graph value node.
// we do this because the graphmanger won't investigate the graph inside valuetuple, // we do this because the graph manager won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node. // change the vector of graph to be make_tuple of graph value node.
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes. // independent nodes.
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
// replace the ret ptr to be make tuple of graph value node // replace the ret ptr to be make tuple of graph value node

View File

@ -69,10 +69,6 @@ namespace pipeline {
using Tensor = mindspore::tensor::Tensor; using Tensor = mindspore::tensor::Tensor;
using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensor = mindspore::tensor::MetaTensor;
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTensorPtr; using mindspore::abstract::AbstractTensorPtr;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
@ -103,6 +99,33 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
return abstract::FromValue(value, broaden); return abstract::FromValue(value, broaden);
} }
bool CheckArgValid(const py::handle &arg) {
if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
auto vector_arg = py::cast<py::list>(arg);
return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
}
if (py::isinstance<py::dict>(arg)) {
auto dict_arg = py::cast<py::dict>(arg);
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
}
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) ||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
}
void CheckArgsValid(const py::tuple &args) {
for (size_t i = 0; i < args.size(); i++) {
if (!CheckArgValid(args[i])) {
MS_EXCEPTION(TypeError)
<< "The inputs types of the outermost network support bool, int, float, tensor, "
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
"and tuple or list containing only these types, and dict whose values are these types, but got "
<< i << "th arg is " << py::str(args[i]);
}
}
}
std::string GetCompileExceptionInfo() { std::string GetCompileExceptionInfo() {
std::ostringstream oss; std::ostringstream oss;
trace::TraceGraphEval(); trace::TraceGraphEval();
@ -470,11 +493,13 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
MS_LOG(ERROR) << "Arg phase must be string."; MS_LOG(ERROR) << "Arg phase must be string.";
return false; return false;
} }
// check the arg valid? // check the function or net is valid
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
MS_LOG(ERROR) << "Find error: parse obj is None."; MS_LOG(ERROR) << "Find error: parse obj is None.";
return false; return false;
} }
// check the args of function or net is valid
CheckArgsValid(args);
#ifdef ENABLE_GE #ifdef ENABLE_GE
GetGeBackendPolicy(); GetGeBackendPolicy();
#endif #endif

View File

@ -52,7 +52,7 @@ REGISTER_PYBIND_DEFINE(
return t->DeepCopy(); return t->DeepCopy();
}); });
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init()); (void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") (void)py::class_<Bool, Number, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init()) .def(py::init())
.def(py::pickle( .def(py::pickle(
[](const Bool &) { // __getstate__ [](const Bool &) { // __getstate__
@ -61,7 +61,7 @@ REGISTER_PYBIND_DEFINE(
[](const py::tuple &) { // __setstate__ [](const py::tuple &) { // __setstate__
return std::make_shared<Bool>(); return std::make_shared<Bool>();
})); }));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") (void)py::class_<Int, Number, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
@ -77,7 +77,7 @@ REGISTER_PYBIND_DEFINE(
Int data(t[0].cast<py::int_>()); Int data(t[0].cast<py::int_>());
return data; return data;
})); }));
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt") (void)py::class_<UInt, Number, std::shared_ptr<UInt>>(m_sub, "UInt")
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
@ -93,7 +93,7 @@ REGISTER_PYBIND_DEFINE(
UInt data(t[0].cast<py::int_>()); UInt data(t[0].cast<py::int_>());
return data; return data;
})); }));
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float") (void)py::class_<Float, Number, std::shared_ptr<Float>>(m_sub, "Float")
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(

View File

@ -20,7 +20,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE)
def test_isinstance(): def test_isinstance():
@ -35,6 +35,7 @@ def test_isinstance():
self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member)
self.list_member = list(self.tuple_member) self.list_member = list(self.tuple_member)
self.weight = Parameter(1.0) self.weight = Parameter(1.0)
self.empty_list = []
def construct(self, x, y): def construct(self, x, y):
is_int = isinstance(self.int_member, int) is_int = isinstance(self.int_member, int)
@ -54,7 +55,9 @@ def test_isinstance():
bool_is_string = isinstance(self.bool_member, str) bool_is_string = isinstance(self.bool_member, str)
tensor_is_tuple = isinstance(x, tuple) tensor_is_tuple = isinstance(x, tuple)
tuple_is_list = isinstance(self.tuple_member, list) tuple_is_list = isinstance(self.tuple_member, list)
return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \ is_empty_list = isinstance(self.empty_list, list)
return is_int, is_float, is_bool, is_string, \
is_empty_list, is_parameter, is_tensor_const, is_tensor_var, \
is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ is_tuple_const, is_tuple_var, is_list_const, is_list_var, \
is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \
float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list
@ -62,7 +65,7 @@ def test_isinstance():
net = Net() net = Net()
x = Tensor(np.arange(4)) x = Tensor(np.arange(4))
y = Tensor(np.arange(5)) y = Tensor(np.arange(5))
assert net(x, y) == (True,) * 13 + (False,) * 4 assert net(x, y) == (True,) * 14 + (False,) * 4
def test_isinstance_not_supported(): def test_isinstance_not_supported():
@ -77,7 +80,8 @@ def test_isinstance_not_supported():
net = Net() net = Net()
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
net() net()
assert "The type 'None' is not supported for 'isinstance'" in str(err.value) assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \
"or a tuple only including these types, but got None" in str(err.value)
def test_isinstance_second_arg_is_list(): def test_isinstance_second_arg_is_list():

View File

@ -15,7 +15,7 @@
""" test ms_function pass non_tensor inputs""" """ test ms_function pass non_tensor inputs"""
import numpy as np import numpy as np
from mindspore import Tensor, ms_function, Parameter from mindspore import Tensor, ms_function
from mindspore import context from mindspore import context
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -56,5 +56,5 @@ def tensor_reduce(tensor_x, axis, tensor_y):
def test_tensor_reduce(): def test_tensor_reduce():
tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32))
axis = (0, 1) axis = (0, 1)
tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) tensor_y = Tensor(np.ones((4, 5), np.float32) * 2)
tensor_reduce(tensor_x, axis, tensor_y) tensor_reduce(tensor_x, axis, tensor_y)

View File

@ -14,47 +14,110 @@
# ============================================================================ # ============================================================================
""" test outermost net pass non_tensor inputs""" """ test outermost net pass non_tensor inputs"""
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
from mindspore.ops import composite as C from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
def test_outermost_net_pass_scalar_tuple_list_dict(): class Net(nn.Cell):
class TestNet(nn.Cell): def __init__(self):
def __init__(self): super(Net, self).__init__()
super(TestNet, self).__init__()
def construct(self, tuple_a, z, list_m, w, s, dict_n): def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] if flag:
return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"]
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=True)
def construct(self, tuple_a, z, list_m, w, s, dict_n): class GradNet(nn.Cell):
return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) def __init__(self, net):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=True)
x = Tensor(np.ones((2, 2), np.float32)) def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
y = Tensor(np.ones((2, 2), np.float32) * 2) return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
z = Tensor(np.ones((2, 2), np.float32) * 3)
w = Tensor(np.ones((2, 2), np.float32) * 4)
arg_t0 = (x, y, z, w)
arg_t1 = (w, y, z, w)
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
args_d0 = {"x": x, "y": y}
args_d1 = {"x": x, "y": y}
forward_net = TestNet()
forward_net(arg_t0, z, arg_l0, w, 6, args_d0)
forward_net(arg_t1, z, arg_l1, x, 6, args_d1)
grad_net = GradNet(forward_net)
grad_net(arg_t0, z, arg_l0, w, 6, args_d0) x = Tensor(np.ones((2, 2), np.float32))
grad_net(arg_t1, z, arg_l1, x, 6, args_d1) y = Tensor(np.ones((2, 2), np.float32) * 2)
z = Tensor(np.ones((2, 2), np.float32) * 3)
w = Tensor(np.ones((2, 2), np.float32) * 4)
sl = 6
s = "ok"
arg_t0 = (x, y, z, w)
arg_t1 = (w, y, z, w)
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
args_d0 = {"x": x, "y": y}
args_d1 = {"x": x, "y": y}
flag_0 = True
flag_1 = False
p = Parameter(x, name="weight")
a = np.ones((2, 2))
forward_net = Net()
grad_net = GradNet(forward_net)
def test_outermost_net_inputs_including_non_tensor():
forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
def test_grad_net_inputs_including_non_tensor():
grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:
grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 1th arg is ok" in str(err.value)
def test_outermost_net_pass_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 1th arg is Parameter (name=weight)" in str(err.value)
def test_outermost_net_pass_tuple_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 6th arg is (" in str(err.value)
def test_outermost_net_pass_list_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 4th arg is [" in str(err.value)
def test_grad_net_pass_dict_including_parameter():
with pytest.raises(TypeError) as err:
grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 3th arg is {" in str(err.value)