Add tensor.new_zeros, ops.zeros, ops.zeros_like

This commit is contained in:
shaojunsong 2022-11-11 15:39:35 +08:00
parent c05a7d37bd
commit 9fb0538b68
21 changed files with 471 additions and 24 deletions

View File

@ -340,6 +340,8 @@ Tensor创建
mindspore.ops.one_hot mindspore.ops.one_hot
mindspore.ops.ones mindspore.ops.ones
mindspore.ops.ones_like mindspore.ops.ones_like
mindspore.ops.zeros
mindspore.ops.zeros_like
随机生成函数 随机生成函数
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@ -0,0 +1,18 @@
mindspore.Tensor.new_ones
==========================
.. py:method:: mindspore.Tensor.new_ones(size, *, dtype=None)
返回一个大小为 `size` 的Tensor填充值为1。默认情况下返回的Tensor和 `self` 具有相同的数据类型。
参数:
- **size** (Union[int, tuple, list]) - 定义输出的shape。
关键字参数:
- **dtype** (mindspore.dtype, 可选) - 输出的数据类型。默认值None使用和 `self` 相同的数据类型。
返回:
Tensorshape和dtype由输入定义填充值为1。
异常:
- **TypeError** - 如果 `size` 不是一个int或int的列表/元组。

View File

@ -0,0 +1,18 @@
mindspore.Tensor.new_zeros
===========================
.. py:method:: mindspore.Tensor.new_zeros(size, *, dtype=None)
返回一个大小为 `size` 的Tensor填充值为0。默认情况下返回的Tensor和 `self` 具有相同的数据类型。
参数:
- **size** (Union[int, tuple, list]) - 定义输出的shape。
关键字参数:
- **dtype** (mindspore.dtype, 可选) - 输出的数据类型。默认值None使用和 `self` 相同的数据类型。
返回:
Tensorshape和dtype由输入定义填充值为0。
异常:
- **TypeError** - 如果 `size` 不是一个int或int的列表/元组。

View File

@ -184,6 +184,8 @@ mindspore.Tensor
mindspore.Tensor.ndimension mindspore.Tensor.ndimension
mindspore.Tensor.negative mindspore.Tensor.negative
mindspore.Tensor.nelement mindspore.Tensor.nelement
mindspore.Tensor.new_ones
mindspore.Tensor.new_zeros
mindspore.Tensor.numel mindspore.Tensor.numel
mindspore.Tensor.nonzero mindspore.Tensor.nonzero
mindspore.Tensor.norm mindspore.Tensor.norm

View File

@ -22,4 +22,4 @@
Tensor`x` 具有相同的dtype。 Tensor`x` 具有相同的dtype。
异常: 异常:
- **ValueError**If `batch1` `batch2` 不能进行批量矩阵乘法。 - **ValueError** - 如果 `batch1` `batch2` 不能进行批量矩阵乘法。

View File

@ -1,15 +1,15 @@
mindspore.ops.ones mindspore.ops.ones
=================== ===================
.. py:function:: mindspore.ops.ones(shape, type) .. py:function:: mindspore.ops.ones(shape, dtype)
创建一个值全为1的Tensor。 创建一个值全为1的Tensor。
第一个参数指定Tensor的shape第二个参数指定填充值的数据类型。 第一个参数指定Tensor的shape第二个参数指定填充值的数据类型。
参数: 参数:
- **shape** (Union[tuple[int], int]) - 指定输出Tensor的shape,只能是正整数常量 - **shape** (Union[tuple[int], int]) - 指定输出Tensor的shape。
- **type** (mindspore.dtype) - 指定输出Tensor的数据类型只能是常量值 - **dtype** (:class:`mindspore.dtype`) - 用来描述所创建的Tensor的 `dtype`。如果为None那么将会使用mindspore.float32。默认值None
返回: 返回:
Tensorshape和数据类型与输入相同。 Tensorshape和数据类型与输入相同。

View File

