forked from mindspore-Ecosystem/mindspore
!11348 add check for outermost net inputs type and support isinstance first arg is an empty list
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
55594db01c
|
@ -42,6 +42,7 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
|
|||
mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4,
|
||||
mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8}
|
||||
|
||||
|
||||
def mean(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Reduces a dimension of a tensor by averaging all elements in the dimension.
|
||||
|
@ -218,11 +219,11 @@ def swapaxes(x, axis1, axis2):
|
|||
perm = F.make_range(0, x.ndim)
|
||||
new_perm = None
|
||||
if axis2 + 1 < x.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1]
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
|
||||
return F.transpose(x, new_perm)
|
||||
|
||||
|
@ -343,7 +344,7 @@ def isinstance_(x, base_type):
|
|||
|
||||
|
||||
def while_cond(x):
|
||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||
"""For while condition, if the condition is a tensor, the loop will not be unrolled"""
|
||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond:
|
||||
|
@ -373,7 +374,8 @@ def check_type_same(x_type, base_type):
|
|||
target_type = pytype_to_mstype[base_type]
|
||||
return isinstance(x_type, target_type)
|
||||
except KeyError:
|
||||
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'")
|
||||
raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, "
|
||||
f"Tensor, Parameter, or a tuple only including these types, but got {base_type}")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -441,7 +443,7 @@ def check_view_shape(x):
|
|||
return x
|
||||
|
||||
|
||||
# convert noraml param_check functions to constexpr functions
|
||||
# convert normal param_check functions to constexpr functions
|
||||
check_astype_dtype_const = constexpr(validator.check_astype_dtype)
|
||||
check_transpose_axis_const = constexpr(validator.check_transpose_axis)
|
||||
check_reshape_shp_const = constexpr(validator.check_reshape_shp)
|
||||
|
@ -449,8 +451,9 @@ check_flatten_order_const = constexpr(validator.check_flatten_order)
|
|||
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
|
||||
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as conditon, if is constant, return immediate bool value"""
|
||||
"""tensor as condition, if is constant, return immediate bool value"""
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond and F.isconstant(x):
|
||||
return const_tensor_to_bool(x)
|
||||
|
|
|
@ -382,7 +382,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
|||
.def(py::init<>());
|
||||
}));
|
||||
|
||||
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) {
|
||||
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
|
|
|
@ -104,7 +104,7 @@ class Tail : public MetaFuncGraph {
|
|||
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue);
|
||||
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
|
||||
|
||||
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
||||
|
||||
|
|
|
@ -156,9 +156,12 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
|
|||
}
|
||||
|
||||
bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
|
||||
if (value_vec.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (auto &elem : value_vec) {
|
||||
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
|
||||
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);
|
||||
const auto &vec = GetValue<ValuePtrList>(elem);
|
||||
auto is_graph = IsAllFuncInValueSequence(vec);
|
||||
if (!is_graph) {
|
||||
return false;
|
||||
|
@ -194,20 +197,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
|
|||
return cnode;
|
||||
}
|
||||
|
||||
// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node
|
||||
// transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
|
||||
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
|
||||
const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value());
|
||||
const auto &value_vec = GetValue<ValuePtrList>(value_node->value());
|
||||
if (!IsAllFuncInValueSequence(value_vec)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
|
||||
// So if has graph in list, try to replace the node with make tuple of graph value node.
|
||||
// we do this because the graphmanger won't investigate the graph inside valuetuple,
|
||||
// we do this because the graph manager won't investigate the graph inside valuetuple,
|
||||
// change the vector of graph to be make_tuple of graph value node.
|
||||
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all
|
||||
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
|
||||
// independent nodes.
|
||||
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
|
||||
// replace the ret ptr to be make tuple of graph value node
|
||||
|
|
|
@ -69,10 +69,6 @@ namespace pipeline {
|
|||
using Tensor = mindspore::tensor::Tensor;
|
||||
using MetaTensor = mindspore::tensor::MetaTensor;
|
||||
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractListPtr;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTensorPtr;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
|
@ -103,6 +99,33 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
|
|||
return abstract::FromValue(value, broaden);
|
||||
}
|
||||
|
||||
bool CheckArgValid(const py::handle &arg) {
|
||||
if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
|
||||
auto vector_arg = py::cast<py::list>(arg);
|
||||
return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
|
||||
}
|
||||
|
||||
if (py::isinstance<py::dict>(arg)) {
|
||||
auto dict_arg = py::cast<py::dict>(arg);
|
||||
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
||||
}
|
||||
|
||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) ||
|
||||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
|
||||
}
|
||||
|
||||
void CheckArgsValid(const py::tuple &args) {
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (!CheckArgValid(args[i])) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "The inputs types of the outermost network support bool, int, float, tensor, "
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
|
||||
"and tuple or list containing only these types, and dict whose values are these types, but got "
|
||||
<< i << "th arg is " << py::str(args[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetCompileExceptionInfo() {
|
||||
std::ostringstream oss;
|
||||
trace::TraceGraphEval();
|
||||
|
@ -470,11 +493,13 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
MS_LOG(ERROR) << "Arg phase must be string.";
|
||||
return false;
|
||||
}
|
||||
// check the arg valid?
|
||||
// check the function or net is valid
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(ERROR) << "Find error: parse obj is None.";
|
||||
return false;
|
||||
}
|
||||
// check the args of function or net is valid
|
||||
CheckArgsValid(args);
|
||||
#ifdef ENABLE_GE
|
||||
GetGeBackendPolicy();
|
||||
#endif
|
||||
|
|
|
@ -52,7 +52,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
return t->DeepCopy();
|
||||
});
|
||||
(void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init());
|
||||
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
|
||||
(void)py::class_<Bool, Number, std::shared_ptr<Bool>>(m_sub, "Bool")
|
||||
.def(py::init())
|
||||
.def(py::pickle(
|
||||
[](const Bool &) { // __getstate__
|
||||
|
@ -61,7 +61,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
[](const py::tuple &) { // __setstate__
|
||||
return std::make_shared<Bool>();
|
||||
}));
|
||||
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
|
||||
(void)py::class_<Int, Number, std::shared_ptr<Int>>(m_sub, "Int")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
|
@ -77,7 +77,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
Int data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt")
|
||||
(void)py::class_<UInt, Number, std::shared_ptr<UInt>>(m_sub, "UInt")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
|
@ -93,7 +93,7 @@ REGISTER_PYBIND_DEFINE(
|
|||
UInt data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float")
|
||||
(void)py::class_<Float, Number, std::shared_ptr<Float>>(m_sub, "Float")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_isinstance():
|
||||
|
@ -35,6 +35,7 @@ def test_isinstance():
|
|||
self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member)
|
||||
self.list_member = list(self.tuple_member)
|
||||
self.weight = Parameter(1.0)
|
||||
self.empty_list = []
|
||||
|
||||
def construct(self, x, y):
|
||||
is_int = isinstance(self.int_member, int)
|
||||
|
@ -54,7 +55,9 @@ def test_isinstance():
|
|||
bool_is_string = isinstance(self.bool_member, str)
|
||||
tensor_is_tuple = isinstance(x, tuple)
|
||||
tuple_is_list = isinstance(self.tuple_member, list)
|
||||
return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \
|
||||
is_empty_list = isinstance(self.empty_list, list)
|
||||
return is_int, is_float, is_bool, is_string, \
|
||||
is_empty_list, is_parameter, is_tensor_const, is_tensor_var, \
|
||||
is_tuple_const, is_tuple_var, is_list_const, is_list_var, \
|
||||
is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \
|
||||
float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list
|
||||
|
@ -62,7 +65,7 @@ def test_isinstance():
|
|||
net = Net()
|
||||
x = Tensor(np.arange(4))
|
||||
y = Tensor(np.arange(5))
|
||||
assert net(x, y) == (True,) * 13 + (False,) * 4
|
||||
assert net(x, y) == (True,) * 14 + (False,) * 4
|
||||
|
||||
|
||||
def test_isinstance_not_supported():
|
||||
|
@ -77,7 +80,8 @@ def test_isinstance_not_supported():
|
|||
net = Net()
|
||||
with pytest.raises(TypeError) as err:
|
||||
net()
|
||||
assert "The type 'None' is not supported for 'isinstance'" in str(err.value)
|
||||
assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \
|
||||
"or a tuple only including these types, but got None" in str(err.value)
|
||||
|
||||
|
||||
def test_isinstance_second_arg_is_list():
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test ms_function pass non_tensor inputs"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, ms_function, Parameter
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
@ -56,5 +56,5 @@ def tensor_reduce(tensor_x, axis, tensor_y):
|
|||
def test_tensor_reduce():
|
||||
tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32))
|
||||
axis = (0, 1)
|
||||
tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2))
|
||||
tensor_y = Tensor(np.ones((4, 5), np.float32) * 2)
|
||||
tensor_reduce(tensor_x, axis, tensor_y)
|
||||
|
|
|
@ -14,47 +14,110 @@
|
|||
# ============================================================================
|
||||
""" test outermost net pass non_tensor inputs"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_outermost_net_pass_scalar_tuple_list_dict():
|
||||
class TestNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestNet, self).__init__()
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, tuple_a, z, list_m, w, s, dict_n):
|
||||
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]
|
||||
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
||||
if flag:
|
||||
return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"]
|
||||
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"]
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.forward_net = net
|
||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||
self.grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, tuple_a, z, list_m, w, s, dict_n):
|
||||
return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n)
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.forward_net = net
|
||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||
self.grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
x = Tensor(np.ones((2, 2), np.float32))
|
||||
y = Tensor(np.ones((2, 2), np.float32) * 2)
|
||||
z = Tensor(np.ones((2, 2), np.float32) * 3)
|
||||
w = Tensor(np.ones((2, 2), np.float32) * 4)
|
||||
arg_t0 = (x, y, z, w)
|
||||
arg_t1 = (w, y, z, w)
|
||||
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
args_d0 = {"x": x, "y": y}
|
||||
args_d1 = {"x": x, "y": y}
|
||||
forward_net = TestNet()
|
||||
forward_net(arg_t0, z, arg_l0, w, 6, args_d0)
|
||||
forward_net(arg_t1, z, arg_l1, x, 6, args_d1)
|
||||
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
||||
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
|
||||
|
||||
grad_net = GradNet(forward_net)
|
||||
grad_net(arg_t0, z, arg_l0, w, 6, args_d0)
|
||||
grad_net(arg_t1, z, arg_l1, x, 6, args_d1)
|
||||
|
||||
x = Tensor(np.ones((2, 2), np.float32))
|
||||
y = Tensor(np.ones((2, 2), np.float32) * 2)
|
||||
z = Tensor(np.ones((2, 2), np.float32) * 3)
|
||||
w = Tensor(np.ones((2, 2), np.float32) * 4)
|
||||
sl = 6
|
||||
s = "ok"
|
||||
arg_t0 = (x, y, z, w)
|
||||
arg_t1 = (w, y, z, w)
|
||||
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
args_d0 = {"x": x, "y": y}
|
||||
args_d1 = {"x": x, "y": y}
|
||||
flag_0 = True
|
||||
flag_1 = False
|
||||
|
||||
|
||||
p = Parameter(x, name="weight")
|
||||
a = np.ones((2, 2))
|
||||
|
||||
forward_net = Net()
|
||||
grad_net = GradNet(forward_net)
|
||||
|
||||
|
||||
def test_outermost_net_inputs_including_non_tensor():
|
||||
forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
|
||||
forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
|
||||
|
||||
|
||||
def test_grad_net_inputs_including_non_tensor():
|
||||
grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
|
||||
grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
|
||||
|
||||
|
||||
def test_net_inputs_including_str():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 1th arg is ok" in str(err.value)
|
||||
|
||||
|
||||
def test_outermost_net_pass_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 1th arg is Parameter (name=weight)" in str(err.value)
|
||||
|
||||
|
||||
def test_outermost_net_pass_tuple_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 6th arg is (" in str(err.value)
|
||||
|
||||
|
||||
def test_outermost_net_pass_list_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 4th arg is [" in str(err.value)
|
||||
|
||||
|
||||
def test_grad_net_pass_dict_including_parameter():
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
|
||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 3th arg is {" in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue