change tensor dtype and shape from function to attr

This commit is contained in:
buxue 2020-06-09 12:18:51 +08:00
parent 553432c968
commit 66bbdb4a31
64 changed files with 348 additions and 342 deletions

View File

@ -79,12 +79,12 @@ if __name__ == '__main__':
for _, cell in net.cells_and_names(): for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(), cell.weight.default_input.shape,
cell.weight.default_input.dtype()).to_tensor() cell.weight.default_input.dtype).to_tensor()
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(), cell.weight.default_input.shape,
cell.weight.default_input.dtype()).to_tensor() cell.weight.default_input.dtype).to_tensor()
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0

View File

@ -338,15 +338,15 @@ class Dense_Thor(Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
self.thor = True self.thor = True
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape()[1] != in_channels: weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -56,7 +56,7 @@ def init_net_param(network, init_value='ones'):
params = network.trainable_params() params = network.trainable_params()
for p in params: for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype))
def main(): def main():

View File

@ -384,6 +384,28 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::init<py::tuple, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) .def(py::init<py::tuple, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr)
.def(py::init<Tensor, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) .def(py::init<Tensor, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr)
.def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_)
.def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.
Returns:
type, the data type of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.dtype
Int32
)mydelimiter")
.def_property_readonly("shape", &Tensor::GetPyTupleShape, R"mydelimiter(
Get the tensor's shape.
Returns:
tuple[int], the shape of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((3, 3)))
>>> data.shape()
(3, 3)
)mydelimiter")
.def("asnumpy", &Tensor::data_sync, R"mydelimiter( .def("asnumpy", &Tensor::data_sync, R"mydelimiter(
Convert tensor to numpy.ndarray. Convert tensor to numpy.ndarray.
@ -437,17 +459,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.dim() >>> data.dim()
2 2
)mydelimiter") )mydelimiter")
.def("dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.
Returns:
type, the data type of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.dtype()
Int32
)mydelimiter")
.def("set_dtype", &Tensor::SetDtype, R"mydelimiter( .def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
Set the tensor's data type. Set the tensor's data type.
@ -459,17 +470,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32) >>> data.set_dtype(mindspore.int32)
mindspore.int32 mindspore.int32
)mydelimiter") )mydelimiter")
.def("shape", &Tensor::GetPyTupleShape, R"mydelimiter(
Get the tensor's shape.
Returns:
tuple[int], the shape of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((3, 3)))
>>> data.shape()
(3, 3)
)mydelimiter")
.def("__str__", &Tensor::ToString) .def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle( .def(py::pickle(
@ -488,8 +488,8 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");
})); }));
} // namespace tensor } // namespace tensor
} // namespace mindspore } // namespace mindspore

View File

@ -170,8 +170,8 @@ def get_py_obj_dtype(obj):
Type of MindSpore type. Type of MindSpore type.
""" """
# Tensor # Tensor
if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type): if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
return tensor_type(obj.dtype()) return tensor_type(obj.dtype)
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function return function
if isinstance(obj, (typing.Type, type)): if isinstance(obj, (typing.Type, type)):

View File

@ -331,11 +331,11 @@ def initializer(init, shape=None, dtype=mstype.float32):
raise TypeError("Unsupported init type '{}'.".format(type(init))) raise TypeError("Unsupported init type '{}'.".format(type(init)))
if isinstance(init, Tensor): if isinstance(init, Tensor):
init_shape = init.shape() init_shape = init.shape
shape = shape if isinstance(shape, (tuple, list)) else [shape] shape = shape if isinstance(shape, (tuple, list)) else [shape]
if shape is not None and init_shape != tuple(shape): if shape is not None and init_shape != tuple(shape):
raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and " raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
"the variable shape {}.".format(list(init.shape()), shape)) "the variable shape {}.".format(list(init.shape), shape))
return init return init
if isinstance(shape, list): if isinstance(shape, list):

View File

@ -140,8 +140,8 @@ class Parameter:
x.name = prefix + '.' + x.name x.name = prefix + '.' + x.name
x.is_init = False x.is_init = False
if init != 'same': if init != 'same':
shape = self.default_input.shape() shape = self.default_input.shape
dtype = self.default_input.dtype() dtype = self.default_input.dtype
if isinstance(init, (str, Initializer, numbers.Number)): if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype) x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape) x.default_input = MetaTensor(dtype, shape)

View File

@ -45,13 +45,13 @@ class Tensor(Tensor_):
>>> # init a tensor with input data >>> # init a tensor with input data
>>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32) >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
>>> assert isinstance(t1, Tensor) >>> assert isinstance(t1, Tensor)
>>> assert t1.shape() == (1, 2, 3) >>> assert t1.shape == (1, 2, 3)
>>> assert t1.dtype() == mindspore.float32 >>> assert t1.dtype == mindspore.float32
>>> >>>
>>> # init a tensor with a float scalar >>> # init a tensor with a float scalar
>>> t2 = Tensor(0.1) >>> t2 = Tensor(0.1)
>>> assert isinstance(t2, Tensor) >>> assert isinstance(t2, Tensor)
>>> assert t2.dtype() == mindspore.float64 >>> assert t2.dtype == mindspore.float64
""" """
def __init__(self, input_data, dtype=None): def __init__(self, input_data, dtype=None):
@ -80,7 +80,7 @@ class Tensor(Tensor_):
return False return False
# The GE backend don't support single `Equal` operator execution. # The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend. # bool type is not supported for `Equal` operator in backend.
if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_: if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_:
return Tensor(np.array(self.asnumpy() == other.asnumpy())) return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other) return tensor_operator_registry.get('__eq__')(self, other)
@ -166,7 +166,7 @@ class Tensor(Tensor_):
return out[0] return out[0]
def __str__(self): def __str__(self):
if self.dtype() == mstype.type_none: if self.dtype == mstype.type_none:
return "Unknown Tensor type!" return "Unknown Tensor type!"
return str(self.asnumpy()) return str(self.asnumpy())

View File

@ -267,21 +267,21 @@ class MobileNetV2(nn.Cell):
if isinstance(m, (nn.Conv2d, DepthwiseConv)): if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32"))) m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data( m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data( m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense): elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal( m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32"))) 0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
def mobilenet_v2(**kwargs): def mobilenet_v2(**kwargs):

View File

@ -322,21 +322,21 @@ class MobileNetV3(nn.Cell):
if isinstance(m, (nn.Conv2d)): if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32"))) m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data( m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data( m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense): elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal( m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32"))) 0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
def mobilenet_v3(model_name, **kwargs): def mobilenet_v3(model_name, **kwargs):

View File