@ -0,0 +1,16 @@
mindspore.ops.zeros
====================
.. py:function:: mindspore.ops.zeros(shape, dtype=None)
创建一个填满0的Tensorshape由 `size` 决定, dtype由 `dtype` 决定。
参数:
- **shape** (Union[tuple[int], int]) - 用来描述所创建的Tensor的 `shape`
- **dtype** (:class:`mindspore.dtype`) - 用来描述所创建的Tensor的 `dtype`。如果为None那么将会使用mindspore.float32。默认值None。
返回:
Tensordtype和shape由入参决定。
异常:
- **TypeError** - 如果 `shape` 既不是int也不是int的元组。

View File

@ -0,0 +1,18 @@
mindspore.ops.zeros_like
=========================
.. py:function:: mindspore.ops.zeros_like(x, *, dtype=None)
创建一个填满0的Tensorshape由 `x` 决定dtype由 `dtype` 决定。
参数:
- **x** (Tensor) - 用来描述所创建的Tensor的shape 。
关键字参数:
- **dtype** (:class:`mindspore.dtype`, 可选) - 用来描述所创建的Tensor的 `dtype`。如果为None那么将会使用 `x` 的dtype。默认值None。
返回:
Tensordtype和shape由入参决定。
异常:
- **TypeError** - 如果 `dtype` 不是MindSpore的dtype。

View File

@ -190,6 +190,8 @@
mindspore.Tensor.ndimension mindspore.Tensor.ndimension
mindspore.Tensor.negative mindspore.Tensor.negative
mindspore.Tensor.nelement mindspore.Tensor.nelement
mindspore.Tensor.new_ones
mindspore.Tensor.new_zeros
mindspore.Tensor.numel mindspore.Tensor.numel
mindspore.Tensor.nonzero mindspore.Tensor.nonzero
mindspore.Tensor.norm mindspore.Tensor.norm

View File

@ -340,6 +340,8 @@ Tensor Building
mindspore.ops.one_hot mindspore.ops.one_hot
mindspore.ops.ones mindspore.ops.ones
mindspore.ops.ones_like mindspore.ops.ones_like
mindspore.ops.zeros
mindspore.ops.zeros_like
Randomly Generating Functions Randomly Generating Functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -421,6 +421,8 @@ BuiltInTypeMap &GetMethodMap() {
{"neg", std::string("neg")}, // neg() {"neg", std::string("neg")}, // neg()
{"ne", std::string("ne")}, // ne() {"ne", std::string("ne")}, // ne()
{"not_equal", std::string("not_equal")}, // not_equal() {"not_equal", std::string("not_equal")}, // not_equal()
{"new_zeros", std::string("new_zeros")}, // new_zeros()
{"new_ones", std::string("new_ones")}, // new_ones()
{"sinh", std::string("sinh")}, // sinh() {"sinh", std::string("sinh")}, // sinh()
{"sort", std::string("sort")}, // sort() {"sort", std::string("sort")}, // sort()
{"trunc", std::string("trunc")}, // trunc() {"trunc", std::string("trunc")}, // trunc()

View File

@ -2861,6 +2861,22 @@ def nonzero(x):
return F.nonzero(x) return F.nonzero(x)
def new_zeros(x, size, *, dtype=None):
r"""
Return a tensor of `size` filled with zeros. By default, the returned tensor has the same dtype as `x`.
"""
_dtype = x.dtype if dtype is None else dtype
return F.zeros(size, dtype=_dtype)
def new_ones(x, size, *, dtype=None):
r"""
Return a tensor of `size` filled with ones. By default, the returned tensor has the same dtype as `x`.
"""
_dtype = x.dtype if dtype is None else dtype
return F.ones(size, dtype=_dtype)
def diag(x): def diag(x):
""" """
Constructs a diagonal tensor with a given diagonal values. Constructs a diagonal tensor with a given diagonal values.
@ -2897,7 +2913,7 @@ def coo_to_csr(x):
def coo_to_dense(x): def coo_to_dense(x):
"""convert coo to dense.""" """convert coo to dense."""
zeros_tensor = F.zeros(x.shape, x.values.dtype) zeros_tensor = F.zeros(x.shape, dtype=x.values.dtype)
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values) return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)

