forked from mindspore-Ecosystem/mindspore
Add tensor.new_zeros, ops.zeros, ops.zeros_like
This commit is contained in:
parent
c05a7d37bd
commit
9fb0538b68
|
@ -340,6 +340,8 @@ Tensor创建
|
|||
mindspore.ops.one_hot
|
||||
mindspore.ops.ones
|
||||
mindspore.ops.ones_like
|
||||
mindspore.ops.zeros
|
||||
mindspore.ops.zeros_like
|
||||
|
||||
随机生成函数
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.Tensor.new_ones
|
||||
==========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.new_ones(size, *, dtype=None)
|
||||
|
||||
返回一个大小为 `size` 的Tensor,填充值为1。默认情况下,返回的Tensor和 `self` 具有相同的数据类型。
|
||||
|
||||
参数:
|
||||
- **size** (Union[int, tuple, list]) - 定义输出的shape。
|
||||
|
||||
关键字参数:
|
||||
- **dtype** (mindspore.dtype, 可选) - 输出的数据类型。默认值:None,使用和 `self` 相同的数据类型。
|
||||
|
||||
返回:
|
||||
Tensor,shape和dtype由输入定义,填充值为1。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `size` 不是一个int,或int的列表/元组。
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.Tensor.new_zeros
|
||||
===========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.new_zeros(size, *, dtype=None)
|
||||
|
||||
返回一个大小为 `size` 的Tensor,填充值为0。默认情况下,返回的Tensor和 `self` 具有相同的数据类型。
|
||||
|
||||
参数:
|
||||
- **size** (Union[int, tuple, list]) - 定义输出的shape。
|
||||
|
||||
关键字参数:
|
||||
- **dtype** (mindspore.dtype, 可选) - 输出的数据类型。默认值:None,使用和 `self` 相同的数据类型。
|
||||
|
||||
返回:
|
||||
Tensor,shape和dtype由输入定义,填充值为0。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `size` 不是一个int,或int的列表/元组。
|
|
@ -184,6 +184,8 @@ mindspore.Tensor
|
|||
mindspore.Tensor.ndimension
|
||||
mindspore.Tensor.negative
|
||||
mindspore.Tensor.nelement
|
||||
mindspore.Tensor.new_ones
|
||||
mindspore.Tensor.new_zeros
|
||||
mindspore.Tensor.numel
|
||||
mindspore.Tensor.nonzero
|
||||
mindspore.Tensor.norm
|
||||
|
|
|
@ -22,4 +22,4 @@
|
|||
Tensor,和 `x` 具有相同的dtype。
|
||||
|
||||
异常:
|
||||
- **ValueError**:If `batch1`, `batch2` 不能进行批量矩阵乘法。
|
||||
- **ValueError** - 如果 `batch1`, `batch2` 不能进行批量矩阵乘法。
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
mindspore.ops.ones
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.ones(shape, type)
|
||||
.. py:function:: mindspore.ops.ones(shape, dtype)
|
||||
|
||||
创建一个值全为1的Tensor。
|
||||
|
||||
第一个参数指定Tensor的shape,第二个参数指定填充值的数据类型。
|
||||
|
||||
参数:
|
||||
- **shape** (Union[tuple[int], int]) - 指定输出Tensor的shape,只能是正整数常量。
|
||||
- **type** (mindspore.dtype) - 指定输出Tensor的数据类型,只能是常量值。
|
||||
- **shape** (Union[tuple[int], int]) - 指定输出Tensor的shape。
|
||||
- **dtype** (:class:`mindspore.dtype`) - 用来描述所创建的Tensor的 `dtype`。如果为None,那么将会使用mindspore.float32。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,shape和数据类型与输入相同。
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
mindspore.ops.zeros
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.zeros(shape, dtype=None)
|
||||
|
||||
创建一个填满0的Tensor,shape由 `size` 决定, dtype由 `dtype` 决定。
|
||||
|
||||
参数:
|
||||
- **shape** (Union[tuple[int], int]) - 用来描述所创建的Tensor的 `shape` 。
|
||||
- **dtype** (:class:`mindspore.dtype`) - 用来描述所创建的Tensor的 `dtype`。如果为None,那么将会使用mindspore.float32。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,dtype和shape由入参决定。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `shape` 既不是int也不是int的元组。
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.ops.zeros_like
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.ops.zeros_like(x, *, dtype=None)
|
||||
|
||||
创建一个填满0的Tensor,shape由 `x` 决定,dtype由 `dtype` 决定。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 用来描述所创建的Tensor的shape 。
|
||||
|
||||
关键字参数:
|
||||
- **dtype** (:class:`mindspore.dtype`, 可选) - 用来描述所创建的Tensor的 `dtype`。如果为None,那么将会使用 `x` 的dtype。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,dtype和shape由入参决定。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `dtype` 不是MindSpore的dtype。
|
|
@ -190,6 +190,8 @@
|
|||
mindspore.Tensor.ndimension
|
||||
mindspore.Tensor.negative
|
||||
mindspore.Tensor.nelement
|
||||
mindspore.Tensor.new_ones
|
||||
mindspore.Tensor.new_zeros
|
||||
mindspore.Tensor.numel
|
||||
mindspore.Tensor.nonzero
|
||||
mindspore.Tensor.norm
|
||||
|
|
|
@ -340,6 +340,8 @@ Tensor Building
|
|||
mindspore.ops.one_hot
|
||||
mindspore.ops.ones
|
||||
mindspore.ops.ones_like
|
||||
mindspore.ops.zeros
|
||||
mindspore.ops.zeros_like
|
||||
|
||||
Randomly Generating Functions
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -421,6 +421,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"neg", std::string("neg")}, // neg()
|
||||
{"ne", std::string("ne")}, // ne()
|
||||
{"not_equal", std::string("not_equal")}, // not_equal()
|
||||
{"new_zeros", std::string("new_zeros")}, // new_zeros()
|
||||
{"new_ones", std::string("new_ones")}, // new_ones()
|
||||
{"sinh", std::string("sinh")}, // sinh()
|
||||
{"sort", std::string("sort")}, // sort()
|
||||
{"trunc", std::string("trunc")}, // trunc()
|
||||
|
|
|
@ -2861,6 +2861,22 @@ def nonzero(x):
|
|||
return F.nonzero(x)
|
||||
|
||||
|
||||
def new_zeros(x, size, *, dtype=None):
|
||||
r"""
|
||||
Return a tensor of `size` filled with zeros. By default, the returned tensor has the same dtype as `x`.
|
||||
"""
|
||||
_dtype = x.dtype if dtype is None else dtype
|
||||
return F.zeros(size, dtype=_dtype)
|
||||
|
||||
|
||||
def new_ones(x, size, *, dtype=None):
|
||||
r"""
|
||||
Return a tensor of `size` filled with ones. By default, the returned tensor has the same dtype as `x`.
|
||||
"""
|
||||
_dtype = x.dtype if dtype is None else dtype
|
||||
return F.ones(size, dtype=_dtype)
|
||||
|
||||
|
||||
def diag(x):
|
||||
"""
|
||||
Constructs a diagonal tensor with a given diagonal values.
|
||||
|
@ -2897,7 +2913,7 @@ def coo_to_csr(x):
|
|||
|
||||
def coo_to_dense(x):
|
||||
"""convert coo to dense."""
|
||||
zeros_tensor = F.zeros(x.shape, x.values.dtype)
|
||||
zeros_tensor = F.zeros(x.shape, dtype=x.values.dtype)
|
||||
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)
|
||||
|
||||
|
||||
|
|
|
@ -4418,6 +4418,74 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('not_equal')(self, other)
|
||||
|
||||
def new_zeros(self, size, *, dtype=None):
|
||||
r"""
|
||||
Return a tensor of `size` filled with zeros. By default, the returned tensor has the same dtype as `self`.
|
||||
|
||||
Args:
|
||||
size (Union[int, tuple, list]): An int, list or tuple of integers defining the output shape.
|
||||
|
||||
Keyword Args:
|
||||
dtype (mindspore.dtype, optional): The desired dtype of the output tensor. If None, same dtype as `self`.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape and dtype is defined above and filled with zeros.
|
||||
|
||||
Raises:
|
||||
TypeError: If `size` is not an int, list or tuple of integers.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> output = x.new_zeros((2, 2))
|
||||
>>> print(output)
|
||||
[[0. 0.]
|
||||
[0. 0.]]
|
||||
"""
|
||||
validator.check_value_type('size', size, [list, int, tuple], 'Tensor.new_zeros')
|
||||
if isinstance(size, list):
|
||||
size = tuple(size)
|
||||
self._init_check()
|
||||
_dtype = self.dtype if dtype is None else dtype
|
||||
return tensor_operator_registry.get('zeros')(size, _dtype)
|
||||
|
||||
def new_ones(self, size, *, dtype=None):
|
||||
r"""
|
||||
Return a tensor of `size` filled with ones. By default, the returned tensor has the same dtype as `self`.
|
||||
|
||||
Args:
|
||||
size (Union[int, tuple, list]): An int, list or tuple of integers defining the output shape.
|
||||
|
||||
Keyword Args:
|
||||
dtype (mindspore.dtype, optional): The desired dtype of the output tensor. Default: if None, same dtype as
|
||||
`self`.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape and dtype is defined above and filled with ones.
|
||||
|
||||
Raises:
|
||||
TypeError: If `size` is not an int, list or tuple of integers.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> output = x.new_ones((2, 2))
|
||||
>>> print(output)
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
"""
|
||||
validator.check_value_type('size', size, [list, int, tuple], 'Tensor.new_zeros')
|
||||
if isinstance(size, list):
|
||||
size = tuple(size)
|
||||
self._init_check()
|
||||
_dtype = self.dtype if dtype is None else dtype
|
||||
return tensor_operator_registry.get('ones')(size, _dtype)
|
||||
|
||||
def sinh(self):
|
||||
r"""
|
||||
Computes hyperbolic sine of the input element-wise.
|
||||
|
|
|
@ -39,6 +39,8 @@ from .array_func import (
|
|||
size,
|
||||
ones,
|
||||
ones_like,
|
||||
zeros,
|
||||
zeros_like,
|
||||
shape,
|
||||
shape_,
|
||||
ger,
|
||||
|
|
|
@ -545,16 +545,17 @@ def fills(x, value):
|
|||
return fills_(x, value_)
|
||||
|
||||
|
||||
def ones(shape, type):
|
||||
def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
|
||||
r"""
|
||||
Creates a tensor filled with value ones.
|
||||
|
||||
Creates a tensor with shape described by the first argument and
|
||||
fills it with value ones in type of the second argument.
|
||||
Creates a tensor with shape described by the first argument and fills it with value ones in type of the second
|
||||
argument.
|
||||
|
||||
Args:
|
||||
shape (Union[tuple[int], int]): The specified shape of output tensor. Only constant positive int is allowed.
|
||||
type (mindspore.dtype): The specified type of output tensor. Only constant value is allowed.
|
||||
dtype (:class:`mindspore.dtype`): The specified type of output tensor. If `dtype` is None,
|
||||
`mindspore.float32` will be used. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
@ -570,13 +571,11 @@ def ones(shape, type):
|
|||
>>> print(output)
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
>>> output = ops.ones((3, 3), mindspore.float32)
|
||||
>>> print(output)
|
||||
[[1. 1. 1.]
|
||||
[1. 1. 1.]
|
||||
[1. 1. 1.]]
|
||||
"""
|
||||
return ones_(shape, type)
|
||||
_dtype = mstype.float32 if dtype is None else dtype
|
||||
ones_op = P.Ones()
|
||||
output = ones_op(shape, _dtype)
|
||||
return output
|
||||
|
||||
|
||||
def ones_like(input_x):
|
||||
|
@ -602,7 +601,74 @@ def ones_like(input_x):
|
|||
[[1 1]
|
||||
[1 1]]
|
||||
"""
|
||||
return ones_like_(input_x)
|
||||
ones_like_op = P.OnesLike()
|
||||
output = ones_like_op(input_x)
|
||||
return output
|
||||
|
||||
|
||||
def zeros(shape, dtype=None): # pylint: disable=redefined-outer-name
|
||||
r"""
|
||||
Creates a tensor filled with 0 with shape described by `size` and fills it with value 0 in type of `dtype`.
|
||||
|
||||
Args:
|
||||
shape (Union[tuple[int], int]): The specified shape of output tensor. Only constant positive int is allowed.
|
||||
|
||||
Keyword Args:
|
||||
dtype (:class:`mindspore.dtype`, optional): The specified type of output tensor. If `dtype` is None,
|
||||
mindspore.float32 will be used. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype and shape as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is neither a tuple of int nor an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = ops.zeros((2, 2), mindspore.float32)
|
||||
>>> print(output)
|
||||
[[0. 0.]
|
||||
[0. 0.]]
|
||||
"""
|
||||
zero_op = P.Zeros()
|
||||
_dtype = mstype.float32 if dtype is None else dtype
|
||||
output = zero_op(shape, _dtype)
|
||||
return output
|
||||
|
||||
|
||||
def zeros_like(x, *, dtype=None):
|
||||
r"""
|
||||
Creates a tensor filled with 0, with the same size as x, and the given dtype.
|
||||
|
||||
If `dtype = None`, the tensor will have the same dtype as input `x`.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of any dimension.
|
||||
|
||||
Keyword Args:
|
||||
dtype (:class:`mindspore.dtype`, optional): The specified dtype of the output tensor. If `dtype` is None,
|
||||
the dtype of the input tensor will be used. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, filled with 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If dtype is not a MindSpore dtype.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.arange(4).reshape(2, 2))
|
||||
>>> output = ops.zeros_like(x, mindspore.float32)
|
||||
>>> print(output)
|
||||
[[0. 0.]
|
||||
[0. 0.]]
|
||||
"""
|
||||
_dtype = x.dtype if dtype is None else dtype
|
||||
zeros_like_op = P.ZerosLike()
|
||||
output = zeros_like_op(x)
|
||||
output = cast_(output, _dtype)
|
||||
return output
|
||||
|
||||
|
||||
def tile(input_x, multiples):
|
||||
|
@ -955,7 +1021,7 @@ def ger(x1, x2):
|
|||
return ger_(x1, x2)
|
||||
|
||||
|
||||
def size(input_x):
|
||||
def size(input_x): # pylint: disable=redefined-outer-name
|
||||
r"""
|
||||
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
||||
Tensor.
|
||||
|
@ -5205,6 +5271,8 @@ __all__ = [
|
|||
'ger',
|
||||
'ones',
|
||||
'ones_like',
|
||||
'zeros',
|
||||
'zeros_like',
|
||||
'shape',
|
||||
'shape_',
|
||||
'reverse',
|
||||
|
|
|
@ -4146,6 +4146,19 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
|
|||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> m = np.ones((3, 3)).astype(np.float32)
|
||||
>>> arr1 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
|
||||
>>> arr2 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
|
||||
>>> a = Tensor(arr1)
|
||||
>>> b = Tensor(arr2)
|
||||
>>> c = Tensor(m)
|
||||
>>> output = ops.addbmm(c, a, b)
|
||||
>>> print(output)
|
||||
[[ 949. 1009. 1069.]
|
||||
[1285. 1377. 1469.]
|
||||
[1621. 1745. 1869.]]
|
||||
"""
|
||||
bmm_op = _get_cache_prim(P.BatchMatMul)()
|
||||
bmm_res = bmm_op(batch1, batch2)
|
||||
|
@ -4176,6 +4189,19 @@ def addmm(x, mat1, mat2, *, beta=1, alpha=1):
|
|||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> m = np.ones((3, 3)).astype(np.float32)
|
||||
>>> arr1 = np.arange(12).astype(np.float32).reshape((3, 4))
|
||||
>>> arr2 = np.arange(12).astype(np.float32).reshape((4, 3))
|
||||
>>> a = Tensor(arr1)
|
||||
>>> b = Tensor(arr2)
|
||||
>>> c = Tensor(m)
|
||||
>>> output = ops.addmm(c, a, b)
|
||||
>>> print(output)
|
||||
[[ 43. 49. 55.]
|
||||
[115. 137. 159.]
|
||||
[187. 225. 263.]]
|
||||
"""
|
||||
matmul_op = _get_cache_prim(P.MatMul)()
|
||||
return beta * x + alpha * (matmul_op(mat1, mat2))
|
||||
|
@ -4261,13 +4287,17 @@ def adjoint(x):
|
|||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = Tensor(np.array([[0. + 0.j, 1. + 1.j], [2. + 2.j, 3. + 3.j]]), mindspore.complex128)
|
||||
>>> output = ops.adjoint(a)
|
||||
>>> print(output)
|
||||
[[0.-0.j 2.-2.j]
|
||||
[1.-1.j 3.-3.j]]
|
||||
"""
|
||||
_dtype = x.dtype
|
||||
_dim = x.ndim
|
||||
perm = [i for i in range(_dim)]
|
||||
perm[-2], perm[-1] = perm[-1], perm[-2]
|
||||
t = ops.transpose(x, tuple(perm))
|
||||
if _dtype in (mstype.complex64, mstype.complex128):
|
||||
t = x.swapaxes(-1, -2)
|
||||
if _dtype in (mstype.complex128, mstype.complex64):
|
||||
return t.conj()
|
||||
return t
|
||||
|
||||
|
|
|
@ -104,8 +104,6 @@ in_dict = Primitive("in_dict")
|
|||
not_in_dict = Primitive("not_in_dict")
|
||||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
zeros = P.Zeros()
|
||||
zeros_like = P.ZerosLike()
|
||||
distribute = Primitive('distribute')
|
||||
embed = Primitive('embed')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
|
@ -311,6 +309,7 @@ tensor_operator_registry.register('mm', mm)
|
|||
tensor_operator_registry.register('nan_to_num', nan_to_num)
|
||||
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
|
||||
tensor_operator_registry.register('zeros', zeros)
|
||||
tensor_operator_registry.register('ones', ones)
|
||||
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)
|
||||
tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max)
|
||||
tensor_operator_registry.register('unsorted_segment_prod', unsorted_segment_prod)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, dtype):
|
||||
out = ops.zeros((3, 4, 5), dtype=dtype)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [None, ms.int32])
|
||||
def test_zeros(mode, dtype):
|
||||
"""
|
||||
Feature: ops.zeros
|
||||
Description: Verify the result of ops.zeros
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
output = net(dtype)
|
||||
if dtype is None:
|
||||
assert output.dtype == ms.float32
|
||||
else:
|
||||
assert output.dtype == dtype
|
||||
expect_out = np.zeros((3, 4, 5))
|
||||
assert np.array_equal(output.asnumpy(), expect_out)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, dtype):
|
||||
out = ops.zeros_like(x, dtype=dtype)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [None, ms.int32])
|
||||
def test_zeros_like(mode, dtype):
|
||||
"""
|
||||
Feature: ops.zeros_like
|
||||
Description: Verify the result of zeros_like
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.arange(9).reshape((3, 3)))
|
||||
net = Net()
|
||||
output = net(x, dtype)
|
||||
if dtype is None:
|
||||
assert output.dtype == x.dtype
|
||||
else:
|
||||
assert output.dtype == dtype
|
||||
expect_out = np.zeros((3, 3))
|
||||
assert np.array_equal(output.asnumpy(), expect_out)
|
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, size, dtype):
|
||||
return x.new_ones(size, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [None, mstype.int32])
|
||||
def test_new_ones(mode, dtype):
|
||||
"""
|
||||
Feature: tensor.new_ones()
|
||||
Description: Verify the result of tensor.new_ones
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor(np.arange(4).reshape((2, 2)), dtype=mstype.float32)
|
||||
output = net(x, (3, 3), dtype)
|
||||
expected = np.ones((3, 3))
|
||||
if dtype is None:
|
||||
assert output.dtype == mstype.float32
|
||||
else:
|
||||
assert output.dtype == dtype
|
||||
assert np.allclose(output.asnumpy(), expected)
|
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, size, dtype):
|
||||
return x.new_zeros(size, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [None, mstype.int32])
|
||||
def test_new_zeros(mode, dtype):
|
||||
"""
|
||||
Feature: tensor.new_zeros()
|
||||
Description: Verify the result of tensor.new_zeros
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor(np.arange(4).reshape((2, 2)), dtype=mstype.float32)
|
||||
output = net(x, (3, 3), dtype)
|
||||
expected = np.zeros((3, 3))
|
||||
if dtype is None:
|
||||
assert output.dtype == mstype.float32
|
||||
else:
|
||||
assert output.dtype == dtype
|
||||
assert np.allclose(output.asnumpy(), expected)
|
Loading…
Reference in New Issue