@ -131,7 +131,7 @@ class Flatten(Cell):
Examples: Examples:
>>> net = nn.Flatten() >>> net = nn.Flatten()
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
>>> input.shape() >>> input.shape
(2, 2, 2) (2, 2, 2)
>>> net(input) >>> net(input)
[[1.2 1.2 2.1 2.1] [[1.2 1.2 2.1 2.1]
@ -198,15 +198,15 @@ class Dense(Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape()[1] != in_channels: weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -69,7 +69,7 @@ class Conv2d(Cell):
Examples: Examples:
>>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape() >>> net(input).shape
(1, 240, 1024, 640) (1, 240, 1024, 640)
""" """

View File

@ -168,7 +168,7 @@ class Conv2d(_Conv):
Examples: Examples:
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape() >>> net(input).shape
(1, 240, 1024, 640) (1, 240, 1024, 640)
""" """
@cell_attr_register @cell_attr_register

View File

@ -56,7 +56,7 @@ class Embedding(Cell):
>>> >>>
>>> # Maps the input word IDs to word embedding. >>> # Maps the input word IDs to word embedding.
>>> output = net(input_data) >>> output = net(input_data)
>>> output.shape() >>> output.shape
(8, 128, 768) (8, 128, 768)
""" """
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):

View File

@ -474,7 +474,7 @@ class LayerNorm(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:] >>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x) >>> m(x)
""" """

View File

@ -113,7 +113,7 @@ class MaxPool2d(_PoolNd):
[0. 0. 4. 0.] [0. 0. 4. 0.]
[1. 8. 7. 0.]]]] [1. 8. 7. 0.]]]]
>>> output = pool(x) >>> output = pool(x)
>>> output.shape() >>> output.shape
(1, 2, 2, 2) (1, 2, 2, 2)
>>> output >>> output
[[[[7. 8.] [[[[7. 8.]
@ -195,7 +195,7 @@ class AvgPool2d(_PoolNd):
[0. 8. 9. 7.] [0. 8. 9. 7.]
[2. 1. 4. 9.]]]] [2. 1. 4. 9.]]]]
>>> output = pool(x) >>> output = pool(x)
>>> output.shape() >>> output.shape
(1, 2, 2, 2) (1, 2, 2, 2)
>>> output >>> output
[[[[4.888889 4.4444447] [[[[4.888889 4.4444447]
@ -260,7 +260,7 @@ class AvgPool1d(_PoolNd):
>>> pool = nn.AvgPool1d(kernel_size=6, strides=1) >>> pool = nn.AvgPool1d(kernel_size=6, strides=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32)
>>> output = pool(x) >>> output = pool(x)
>>> output.shape() >>> output.shape
(1, 3, 1) (1, 3, 1)
""" """

View File

@ -571,8 +571,8 @@ class DenseQuant(Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape()[1] != in_channels: weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer( self.weight = Parameter(initializer(
@ -580,7 +580,7 @@ class DenseQuant(Cell):
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(

View File

@ -23,7 +23,7 @@ greater = base.MultitypeFuncGraph("greater")
@greater.register("Number", "Number") @greater.register("Number", "Number")
def _greater_scala(x, y): def _greater_scalar(x, y):
""" """
Determine whether two numbers are greater. Determine whether two numbers are greater.

View File

@ -145,10 +145,10 @@ class SameTypeShape(PrimitiveWithInfer):
def __call__(self, x, y): def __call__(self, x, y):
"""run in PyNative mode""" """run in PyNative mode"""
validator.check_value_type("x", x, Tensor, self.name) validator.check_value_type('x', x, Tensor, self.name)
validator.check_value_type("y", y, Tensor, self.name) validator.check_value_type('y', y, Tensor, self.name)
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError) validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name) validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
return x return x
def __infer__(self, x, y): def __infer__(self, x, y):
@ -187,7 +187,7 @@ class Cast(PrimitiveWithInfer):
def check_elim(self, x, dtype): def check_elim(self, x, dtype):
if isinstance(x, Tensor): if isinstance(x, Tensor):
if x.dtype() == dtype: if x.dtype == dtype:
return (True, x) return (True, x)
return (False, None) return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
@ -498,7 +498,7 @@ class GatherV2(PrimitiveWithInfer):
The original Tensor. The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`. `[0, input_param.shape[axis])`.
- **axis** (int) - Specifies the dimension index to gather indices. - **axis** (int) - Specifies the dimension index to gather indices.
Outputs: Outputs:
@ -542,7 +542,7 @@ class SparseGatherV2(GatherV2):
The original Tensor. The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`. `[0, input_param.shape[axis])`.
- **axis** (int) - Specifies the dimension index to gather indices. - **axis** (int) - Specifies the dimension index to gather indices.
Outputs: Outputs:
@ -700,7 +700,7 @@ class Split(PrimitiveWithInfer):
output_num (int): The number of output tensors. Default: 1. output_num (int): The number of output tensors. Default: 1.
Raises: Raises:
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())), ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)),
or if the output_num is less than or equal to 0, or if the or if the output_num is less than or equal to 0, or if the
dimension which to split cannot be evenly divided by output_num. dimension which to split cannot be evenly divided by output_num.
@ -1644,7 +1644,7 @@ class Unpack(PrimitiveWithInfer):
A tuple of Tensors, the shape of each objects is same. A tuple of Tensors, the shape of each objects is same.
Raises: Raises:
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())). ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
Examples: Examples:
>>> unpack = P.Unpack() >>> unpack = P.Unpack()
@ -1850,7 +1850,7 @@ class StridedSlice(PrimitiveWithInfer):
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32) >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
>>> slice = P.StridedSlice() >>> slice = P.StridedSlice()
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1)) >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
>>> output.shape() >>> output.shape
(1, 1, 3) (1, 1, 3)
>>> output >>> output
[[[3, 3, 3]]] [[[3, 3, 3]]]
@ -1974,7 +1974,7 @@ class Diag(PrimitiveWithInfer):
if x is None: if x is None:
return None return None
# do constant-folding only when x rank is 1 # do constant-folding only when x rank is 1
if len(x.shape()) != 1: if len(x.shape) != 1:
return None return None
ret = np.diag(x.asnumpy()) ret = np.diag(x.asnumpy())
return Tensor(ret) return Tensor(ret)
@ -2026,7 +2026,7 @@ class DiagPart(PrimitiveWithInfer):
if x is None: if x is None:
return None return None
# do constant-folding only when x rank is 2 # do constant-folding only when x rank is 2
if len(x.shape()) != 2: if len(x.shape) != 2:
return None return None
ret = np.diag(x.asnumpy()) ret = np.diag(x.asnumpy())
return Tensor(ret) return Tensor(ret)

View File

@ -2329,8 +2329,8 @@ class NMSWithMask(PrimitiveWithInfer):
def infer_shape(self, bboxes_shape): def infer_shape(self, bboxes_shape):
cls_name = self.name cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) validator.check_integer("bboxes.shape[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0] num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))

View File

@ -78,7 +78,7 @@ class Flatten(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32) >>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
>>> flatten = P.Flatten() >>> flatten = P.Flatten()
>>> output = flatten(input_tensor) >>> output = flatten(input_tensor)
>>> assert output.shape() == (1, 24) >>> assert output.shape == (1, 24)
""" """
@prim_attr_register @prim_attr_register
@ -840,7 +840,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
>>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32) >>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32)
>>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3)) >>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3))
>>> output = depthwise_conv2d(input, weight) >>> output = depthwise_conv2d(input, weight)
>>> assert output.shape() == (10, 96, 30, 30) >>> assert output.shape == (10, 96, 30, 30)
""" """
@prim_attr_register @prim_attr_register
@ -2057,7 +2057,7 @@ class DropoutDoMask(PrimitiveWithInfer):
>>> dropout_do_mask = P.DropoutDoMask() >>> dropout_do_mask = P.DropoutDoMask()
>>> mask = dropout_gen_mask(shape, keep_prob) >>> mask = dropout_gen_mask(shape, keep_prob)
>>> output = dropout_do_mask(x, mask, keep_prob) >>> output = dropout_do_mask(x, mask, keep_prob)
>>> assert output.shape() == (20, 16, 50) >>> assert output.shape == (20, 16, 50)
""" """
@prim_attr_register @prim_attr_register
@ -2114,7 +2114,7 @@ class ResizeBilinear(PrimitiveWithInfer):
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32) >>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32)
>>> resize_bilinear = P.ResizeBilinear((5, 5)) >>> resize_bilinear = P.ResizeBilinear((5, 5))
>>> result = resize_bilinear(tensor) >>> result = resize_bilinear(tensor)
>>> assert result.shape() == (5, 5) >>> assert result.shape == (5, 5)
""" """
@prim_attr_register @prim_attr_register

View File

@ -157,8 +157,8 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
data = Tensor(data) data = Tensor(data)
if not isinstance(data, Tensor): if not isinstance(data, Tensor):
raise ValueError("elements in tensors must be Tensor") raise ValueError("elements in tensors must be Tensor")
shape_ = data.shape() shape_ = data.shape
type_ = data.dtype() type_ = data.dtype
new_shape = () new_shape = ()
batchsize_per_device = 1 batchsize_per_device = 1
for i, item in enumerate(shape_): for i, item in enumerate(shape_):

View File