View File

@ -4418,6 +4418,74 @@ class Tensor(Tensor_):
self._init_check() self._init_check()
return tensor_operator_registry.get('not_equal')(self, other) return tensor_operator_registry.get('not_equal')(self, other)
def new_zeros(self, size, *, dtype=None):
r"""
Return a tensor of `size` filled with zeros. By default, the returned tensor has the same dtype as `self`.
Args:
size (Union[int, tuple, list]): An int, list or tuple of integers defining the output shape.
Keyword Args:
dtype (mindspore.dtype, optional): The desired dtype of the output tensor. If None, same dtype as `self`.
Default: None.
Returns:
Tensor, the shape and dtype is defined above and filled with zeros.
Raises:
TypeError: If `size` is not an int, list or tuple of integers.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> output = x.new_zeros((2, 2))
>>> print(output)
[[0. 0.]
[0. 0.]]
"""
validator.check_value_type('size', size, [list, int, tuple], 'Tensor.new_zeros')
if isinstance(size, list):
size = tuple(size)
self._init_check()
_dtype = self.dtype if dtype is None else dtype
return tensor_operator_registry.get('zeros')(size, _dtype)
def new_ones(self, size, *, dtype=None):
r"""
Return a tensor of `size` filled with ones. By default, the returned tensor has the same dtype as `self`.
Args:
size (Union[int, tuple, list]): An int, list or tuple of integers defining the output shape.
Keyword Args:
dtype (mindspore.dtype, optional): The desired dtype of the output tensor. Default: if None, same dtype as
`self`.
Returns:
Tensor, the shape and dtype is defined above and filled with ones.
Raises:
TypeError: If `size` is not an int, list or tuple of integers.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> output = x.new_ones((2, 2))
>>> print(output)
[[1. 1.]
[1. 1.]]
"""
validator.check_value_type('size', size, [list, int, tuple], 'Tensor.new_zeros')
if isinstance(size, list):
size = tuple(size)
self._init_check()
_dtype = self.dtype if dtype is None else dtype
return tensor_operator_registry.get('ones')(size, _dtype)
def sinh(self): def sinh(self):
r""" r"""
Computes hyperbolic sine of the input element-wise. Computes hyperbolic sine of the input element-wise.

View File