@ -42,17 +42,17 @@ def _special_process_par(par, new_par):
Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor. Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
""" """
par_shape_len = len(par.data.shape()) par_shape_len = len(par.data.shape)
new_par_shape_len = len(new_par.data.shape()) new_par_shape_len = len(new_par.data.shape)
delta_len = new_par_shape_len - par_shape_len delta_len = new_par_shape_len - par_shape_len
delta_i = 0 delta_i = 0
for delta_i in range(delta_len): for delta_i in range(delta_len):
if new_par.data.shape()[par_shape_len + delta_i] != 1: if new_par.data.shape[par_shape_len + delta_i] != 1:
break break
if delta_i == delta_len - 1: if delta_i == delta_len - 1:
new_val = new_par.data.asnumpy() new_val = new_par.data.asnumpy()
new_val = new_val.reshape(par.data.shape()) new_val = new_val.reshape(par.data.shape)
par.set_parameter_data(Tensor(new_val, par.data.dtype())) par.set_parameter_data(Tensor(new_val, par.data.dtype))
return True return True
return False return False
@ -61,17 +61,17 @@ def _update_param(param, new_param):
"""Updates param's data from new_param's data.""" """Updates param's data from new_param's data."""
if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
if param.data.dtype() != new_param.data.dtype(): if param.data.dtype != new_param.data.dtype:
logger.error("Failed to combine the net and the parameters for param %s.", param.name) logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} type({}) different from parameter_dict's({})" msg = ("Net parameters {} type({}) different from parameter_dict's({})"
.format(param.name, param.data.dtype(), new_param.data.dtype())) .format(param.name, param.data.dtype, new_param.data.dtype))
raise RuntimeError(msg) raise RuntimeError(msg)
if param.data.shape() != new_param.data.shape(): if param.data.shape != new_param.data.shape:
if not _special_process_par(param, new_param): if not _special_process_par(param, new_param):
logger.error("Failed to combine the net and the parameters for param %s.", param.name) logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) different from parameter_dict's({})" msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
.format(param.name, param.data.shape(), new_param.data.shape())) .format(param.name, param.data.shape, new_param.data.shape))
raise RuntimeError(msg) raise RuntimeError(msg)
return return
@ -79,12 +79,12 @@ def _update_param(param, new_param):
return return
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
if param.data.shape() != (1,) and param.data.shape() != (): if param.data.shape != (1,) and param.data.shape != ():
logger.error("Failed to combine the net and the parameters for param %s.", param.name) logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)." msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
.format(param.name, param.data.shape())) .format(param.name, param.data.shape))
raise RuntimeError(msg) raise RuntimeError(msg)
param.set_parameter_data(initializer(new_param.data, param.data.shape(), param.data.dtype())) param.set_parameter_data(initializer(new_param.data, param.data.shape, param.data.dtype))
elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor): elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
logger.error("Failed to combine the net and the parameters for param %s.", param.name) logger.error("Failed to combine the net and the parameters for param %s.", param.name)
@ -120,12 +120,12 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
param["data"].init_data() param["data"].init_data()
param_data = param["data"].asnumpy().reshape(-1) param_data = param["data"].asnumpy().reshape(-1)
param_tensor.tensor_content = param_data.tostring() param_tensor.tensor_content = param_data.tostring()
param_tensor.tensor_type = str(param["data"].dtype()) param_tensor.tensor_type = str(param["data"].dtype)
if param['data'].shape() == (): if param['data'].shape == ():
param_tensor.dims.append(0) param_tensor.dims.append(0)
else: else:
for dim in param['data'].shape(): for dim in param['data'].shape:
param_tensor.dims.append(dim) param_tensor.dims.append(dim)
with open(ckpoint_file_name, "wb") as f: with open(ckpoint_file_name, "wb") as f:

View File

@ -73,7 +73,7 @@ class FusedLayerNorm(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:] >>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x) >>> m(x)
""" """

View File

@ -267,21 +267,21 @@ class MobileNetV2(nn.Cell):
if isinstance(m, (nn.Conv2d, DepthwiseConv)): if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32"))) m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data( m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data( m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense): elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal( m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32"))) 0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
def mobilenet_v2(**kwargs): def mobilenet_v2(**kwargs):

View File

@ -322,21 +322,21 @@ class MobileNetV3(nn.Cell):
if isinstance(m, (nn.Conv2d)): if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32"))) m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data( m.gamma.set_parameter_data(
Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data( m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense): elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal( m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape()).astype("float32"))) 0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
def mobilenet_v3(model_name, **kwargs): def mobilenet_v3(model_name, **kwargs):

View File

@ -66,12 +66,12 @@ if __name__ == '__main__':
for _, cell in net.cells_and_names(): for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(), cell.weight.default_input.shape,
cell.weight.default_input.dtype()).to_tensor() cell.weight.default_input.dtype).to_tensor()
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(), cell.weight.default_input.shape,
cell.weight.default_input.dtype()).to_tensor() cell.weight.default_input.dtype).to_tensor()
if not config.label_smooth: if not config.label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)

View File

@ -23,9 +23,9 @@ def init_net_param(network, initialize_mode='TruncatedNormal'):
for p in params: for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
if initialize_mode == 'TruncatedNormal': if initialize_mode == 'TruncatedNormal':
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype())) p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape, p.data.dtype))
else: else:
p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype()) p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype)
def load_backbone_params(network, param_dict): def load_backbone_params(network, param_dict):

View File

@ -78,15 +78,15 @@ class GNNFeatureTransform(nn.Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape()[1] != in_channels: weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -51,4 +51,4 @@ def test_AllGather():
diff = output.asnumpy() - expect diff = output.asnumpy() - expect
error = np.ones(shape=expect.shape) * 1.0e-5 error = np.ones(shape=expect.shape) * 1.0e-5
assert np.all(diff < error) assert np.all(diff < error)
assert output.shape() == expect.shape assert output.shape == expect.shape

View File

@ -62,19 +62,19 @@ def test_AllReduce():
diff0 = output[0].asnumpy() - expect0 diff0 = output[0].asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = expect0 expect1 = expect0
diff1 = output[1].asnumpy() - expect1 diff1 = output[1].asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = expect1 expect2 = expect1
diff2 = output[2].asnumpy() - expect2 diff2 = output[2].asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape
class Net2(nn.Cell): class Net2(nn.Cell):
@ -108,16 +108,16 @@ def test_AllReduce2():
diff0 = abs(output[0].asnumpy() - expect0) diff0 = abs(output[0].asnumpy() - expect0)
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = expect0 * size expect1 = expect0 * size
diff1 = abs(output[1].asnumpy() - expect1) diff1 = abs(output[1].asnumpy() - expect1)
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = expect1 * size expect2 = expect1 * size
diff2 = abs(output[2].asnumpy() - expect2) diff2 = abs(output[2].asnumpy() - expect2)
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape

View File

@ -61,16 +61,16 @@ def test_ReduceScatter():
diff0 = output[0].asnumpy() - expect0 diff0 = output[0].asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * size expect1 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * size
diff1 = output[1].asnumpy() - expect1 diff1 = output[1].asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * 1 expect2 = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * 1
diff2 = output[2].asnumpy() - expect2 diff2 = output[2].asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape

View File

@ -73,7 +73,7 @@ class FusedLayerNorm(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:] >>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x) >>> m(x)
""" """

View File

@ -75,93 +75,93 @@ def test_tensor_auto_cast():
t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64) t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
net = TensorAutoCast() net = TensorAutoCast()
rs = net(t_uint8, t_int8) rs = net(t_uint8, t_int8)
assert rs.dtype() == mstype.int16 assert rs.dtype == mstype.int16
rs = net(t_uint8, t_int16) rs = net(t_uint8, t_int16)
assert rs.dtype() == mstype.int16 assert rs.dtype == mstype.int16
rs = net(t_uint8, t_int32) rs = net(t_uint8, t_int32)
assert rs.dtype() == mstype.int32 assert rs.dtype == mstype.int32
rs = net(t_uint8, t_int64) rs = net(t_uint8, t_int64)
assert rs.dtype() == mstype.int64 assert rs.dtype == mstype.int64
rs = net(t_int8, t_int16) rs = net(t_int8, t_int16)
assert rs.dtype() == mstype.int16 assert rs.dtype == mstype.int16
rs = net(t_int8, t_int32) rs = net(t_int8, t_int32)
assert rs.dtype() == mstype.int32 assert rs.dtype == mstype.int32
rs = net(t_int8, t_int64) rs = net(t_int8, t_int64)
assert rs.dtype() == mstype.int64 assert rs.dtype == mstype.int64
rs = net(t_int16, t_int32) rs = net(t_int16, t_int32)
assert rs.dtype() == mstype.int32 assert rs.dtype == mstype.int32
rs = net(t_int16, t_int64) rs = net(t_int16, t_int64)
assert rs.dtype() == mstype.int64 assert rs.dtype == mstype.int64
rs = net(t_int32, t_int64) rs = net(t_int32, t_int64)
assert rs.dtype() == mstype.int64 assert rs.dtype == mstype.int64
rs = net(t_fp16, t_fp32) rs = net(t_fp16, t_fp32)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = net(t_fp16, t_fp64) rs = net(t_fp16, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_fp32, t_fp64) rs = net(t_fp32, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_uint8, t_fp16) rs = net(t_uint8, t_fp16)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = net(t_uint8, t_fp32) rs = net(t_uint8, t_fp32)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = net(t_uint8, t_fp64) rs = net(t_uint8, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_int8, t_fp64) rs = net(t_int8, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_int16, t_fp64) rs = net(t_int16, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_int32, t_fp64) rs = net(t_int32, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_int64, t_fp64) rs = net(t_int64, t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
rs = net(t_fp16, t_int8) rs = net(t_fp16, t_int8)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = net(t_fp16, t_uint8) rs = net(t_fp16, t_uint8)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = net(t_fp16, t_int16) rs = net(t_fp16, t_int16)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = net(t_fp16, t_int32) rs = net(t_fp16, t_int32)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = net(t_fp16, t_int64) rs = net(t_fp16, t_int64)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
tint = TensorIntAutoCast() tint = TensorIntAutoCast()
rs = tint(t_uint8) rs = tint(t_uint8)
assert rs.dtype() == mstype.uint8 assert rs.dtype == mstype.uint8
rs = tint(t_int8) rs = tint(t_int8)
assert rs.dtype() == mstype.int8 assert rs.dtype == mstype.int8
rs = tint(t_int16) rs = tint(t_int16)
assert rs.dtype() == mstype.int16 assert rs.dtype == mstype.int16
rs = tint(t_int32) rs = tint(t_int32)
assert rs.dtype() == mstype.int32 assert rs.dtype == mstype.int32
rs = tint(t_int64) rs = tint(t_int64)
assert rs.dtype() == mstype.int64 assert rs.dtype == mstype.int64
rs = tint(t_fp16) rs = tint(t_fp16)
assert rs.dtype() == mstype.float16 assert rs.dtype == mstype.float16
rs = tint(t_fp32) rs = tint(t_fp32)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tint(t_fp64) rs = tint(t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
tfp = TensorFPAutoCast() tfp = TensorFPAutoCast()
rs = tfp(t_uint8) rs = tfp(t_uint8)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_int8) rs = tfp(t_int8)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_int16) rs = tfp(t_int16)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_int32) rs = tfp(t_int32)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_int64) rs = tfp(t_int64)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_fp16) rs = tfp(t_fp16)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_fp32) rs = tfp(t_fp32)
assert rs.dtype() == mstype.float32 assert rs.dtype == mstype.float32
rs = tfp(t_fp64) rs = tfp(t_fp64)
assert rs.dtype() == mstype.float64 assert rs.dtype == mstype.float64
t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16) t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32) t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)

View File

@ -35,7 +35,7 @@ class Net(nn.Cell):
self.biasAdd = P.BiasAdd() self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(

View File

@ -64,7 +64,7 @@ def convert_type(shapes, types):
for np_shape, np_type in zip(shapes, types): for np_shape, np_type in zip(shapes, types):
input_np = np.zeros(np_shape, np_type) input_np = np.zeros(np_shape, np_type)
tensor = Tensor(input_np) tensor = Tensor(input_np)
ms_types.append(tensor.dtype()) ms_types.append(tensor.dtype)
return ms_types return ms_types

View File

@ -34,7 +34,7 @@ class NetArgmax(nn.Cell):
x = Tensor(np.array([[1., 20., 5.], x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.], [67., 8., 9.],
[130., 24., 15.]]).astype(np.float32)) [130., 24., 15.]]).astype(np.float32))
self.x = Parameter(initializer(x, x.shape()), name='x') self.x = Parameter(initializer(x, x.shape), name='x')
def construct(self): def construct(self):
return self.argmax(self.x) return self.argmax(self.x)

View File

@ -32,8 +32,8 @@ class NetEqualCount(nn.Cell):
self.equalcount = P.EqualCount() self.equalcount = P.EqualCount()
x = Tensor(np.array([1, 20, 5]).astype(np.int32)) x = Tensor(np.array([1, 20, 5]).astype(np.int32))
y = Tensor(np.array([2, 20, 5]).astype(np.int32)) y = Tensor(np.array([2, 20, 5]).astype(np.int32))
self.x = Parameter(initializer(x, x.shape()), name='x') self.x = Parameter(initializer(x, x.shape), name='x')
self.y = Parameter(initializer(y, y.shape()), name='y') self.y = Parameter(initializer(y, y.shape), name='y')
def construct(self): def construct(self):
return self.equalcount(self.x, self.y) return self.equalcount(self.x, self.y)

View File

@ -33,7 +33,7 @@ class NetSoftmax(nn.Cell):
x = Tensor(np.array([[0.1, 0.3, 0.6], x = Tensor(np.array([[0.1, 0.3, 0.6],
[0.2, -0.6, 0.8], [0.2, -0.6, 0.8],
[0.6, 1, 0.4]]).astype(np.float32)) [0.6, 1, 0.4]]).astype(np.float32))
self.x = Parameter(initializer(x, x.shape()), name='x') self.x = Parameter(initializer(x, x.shape), name='x')
def construct(self): def construct(self):
return self.softmax(self.x) return self.softmax(self.x)

View File

@ -32,9 +32,9 @@ class NetSoftmaxWithCrossEntropy(nn.Cell):
logits = Tensor(np.array([[1, 1, 10], logits = Tensor(np.array([[1, 1, 10],
[1, 10, 1], [1, 10, 1],
[10, 1, 1]]).astype(np.float32)) [10, 1, 1]]).astype(np.float32))
self.logits = Parameter(initializer(logits, logits.shape()), name='logits') self.logits = Parameter(initializer(logits, logits.shape), name='logits')
labels = Tensor(np.array([2, 1, 0]).astype(np.int32)) labels = Tensor(np.array([2, 1, 0]).astype(np.int32))
self.labels = Parameter(initializer(labels, labels.shape()), name='labels') self.labels = Parameter(initializer(labels, labels.shape), name='labels')
self.SoftmaxWithCrossEntropy = P.SparseSoftmaxCrossEntropyWithLogits(True) self.SoftmaxWithCrossEntropy = P.SparseSoftmaxCrossEntropyWithLogits(True)
def construct(self): def construct(self):

View File

@ -50,4 +50,4 @@ def test_correction_mul():
diff = output.asnumpy() - expect diff = output.asnumpy() - expect
assert np.all(diff < error) assert np.all(diff < error)
assert np.all(diff > error * -1) assert np.all(diff > error * -1)
assert output.shape() == expect.shape assert output.shape == expect.shape

View File