@ -39,6 +39,8 @@ from .array_func import (
size, size,
ones, ones,
ones_like, ones_like,
zeros,
zeros_like,
shape, shape,
shape_, shape_,
ger, ger,

View File

@ -545,16 +545,17 @@ def fills(x, value):
return fills_(x, value_) return fills_(x, value_)
def ones(shape, type): def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
r""" r"""
Creates a tensor filled with value ones. Creates a tensor filled with value ones.
Creates a tensor with shape described by the first argument and Creates a tensor with shape described by the first argument and fills it with value ones in type of the second
fills it with value ones in type of the second argument. argument.
Args: Args:
shape (Union[tuple[int], int]): The specified shape of output tensor. Only constant positive int is allowed. shape (Union[tuple[int], int]): The specified shape of output tensor. Only constant positive int is allowed.
type (mindspore.dtype): The specified type of output tensor. Only constant value is allowed. dtype (:class:`mindspore.dtype`): The specified type of output tensor. If `dtype` is None,
`mindspore.float32` will be used. Default: None.
Returns: Returns:
Tensor, has the same type and shape as input shape value. Tensor, has the same type and shape as input shape value.
@ -570,13 +571,11 @@ def ones(shape, type):
>>> print(output) >>> print(output)
[[1. 1.] [[1. 1.]
[1. 1.]] [1. 1.]]
>>> output = ops.ones((3, 3), mindspore.float32)
>>> print(output)
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
""" """
return ones_(shape, type) _dtype = mstype.float32 if dtype is None else dtype
ones_op = P.Ones()
output = ones_op(shape, _dtype)
return output
def ones_like(input_x): def ones_like(input_x):
@ -602,7 +601,74 @@ def ones_like(input_x):
[[1 1] [[1 1]
[1 1]] [1 1]]
""" """
return ones_like_(input_x) ones_like_op = P.OnesLike()
output = ones_like_op(input_x)
return output
def zeros(shape, dtype=None): # pylint: disable=redefined-outer-name
r"""
Creates a tensor filled with 0 with shape described by `size` and fills it with value 0 in type of `dtype`.
Args:
shape (Union[tuple[int], int]): The specified shape of output tensor. Only constant positive int is allowed.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): The specified type of output tensor. If `dtype` is None,
mindspore.float32 will be used. Default: None.
Returns:
Tensor, has the same dtype and shape as input.
Raises:
TypeError: If `shape` is neither a tuple of int nor an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = ops.zeros((2, 2), mindspore.float32)
>>> print(output)
[[0. 0.]
[0. 0.]]
"""
zero_op = P.Zeros()
_dtype = mstype.float32 if dtype is None else dtype
output = zero_op(shape, _dtype)
return output
def zeros_like(x, *, dtype=None):
r"""
Creates a tensor filled with 0, with the same size as x, and the given dtype.
If `dtype = None`, the tensor will have the same dtype as input `x`.
Args:
x (Tensor): Tensor of any dimension.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): The specified dtype of the output tensor. If `dtype` is None,
the dtype of the input tensor will be used. Default: None.
Returns:
Tensor, filled with 0.
Raises:
ValueError: If dtype is not a MindSpore dtype.
Examples:
>>> x = Tensor(np.arange(4).reshape(2, 2))
>>> output = ops.zeros_like(x, mindspore.float32)
>>> print(output)
[[0. 0.]
[0. 0.]]
"""
_dtype = x.dtype if dtype is None else dtype
zeros_like_op = P.ZerosLike()
output = zeros_like_op(x)
output = cast_(output, _dtype)
return output
def tile(input_x, multiples): def tile(input_x, multiples):
@ -955,7 +1021,7 @@ def ger(x1, x2):
return ger_(x1, x2) return ger_(x1, x2)
def size(input_x): def size(input_x): # pylint: disable=redefined-outer-name
r""" r"""
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
Tensor. Tensor.
@ -5205,6 +5271,8 @@ __all__ = [
'ger', 'ger',
'ones', 'ones',
'ones_like', 'ones_like',
'zeros',
'zeros_like',
'shape', 'shape',
'shape_', 'shape_',
'reverse', 'reverse',

View File

@ -4146,6 +4146,19 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples:
>>> m = np.ones((3, 3)).astype(np.float32)
>>> arr1 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
>>> arr2 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
>>> a = Tensor(arr1)
>>> b = Tensor(arr2)
>>> c = Tensor(m)
>>> output = ops.addbmm(c, a, b)
>>> print(output)
[[ 949. 1009. 1069.]
[1285. 1377. 1469.]
[1621. 1745. 1869.]]
""" """
bmm_op = _get_cache_prim(P.BatchMatMul)() bmm_op = _get_cache_prim(P.BatchMatMul)()
bmm_res = bmm_op(batch1, batch2) bmm_res = bmm_op(batch1, batch2)
@ -4176,6 +4189,19 @@ def addmm(x, mat1, mat2, *, beta=1, alpha=1):
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples:
>>> m = np.ones((3, 3)).astype(np.float32)
>>> arr1 = np.arange(12).astype(np.float32).reshape((3, 4))
>>> arr2 = np.arange(12).astype(np.float32).reshape((4, 3))
>>> a = Tensor(arr1)
>>> b = Tensor(arr2)
>>> c = Tensor(m)
>>> output = ops.addmm(c, a, b)
>>> print(output)
[[ 43. 49. 55.]
[115. 137. 159.]
[187. 225. 263.]]
""" """
matmul_op = _get_cache_prim(P.MatMul)() matmul_op = _get_cache_prim(P.MatMul)()
return beta * x + alpha * (matmul_op(mat1, mat2)) return beta * x + alpha * (matmul_op(mat1, mat2))
@ -4261,13 +4287,17 @@ def adjoint(x):
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples:
>>> a = Tensor(np.array([[0. + 0.j, 1. + 1.j], [2. + 2.j, 3. + 3.j]]), mindspore.complex128)
>>> output = ops.adjoint(a)
>>> print(output)
[[0.-0.j 2.-2.j]
[1.-1.j 3.-3.j]]
""" """
_dtype = x.dtype _dtype = x.dtype
_dim = x.ndim t = x.swapaxes(-1, -2)
perm = [i for i in range(_dim)] if _dtype in (mstype.complex128, mstype.complex64):
perm[-2], perm[-1] = perm[-1], perm[-2]
t = ops.transpose(x, tuple(perm))
if _dtype in (mstype.complex64, mstype.complex128):
return t.conj() return t.conj()
return t return t

View File

@ -104,8 +104,6 @@ in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict") not_in_dict = Primitive("not_in_dict")
broadcast_gradient_args = Primitive('BroadcastGradientArgs') broadcast_gradient_args = Primitive('BroadcastGradientArgs')
array_reduce = Primitive('array_reduce') array_reduce = Primitive('array_reduce')
zeros = P.Zeros()
zeros_like = P.ZerosLike()
distribute = Primitive('distribute') distribute = Primitive('distribute')
embed = Primitive('embed') embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed() ref_to_embed = _grad_ops.RefToEmbed()
@ -311,6 +309,7 @@ tensor_operator_registry.register('mm', mm)
tensor_operator_registry.register('nan_to_num', nan_to_num) tensor_operator_registry.register('nan_to_num', nan_to_num)
tensor_operator_registry.register('csr_to_coo', csr_to_coo) tensor_operator_registry.register('csr_to_coo', csr_to_coo)
tensor_operator_registry.register('zeros', zeros) tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('ones', ones)
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min) tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)
tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max) tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max)
tensor_operator_registry.register('unsorted_segment_prod', unsorted_segment_prod) tensor_operator_registry.register('unsorted_segment_prod', unsorted_segment_prod)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, dtype):
out = ops.zeros((3, 4, 5), dtype=dtype)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, ms.int32])
def test_zeros(mode, dtype):
"""
Feature: ops.zeros
Description: Verify the result of ops.zeros
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
output = net(dtype)
if dtype is None:
assert output.dtype == ms.float32
else:
assert output.dtype == dtype
expect_out = np.zeros((3, 4, 5))
assert np.array_equal(output.asnumpy(), expect_out)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, dtype):
out = ops.zeros_like(x, dtype=dtype)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, ms.int32])
def test_zeros_like(mode, dtype):
"""
Feature: ops.zeros_like
Description: Verify the result of zeros_like
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.arange(9).reshape((3, 3)))
net = Net()
output = net(x, dtype)
if dtype is None:
assert output.dtype == x.dtype
else:
assert output.dtype == dtype
expect_out = np.zeros((3, 3))
assert np.array_equal(output.asnumpy(), expect_out)

View File

@ -0,0 +1,38 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x, size, dtype):
return x.new_ones(size, dtype=dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, mstype.int32])
def test_new_ones(mode, dtype):
"""
Feature: tensor.new_ones()
Description: Verify the result of tensor.new_ones
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.arange(4).reshape((2, 2)), dtype=mstype.float32)
output = net(x, (3, 3), dtype)
expected = np.ones((3, 3))
if dtype is None:
assert output.dtype == mstype.float32
else:
assert output.dtype == dtype
assert np.allclose(output.asnumpy(), expected)

View File

@ -0,0 +1,38 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x, size, dtype):
return x.new_zeros(size, dtype=dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, mstype.int32])
def test_new_zeros(mode, dtype):
"""
Feature: tensor.new_zeros()
Description: Verify the result of tensor.new_zeros
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.arange(4).reshape((2, 2)), dtype=mstype.float32)
output = net(x, (3, 3), dtype)
expected = np.zeros((3, 3))
if dtype is None:
assert output.dtype == mstype.float32
else:
assert output.dtype == dtype
assert np.allclose(output.asnumpy(), expected)