@ -65,19 +65,19 @@ def test_equal():
equal = NetEqual() equal = NetEqual()
output0 = equal(x0, y0) output0 = equal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = equal(x1, y1) output1 = equal(x1, y1)
assert np.all(output1.asnumpy() == expect1) assert np.all(output1.asnumpy() == expect1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
equal = NetEqual() equal = NetEqual()
output0 = equal(x0, y0) output0 = equal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = equal(x1, y1) output1 = equal(x1, y1)
assert np.all(output1.asnumpy() == expect1) assert np.all(output1.asnumpy() == expect1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
@pytest.mark.level0 @pytest.mark.level0
@ -92,13 +92,13 @@ def test_notequal():
notequal = NetNotEqual() notequal = NetNotEqual()
output0 = notequal(x0, y0) output0 = notequal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
notequal = NetNotEqual() notequal = NetNotEqual()
output0 = notequal(x0, y0) output0 = notequal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
@pytest.mark.level0 @pytest.mark.level0
@ -113,10 +113,10 @@ def test_greaterqual():
gequal = NetGreaterEqual() gequal = NetGreaterEqual()
output0 = gequal(x0, y0) output0 = gequal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gequal = NetGreaterEqual() gequal = NetGreaterEqual()
output0 = gequal(x0, y0) output0 = gequal(x0, y0)
assert np.all(output0.asnumpy() == expect0) assert np.all(output0.asnumpy() == expect0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape

View File

@ -49,19 +49,19 @@ def test_exp():
output0 = exp(x0) output0 = exp(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = exp(x1) output1 = exp(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
exp = NetExp() exp = NetExp()
output0 = exp(x0) output0 = exp(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = exp(x1) output1 = exp(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape

View File

@ -50,10 +50,10 @@ def test_log():
output1 = log(x1) output1 = log(x1)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
log = NetLog() log = NetLog()
@ -61,7 +61,7 @@ def test_log():
output1 = log(x1) output1 = log(x1)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape

View File

@ -64,35 +64,35 @@ def test_mul():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = mul(x1, y1) output1 = mul(x1, y1)
expect1 = np.multiply(x1_np, y1_np) expect1 = np.multiply(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
output2 = mul(x2, y2) output2 = mul(x2, y2)
expect2 = np.multiply(x2_np, y2_np) expect2 = np.multiply(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
output3 = mul(x3, y3) output3 = mul(x3, y3)
expect3 = np.multiply(x3_np, y3_np) expect3 = np.multiply(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
output4 = mul(x4, y4) output4 = mul(x4, y4)
expect4 = np.multiply(x4_np, y4_np) expect4 = np.multiply(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
mul = NetMul() mul = NetMul()
@ -101,32 +101,32 @@ def test_mul():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = mul(x1, y1) output1 = mul(x1, y1)
expect1 = np.multiply(x1_np, y1_np) expect1 = np.multiply(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
output2 = mul(x2, y2) output2 = mul(x2, y2)
expect2 = np.multiply(x2_np, y2_np) expect2 = np.multiply(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
output3 = mul(x3, y3) output3 = mul(x3, y3)
expect3 = np.multiply(x3_np, y3_np) expect3 = np.multiply(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
output4 = mul(x4, y4) output4 = mul(x4, y4)
expect4 = np.multiply(x4_np, y4_np) expect4 = np.multiply(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape

View File

@ -49,19 +49,19 @@ def test_neg():
output0 = neg(x0) output0 = neg(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = neg(x1) output1 = neg(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
neg = NetNeg() neg = NetNeg()
output0 = neg(x0) output0 = neg(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = neg(x1) output1 = neg(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape

View File

@ -64,35 +64,35 @@ def test_real_div():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = real_div(x1, y1) output1 = real_div(x1, y1)
expect1 = np.divide(x1_np, y1_np) expect1 = np.divide(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
output2 = real_div(x2, y2) output2 = real_div(x2, y2)
expect2 = np.divide(x2_np, y2_np) expect2 = np.divide(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
output3 = real_div(x3, y3) output3 = real_div(x3, y3)
expect3 = np.divide(x3_np, y3_np) expect3 = np.divide(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
output4 = real_div(x4, y4) output4 = real_div(x4, y4)
expect4 = np.divide(x4_np, y4_np) expect4 = np.divide(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
real_div = NetRealDiv() real_div = NetRealDiv()
@ -101,32 +101,32 @@ def test_real_div():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = real_div(x1, y1) output1 = real_div(x1, y1)
expect1 = np.divide(x1_np, y1_np) expect1 = np.divide(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
output2 = real_div(x2, y2) output2 = real_div(x2, y2)
expect2 = np.divide(x2_np, y2_np) expect2 = np.divide(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
output3 = real_div(x3, y3) output3 = real_div(x3, y3)
expect3 = np.divide(x3_np, y3_np) expect3 = np.divide(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
output4 = real_div(x4, y4) output4 = real_div(x4, y4)
expect4 = np.divide(x4_np, y4_np) expect4 = np.divide(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape

View File

@ -49,19 +49,19 @@ def test_Reciprocal():
output0 = reciprocal(x0) output0 = reciprocal(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = reciprocal(x1) output1 = reciprocal(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
reciprocal = NetReciprocal() reciprocal = NetReciprocal()
output0 = reciprocal(x0) output0 = reciprocal(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = reciprocal(x1) output1 = reciprocal(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape

View File

@ -128,43 +128,43 @@ def test_ReduceMax():
diff0 = abs(output[0].asnumpy() - expect0) diff0 = abs(output[0].asnumpy() - expect0)
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = np.max(x1, axis=axis1, keepdims=keep_dims1) expect1 = np.max(x1, axis=axis1, keepdims=keep_dims1)
diff1 = abs(output[1].asnumpy() - expect1) diff1 = abs(output[1].asnumpy() - expect1)
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = np.max(x2, axis=axis2, keepdims=keep_dims2) expect2 = np.max(x2, axis=axis2, keepdims=keep_dims2)
diff2 = abs(output[2].asnumpy() - expect2) diff2 = abs(output[2].asnumpy() - expect2)
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape
expect3 = np.max(x3, axis=axis3, keepdims=keep_dims3) expect3 = np.max(x3, axis=axis3, keepdims=keep_dims3)
diff3 = abs(output[3].asnumpy() - expect3) diff3 = abs(output[3].asnumpy() - expect3)
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output[3].shape() == expect3.shape assert output[3].shape == expect3.shape
expect4 = np.max(x4, axis=np_axis4, keepdims=keep_dims4) expect4 = np.max(x4, axis=np_axis4, keepdims=keep_dims4)
diff4 = abs(output[4].asnumpy() - expect4) diff4 = abs(output[4].asnumpy() - expect4)
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output[4].shape() == expect4.shape assert output[4].shape == expect4.shape
expect5 = np.max(x5, axis=np_axis5, keepdims=keep_dims5) expect5 = np.max(x5, axis=np_axis5, keepdims=keep_dims5)
diff5 = abs(output[5].asnumpy() - expect5) diff5 = abs(output[5].asnumpy() - expect5)
error5 = np.ones(shape=expect5.shape) * 1.0e-5 error5 = np.ones(shape=expect5.shape) * 1.0e-5
assert np.all(diff5 < error5) assert np.all(diff5 < error5)
assert output[5].shape() == expect5.shape assert output[5].shape == expect5.shape
expect6 = np.max(x6, axis=axis6, keepdims=keep_dims6) expect6 = np.max(x6, axis=axis6, keepdims=keep_dims6)
diff6 = abs(output[6].asnumpy() - expect6) diff6 = abs(output[6].asnumpy() - expect6)
error6 = np.ones(shape=expect6.shape) * 1.0e-5 error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6) assert np.all(diff6 < error6)
assert output[6].shape() == expect6.shape assert output[6].shape == expect6.shape
expect7 = np.max(x7, axis=axis7, keepdims=keep_dims7) expect7 = np.max(x7, axis=axis7, keepdims=keep_dims7)
diff7 = abs(output[7].asnumpy() - expect7) diff7 = abs(output[7].asnumpy() - expect7)

View File

@ -180,88 +180,88 @@ def test_ReduceMean():
diff0 = abs(output[0].asnumpy() - expect0) diff0 = abs(output[0].asnumpy() - expect0)
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = np.mean(x1, axis=axis1, keepdims=keep_dims1) expect1 = np.mean(x1, axis=axis1, keepdims=keep_dims1)
diff1 = abs(output[1].asnumpy() - expect1) diff1 = abs(output[1].asnumpy() - expect1)
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = np.mean(x2, axis=axis2, keepdims=keep_dims2) expect2 = np.mean(x2, axis=axis2, keepdims=keep_dims2)
diff2 = abs(output[2].asnumpy() - expect2) diff2 = abs(output[2].asnumpy() - expect2)
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape
expect3 = np.mean(x3, axis=axis3, keepdims=keep_dims3) expect3 = np.mean(x3, axis=axis3, keepdims=keep_dims3)
diff3 = abs(output[3].asnumpy() - expect3) diff3 = abs(output[3].asnumpy() - expect3)
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output[3].shape() == expect3.shape assert output[3].shape == expect3.shape
expect4 = np.mean(x4, axis=axis4, keepdims=keep_dims4) expect4 = np.mean(x4, axis=axis4, keepdims=keep_dims4)
diff4 = abs(output[4].asnumpy() - expect4) diff4 = abs(output[4].asnumpy() - expect4)
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output[4].shape() == expect4.shape assert output[4].shape == expect4.shape
expect5 = np.mean(x5, axis=axis5, keepdims=keep_dims5) expect5 = np.mean(x5, axis=axis5, keepdims=keep_dims5)
diff5 = abs(output[5].asnumpy() - expect5) diff5 = abs(output[5].asnumpy() - expect5)
error5 = np.ones(shape=expect5.shape) * 1.0e-5 error5 = np.ones(shape=expect5.shape) * 1.0e-5
assert np.all(diff5 < error5) assert np.all(diff5 < error5)
assert output[5].shape() == expect5.shape assert output[5].shape == expect5.shape
expect6 = np.mean(x6, axis=axis6, keepdims=keep_dims6) expect6 = np.mean(x6, axis=axis6, keepdims=keep_dims6)
diff6 = abs(output[6].asnumpy() - expect6) diff6 = abs(output[6].asnumpy() - expect6)
error6 = np.ones(shape=expect6.shape) * 1.0e-5 error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6) assert np.all(diff6 < error6)
assert output[6].shape() == expect6.shape assert output[6].shape == expect6.shape
expect7 = np.mean(x7, axis=axis7, keepdims=keep_dims7) expect7 = np.mean(x7, axis=axis7, keepdims=keep_dims7)
diff7 = abs(output[7].asnumpy() - expect7) diff7 = abs(output[7].asnumpy() - expect7)
error7 = np.ones(shape=expect7.shape) * 1.0e-5 error7 = np.ones(shape=expect7.shape) * 1.0e-5
assert np.all(diff7 < error7) assert np.all(diff7 < error7)
assert output[7].shape() == expect7.shape assert output[7].shape == expect7.shape
expect8 = np.mean(x8, axis=axis8, keepdims=keep_dims8) expect8 = np.mean(x8, axis=axis8, keepdims=keep_dims8)
diff8 = abs(output[8].asnumpy() - expect8) diff8 = abs(output[8].asnumpy() - expect8)
error8 = np.ones(shape=expect8.shape) * 1.0e-5 error8 = np.ones(shape=expect8.shape) * 1.0e-5
assert np.all(diff8 < error8) assert np.all(diff8 < error8)
assert output[8].shape() == expect8.shape assert output[8].shape == expect8.shape
expect9 = np.mean(x9, axis=axis9, keepdims=keep_dims9) expect9 = np.mean(x9, axis=axis9, keepdims=keep_dims9)
diff9 = abs(output[9].asnumpy() - expect9) diff9 = abs(output[9].asnumpy() - expect9)
error9 = np.ones(shape=expect9.shape) * 1.0e-5 error9 = np.ones(shape=expect9.shape) * 1.0e-5
assert np.all(diff9 < error9) assert np.all(diff9 < error9)
assert output[9].shape() == expect9.shape assert output[9].shape == expect9.shape
expect10 = np.mean(x10, axis=axis10, keepdims=keep_dims10) expect10 = np.mean(x10, axis=axis10, keepdims=keep_dims10)
diff10 = abs(output[10].asnumpy() - expect10) diff10 = abs(output[10].asnumpy() - expect10)
error10 = np.ones(shape=expect10.shape) * 1.0e-5 error10 = np.ones(shape=expect10.shape) * 1.0e-5
assert np.all(diff10 < error10) assert np.all(diff10 < error10)
assert output[10].shape() == expect10.shape assert output[10].shape == expect10.shape
expect11 = np.mean(x11, axis=axis11, keepdims=keep_dims11) expect11 = np.mean(x11, axis=axis11, keepdims=keep_dims11)
diff11 = abs(output[11].asnumpy() - expect11) diff11 = abs(output[11].asnumpy() - expect11)
error11 = np.ones(shape=expect11.shape) * 1.0e-5 error11 = np.ones(shape=expect11.shape) * 1.0e-5
assert np.all(diff11 < error11) assert np.all(diff11 < error11)
assert output[11].shape() == expect11.shape assert output[11].shape == expect11.shape
expect12 = np.mean(x12, axis=axis12, keepdims=keep_dims12) expect12 = np.mean(x12, axis=axis12, keepdims=keep_dims12)
diff12 = abs(output[12].asnumpy() - expect12) diff12 = abs(output[12].asnumpy() - expect12)
error12 = np.ones(shape=expect12.shape) * 1.0e-5 error12 = np.ones(shape=expect12.shape) * 1.0e-5
assert np.all(diff12 < error12) assert np.all(diff12 < error12)
assert output[12].shape() == expect12.shape assert output[12].shape == expect12.shape
expect13 = np.mean(x13, axis=axis13, keepdims=keep_dims13) expect13 = np.mean(x13, axis=axis13, keepdims=keep_dims13)
diff13 = abs(output[13].asnumpy() - expect13) diff13 = abs(output[13].asnumpy() - expect13)
error13 = np.ones(shape=expect13.shape) * 1.0e-5 error13 = np.ones(shape=expect13.shape) * 1.0e-5
assert np.all(diff13 < error13) assert np.all(diff13 < error13)
assert output[13].shape() == expect13.shape assert output[13].shape == expect13.shape
expect14 = np.mean(x14, axis=np_axis14, keepdims=keep_dims14) expect14 = np.mean(x14, axis=np_axis14, keepdims=keep_dims14)
diff14 = abs(output[14].asnumpy() - expect14) diff14 = abs(output[14].asnumpy() - expect14)
error14 = np.ones(shape=expect14.shape) * 1.0e-5 error14 = np.ones(shape=expect14.shape) * 1.0e-5
assert np.all(diff14 < error14) assert np.all(diff14 < error14)
assert output[14].shape() == expect14.shape assert output[14].shape == expect14.shape

View File

@ -182,88 +182,88 @@ def test_ReduceSum():
diff0 = abs(output[0].asnumpy() - expect0) diff0 = abs(output[0].asnumpy() - expect0)
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = np.sum(x1, axis=axis1, keepdims=keep_dims1) expect1 = np.sum(x1, axis=axis1, keepdims=keep_dims1)
diff1 = abs(output[1].asnumpy() - expect1) diff1 = abs(output[1].asnumpy() - expect1)
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = np.sum(x2, axis=axis2, keepdims=keep_dims2) expect2 = np.sum(x2, axis=axis2, keepdims=keep_dims2)
diff2 = abs(output[2].asnumpy() - expect2) diff2 = abs(output[2].asnumpy() - expect2)
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape
expect3 = np.sum(x3, axis=axis3, keepdims=keep_dims3) expect3 = np.sum(x3, axis=axis3, keepdims=keep_dims3)
diff3 = abs(output[3].asnumpy() - expect3) diff3 = abs(output[3].asnumpy() - expect3)
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output[3].shape() == expect3.shape assert output[3].shape == expect3.shape
expect4 = np.sum(x4, axis=np_axis4, keepdims=keep_dims4) expect4 = np.sum(x4, axis=np_axis4, keepdims=keep_dims4)
diff4 = abs(output[4].asnumpy() - expect4) diff4 = abs(output[4].asnumpy() - expect4)
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output[4].shape() == expect4.shape assert output[4].shape == expect4.shape
expect5 = np.sum(x5, axis=np_axis5, keepdims=keep_dims5) expect5 = np.sum(x5, axis=np_axis5, keepdims=keep_dims5)
diff5 = abs(output[5].asnumpy() - expect5) diff5 = abs(output[5].asnumpy() - expect5)
error5 = np.ones(shape=expect5.shape) * 1.0e-5 error5 = np.ones(shape=expect5.shape) * 1.0e-5
assert np.all(diff5 < error5) assert np.all(diff5 < error5)
assert output[5].shape() == expect5.shape assert output[5].shape == expect5.shape
expect6 = np.sum(x6, axis=axis6, keepdims=keep_dims6) expect6 = np.sum(x6, axis=axis6, keepdims=keep_dims6)
diff6 = abs(output[6].asnumpy() - expect6) diff6 = abs(output[6].asnumpy() - expect6)
error6 = np.ones(shape=expect6.shape) * 1.0e-5 error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6) assert np.all(diff6 < error6)
assert output[6].shape() == expect6.shape assert output[6].shape == expect6.shape
expect7 = np.sum(x7, axis=axis7, keepdims=keep_dims7) expect7 = np.sum(x7, axis=axis7, keepdims=keep_dims7)
diff7 = abs(output[7].asnumpy() - expect7) diff7 = abs(output[7].asnumpy() - expect7)
error7 = np.ones(shape=expect7.shape) * 1.0e-5 error7 = np.ones(shape=expect7.shape) * 1.0e-5
assert np.all(diff7 < error7) assert np.all(diff7 < error7)
assert output[7].shape() == expect7.shape assert output[7].shape == expect7.shape
expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8) expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8)
diff8 = abs(output[8].asnumpy() - expect8) diff8 = abs(output[8].asnumpy() - expect8)
error8 = np.ones(shape=expect8.shape) * 1.0e-5 error8 = np.ones(shape=expect8.shape) * 1.0e-5
assert np.all(diff8 < error8) assert np.all(diff8 < error8)
assert output[8].shape() == expect8.shape assert output[8].shape == expect8.shape
expect9 = np.sum(x9, axis=axis9, keepdims=keep_dims9) expect9 = np.sum(x9, axis=axis9, keepdims=keep_dims9)
diff9 = abs(output[9].asnumpy() - expect9) diff9 = abs(output[9].asnumpy() - expect9)
error9 = np.ones(shape=expect9.shape) * 1.0e-5 error9 = np.ones(shape=expect9.shape) * 1.0e-5
assert np.all(diff9 < error9) assert np.all(diff9 < error9)
assert output[9].shape() == expect9.shape assert output[9].shape == expect9.shape
expect10 = np.sum(x10, axis=axis10, keepdims=keep_dims10) expect10 = np.sum(x10, axis=axis10, keepdims=keep_dims10)
diff10 = abs(output[10].asnumpy() - expect10) diff10 = abs(output[10].asnumpy() - expect10)
error10 = np.ones(shape=expect10.shape) * 1.0e-5 error10 = np.ones(shape=expect10.shape) * 1.0e-5
assert np.all(diff10 < error10) assert np.all(diff10 < error10)
assert output[10].shape() == expect10.shape assert output[10].shape == expect10.shape
expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11) expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11)
diff11 = abs(output[11].asnumpy() - expect11) diff11 = abs(output[11].asnumpy() - expect11)
error11 = np.ones(shape=expect11.shape) * 1.0e-5 error11 = np.ones(shape=expect11.shape) * 1.0e-5
assert np.all(diff11 < error11) assert np.all(diff11 < error11)
assert output[11].shape() == expect11.shape assert output[11].shape == expect11.shape
expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12) expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12)
diff12 = abs(output[12].asnumpy() - expect12) diff12 = abs(output[12].asnumpy() - expect12)
error12 = np.ones(shape=expect12.shape) * 1.0e-5 error12 = np.ones(shape=expect12.shape) * 1.0e-5
assert np.all(diff12 < error12) assert np.all(diff12 < error12)
assert output[12].shape() == expect12.shape assert output[12].shape == expect12.shape
expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13) expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13)
diff13 = abs(output[13].asnumpy() - expect13) diff13 = abs(output[13].asnumpy() - expect13)
error13 = np.ones(shape=expect13.shape) * 1.0e-5 error13 = np.ones(shape=expect13.shape) * 1.0e-5
assert np.all(diff13 < error13) assert np.all(diff13 < error13)
assert output[13].shape() == expect13.shape assert output[13].shape == expect13.shape
expect14 = np.sum(x14, axis=np_axis14, keepdims=keep_dims14) expect14 = np.sum(x14, axis=np_axis14, keepdims=keep_dims14)
diff14 = abs(output[14].asnumpy() - expect14) diff14 = abs(output[14].asnumpy() - expect14)
error14 = np.ones(shape=expect14.shape) * 1.0e-5 error14 = np.ones(shape=expect14.shape) * 1.0e-5
assert np.all(diff14 < error14) assert np.all(diff14 < error14)
assert output[14].shape() == expect14.shape assert output[14].shape == expect14.shape

View File

@ -76,19 +76,19 @@ def test_Sub():
output4 = sub(x4, y4) output4 = sub(x4, y4)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sub = Net() sub = Net()
@ -99,16 +99,16 @@ def test_Sub():
output4 = sub(x4, y4) output4 = sub(x4, y4)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape() == expect2.shape assert output2.shape == expect2.shape
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape() == expect3.shape assert output3.shape == expect3.shape
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape() == expect4.shape assert output4.shape == expect4.shape

View File

@ -65,16 +65,16 @@ def test_tile():
diff0 = output[0].asnumpy() - expect0 diff0 = output[0].asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output[0].shape() == expect0.shape assert output[0].shape == expect0.shape
expect1 = np.tile(input_x1, mul1) expect1 = np.tile(input_x1, mul1)
diff1 = output[1].asnumpy() - expect1 diff1 = output[1].asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output[1].shape() == expect1.shape assert output[1].shape == expect1.shape
expect2 = np.tile(input_x2, mul2) expect2 = np.tile(input_x2, mul2)
diff2 = output[2].asnumpy() - expect2 diff2 = output[2].asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output[2].shape() == expect2.shape assert output[2].shape == expect2.shape

View File

@ -50,14 +50,14 @@ def test_ZerosLike():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = zeros_like(x1) output1 = zeros_like(x1)
expect1 = np.zeros_like(x1_np) expect1 = np.zeros_like(x1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
zeros_like = NetZerosLike() zeros_like = NetZerosLike()
@ -66,11 +66,11 @@ def test_ZerosLike():
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape() == expect0.shape assert output0.shape == expect0.shape
output1 = zeros_like(x1) output1 = zeros_like(x1)
expect1 = np.zeros_like(x1_np) expect1 = np.zeros_like(x1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape() == expect1.shape assert output1.shape == expect1.shape

View File

@ -20,6 +20,7 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from tests.ut.python.ut_filter import non_graph_engine from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
@ -44,7 +45,12 @@ def test_list_equal():
y = Tensor(np.zeros([3, 4, 5], np.int32)) y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [1, 2, 3] z = [1, 2, 3]
net = Net(z) net = Net(z)
assert net(x, y) == x ret = net(x, y)
print(ret.asnumpy())
assert ret == x
assert ret.dtype == mstype.int32
assert ret.shape == (6, 8, 10)
def test_list_not_equal(): def test_list_not_equal():

View File

@ -33,7 +33,7 @@ class Net(nn.Cell):
self.biasAdd = P.BiasAdd() self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(

View File

@ -65,7 +65,7 @@ def test_bias_add(test_with_simu):
self.biasAdd = P.BiasAdd() self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != output_channels: if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(

View File

@ -50,148 +50,148 @@ def test_tensor():
"""test_tensor""" """test_tensor"""
t1 = ms.Tensor(ndarr) t1 = ms.Tensor(ndarr)
assert isinstance(t1, ms.Tensor) assert isinstance(t1, ms.Tensor)
assert t1.dtype() == ms.float64 assert t1.dtype == ms.float64
t2 = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) t2 = ms.Tensor(np.zeros([1, 2, 3]), ms.float32)
assert isinstance(t2, ms.Tensor) assert isinstance(t2, ms.Tensor)
assert t2.shape() == (1, 2, 3) assert t2.shape == (1, 2, 3)
assert t2.dtype() == ms.float32 assert t2.dtype == ms.float32
t3 = ms.Tensor(0.1) t3 = ms.Tensor(0.1)
assert isinstance(t3, ms.Tensor) assert isinstance(t3, ms.Tensor)
assert t3.dtype() == ms.float64 assert t3.dtype == ms.float64
t4 = ms.Tensor(1) t4 = ms.Tensor(1)
assert isinstance(t4, ms.Tensor) assert isinstance(t4, ms.Tensor)
assert t4.dtype() == ms.int64 assert t4.dtype == ms.int64
def test_tensor_type_float16(): def test_tensor_type_float16():
t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16)) t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16))
assert isinstance(t_float16, ms.Tensor) assert isinstance(t_float16, ms.Tensor)
assert t_float16.shape() == (2, 3) assert t_float16.shape == (2, 3)
assert t_float16.dtype() == ms.float16 assert t_float16.dtype == ms.float16
def test_tensor_type_float32(): def test_tensor_type_float32():
t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
assert isinstance(t_float32, ms.Tensor) assert isinstance(t_float32, ms.Tensor)
assert t_float32.shape() == (2, 3) assert t_float32.shape == (2, 3)
assert t_float32.dtype() == ms.float32 assert t_float32.dtype == ms.float32
def test_tensor_type_float32_user_define(): def test_tensor_type_float32_user_define():
t = ms.Tensor(np.zeros([1, 2, 3]), ms.float32) t = ms.Tensor(np.zeros([1, 2, 3]), ms.float32)
assert isinstance(t, ms.Tensor) assert isinstance(t, ms.Tensor)
assert t.shape() == (1, 2, 3) assert t.shape == (1, 2, 3)
assert t.dtype() == ms.float32 assert t.dtype == ms.float32
def test_tensor_type_float64(): def test_tensor_type_float64():
t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]]) t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]])
assert isinstance(t, ms.Tensor) assert isinstance(t, ms.Tensor)
assert t.shape() == (2, 3) assert t.shape == (2, 3)
assert t.dtype() == ms.float64 assert t.dtype == ms.float64
t_zero = ms.Tensor(np.zeros([1, 2, 3])) t_zero = ms.Tensor(np.zeros([1, 2, 3]))
assert isinstance(t_zero, ms.Tensor) assert isinstance(t_zero, ms.Tensor)
assert t_zero.shape() == (1, 2, 3) assert t_zero.shape == (1, 2, 3)
assert t_zero.dtype() == ms.float64 assert t_zero.dtype == ms.float64
def test_tensor_type_float64_user_define(): def test_tensor_type_float64_user_define():
t = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=float)) t = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=float))
assert isinstance(t, ms.Tensor) assert isinstance(t, ms.Tensor)
assert t.shape() == (2, 3) assert t.shape == (2, 3)
assert t.dtype() == ms.float64 assert t.dtype == ms.float64
t_float64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]]), ms.float64) t_float64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]]), ms.float64)
assert isinstance(t_float64, ms.Tensor) assert isinstance(t_float64, ms.Tensor)
assert t_float64.shape() == (2, 3) assert t_float64.shape == (2, 3)
assert t_float64.dtype() == ms.float64 assert t_float64.dtype == ms.float64
def test_tensor_type_bool(): def test_tensor_type_bool():
# init a tensor with bool type # init a tensor with bool type
ts_bool_array = ms.Tensor(np.zeros([2, 3], np.bool), ms.bool_) ts_bool_array = ms.Tensor(np.zeros([2, 3], np.bool), ms.bool_)
assert isinstance(ts_bool_array, ms.Tensor) assert isinstance(ts_bool_array, ms.Tensor)
assert ts_bool_array.dtype() == ms.bool_ assert ts_bool_array.dtype == ms.bool_
t_bool = ms.Tensor(True) t_bool = ms.Tensor(True)
assert isinstance(t_bool, ms.Tensor) assert isinstance(t_bool, ms.Tensor)
assert t_bool.dtype() == ms.bool_ assert t_bool.dtype == ms.bool_
t_bool_array = ms.Tensor(np.array([[True, False, True], [False, False, False]])) t_bool_array = ms.Tensor(np.array([[True, False, True], [False, False, False]]))
assert isinstance(t_bool_array, ms.Tensor) assert isinstance(t_bool_array, ms.Tensor)
assert t_bool_array.shape() == (2, 3) assert t_bool_array.shape == (2, 3)
assert t_bool_array.dtype() == ms.bool_ assert t_bool_array.dtype == ms.bool_
def test_tensor_type_int8(): def test_tensor_type_int8():
t_int8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8)) t_int8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8))
assert isinstance(t_int8_array, ms.Tensor) assert isinstance(t_int8_array, ms.Tensor)
assert t_int8_array.shape() == (2, 3) assert t_int8_array.shape == (2, 3)
assert t_int8_array.dtype() == ms.int8 assert t_int8_array.dtype == ms.int8
def test_tensor_type_int16(): def test_tensor_type_int16():
t_int16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)) t_int16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16))
assert isinstance(t_int16_array, ms.Tensor) assert isinstance(t_int16_array, ms.Tensor)
assert t_int16_array.shape() == (2, 3) assert t_int16_array.shape == (2, 3)
assert t_int16_array.dtype() == ms.int16 assert t_int16_array.dtype == ms.int16
def test_tensor_type_int32(): def test_tensor_type_int32():
t_int = ms.Tensor([[1, 2, 3], [4, 5, 6]]) t_int = ms.Tensor([[1, 2, 3], [4, 5, 6]])
assert isinstance(t_int, ms.Tensor) assert isinstance(t_int, ms.Tensor)
assert t_int.shape() == (2, 3) assert t_int.shape == (2, 3)
assert t_int.dtype() == ms.int64 assert t_int.dtype == ms.int64
t_int_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) t_int_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
assert isinstance(t_int_array, ms.Tensor) assert isinstance(t_int_array, ms.Tensor)
assert t_int_array.shape() == (2, 3) assert t_int_array.shape == (2, 3)
assert t_int_array.dtype() == ms.int32 assert t_int_array.dtype == ms.int32
def test_tensor_type_int64(): def test_tensor_type_int64():
t_int64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)) t_int64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64))
assert isinstance(t_int64, ms.Tensor) assert isinstance(t_int64, ms.Tensor)
assert t_int64.shape() == (2, 3) assert t_int64.shape == (2, 3)
assert t_int64.dtype() == ms.int64 assert t_int64.dtype == ms.int64
def test_tensor_type_uint8(): def test_tensor_type_uint8():
t_uint8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8)) t_uint8_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8))
assert isinstance(t_uint8_array, ms.Tensor) assert isinstance(t_uint8_array, ms.Tensor)
assert t_uint8_array.shape() == (2, 3) assert t_uint8_array.shape == (2, 3)
assert t_uint8_array.dtype() == ms.uint8 assert t_uint8_array.dtype == ms.uint8
def test_tensor_type_uint16(): def test_tensor_type_uint16():
t_uint16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)) t_uint16_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16))
assert isinstance(t_uint16_array, ms.Tensor) assert isinstance(t_uint16_array, ms.Tensor)
assert t_uint16_array.shape() == (2, 3) assert t_uint16_array.shape == (2, 3)
assert t_uint16_array.dtype() == ms.uint16 assert t_uint16_array.dtype == ms.uint16
def test_tensor_type_uint32(): def test_tensor_type_uint32():
t_uint32_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)) t_uint32_array = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32))
assert isinstance(t_uint32_array, ms.Tensor) assert isinstance(t_uint32_array, ms.Tensor)
assert t_uint32_array.shape() == (2, 3) assert t_uint32_array.shape == (2, 3)
assert t_uint32_array.dtype() == ms.uint32 assert t_uint32_array.dtype == ms.uint32
def test_tensor_type_uint64(): def test_tensor_type_uint64():
t_uint64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)) t_uint64 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64))
assert isinstance(t_uint64, ms.Tensor) assert isinstance(t_uint64, ms.Tensor)
assert t_uint64.shape() == (2, 3) assert t_uint64.shape == (2, 3)
assert t_uint64.dtype() == ms.uint64 assert t_uint64.dtype == ms.uint64
def test_set_type(): def test_set_type():
t = ms.Tensor(ndarr) t = ms.Tensor(ndarr)
t.set_dtype(ms.float32) t.set_dtype(ms.float32)
assert t.dtype() == ms.float32 assert t.dtype == ms.float32
@non_graph_engine @non_graph_engine
@ -250,11 +250,11 @@ def test_return_tensor():
tensor_ = exe(net, input_data) tensor_ = exe(net, input_data)
# get shape # get shape
shape_ = tensor_.shape() shape_ = tensor_.shape
print("shape = ", shape_) print("shape = ", shape_)
# get type # get type
type_ = tensor_.dtype() type_ = tensor_.dtype
print("type = ", type_) print("type = ", type_)
# get value # get value

View File

@ -71,7 +71,7 @@ def test_tensor_size():
def test_dtype(): def test_dtype():
a = ms.Tensor(np.ones((2, 3), dtype=np.int32)) a = ms.Tensor(np.ones((2, 3), dtype=np.int32))
assert a.dtype() == ms.int32 assert a.dtype == ms.int32
def test_asnumpy(): def test_asnumpy():
@ -89,7 +89,7 @@ def test_print():
def test_float(): def test_float():
a = ms.Tensor(np.ones((2, 3)), ms.float16) a = ms.Tensor(np.ones((2, 3)), ms.float16)
assert a.dtype() == ms.float16 assert a.dtype == ms.float16
def test_tensor_method_sub(): def test_tensor_method_sub():

View File

@ -71,7 +71,7 @@ def test(name, file_path, batch_size):
data_list.append(data.asnumpy()) data_list.append(data.asnumpy())
batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2)) batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2))
input_tensor = Tensor(batch_data) input_tensor = Tensor(batch_data)
print(input_tensor.shape()) print(input_tensor.shape)
network(input_tensor) network(input_tensor)

View File

@ -23,7 +23,7 @@ from mindspore import dtype as mstype
def test_check_layer_norm_1(): def test_check_layer_norm_1():
x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32)
shape1 = x.shape()[1:] shape1 = x.shape[1:]
m = nn.LayerNorm(shape1, -1, 1) m = nn.LayerNorm(shape1, -1, 1)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
m(x) m(x)
@ -31,7 +31,7 @@ def test_check_layer_norm_1():
def test_check_layer_norm_2(): def test_check_layer_norm_2():
x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32) x = Tensor(np.ones([20, 5, 10, 10]), mstype.float32)
shape1 = x.shape()[1:] shape1 = x.shape[1:]
m = nn.LayerNorm(shape1, begin_params_axis=1) m = nn.LayerNorm(shape1, begin_params_axis=1)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
m(x) m(x)

View File

@ -65,7 +65,7 @@ def test_init_Initializer():
def test_init_tensor(): def test_init_tensor():
tensor = ms.Tensor(np.zeros([1, 2, 3])) tensor = ms.Tensor(np.zeros([1, 2, 3]))
tensor = init.initializer(tensor, [1, 2, 3], ms.float32) tensor = init.initializer(tensor, [1, 2, 3], ms.float32)
assert tensor.shape() == (1, 2, 3) assert tensor.shape == (1, 2, 3)
def test_init_zero_default_dtype(): def test_init_zero_default_dtype():

View File

@ -126,8 +126,8 @@ def test_load_checkpoint():
assert len(par_dict) == 3 assert len(par_dict) == 3
assert par_dict['param_test'].name == 'param_test' assert par_dict['param_test'].name == 'param_test'
assert par_dict['param_test'].data.dtype() == mstype.float32 assert par_dict['param_test'].data.dtype == mstype.float32
assert par_dict['param_test'].data.shape() == (1, 3, 224, 224) assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
assert isinstance(par_dict, dict) assert isinstance(par_dict, dict)

View File

@ -46,7 +46,7 @@ def vm_impl_dType(self):
def vm_impl(x): def vm_impl(x):
# update the src type # update the src type
return x.dtype() return x.dtype
return vm_impl return vm_impl