add triu, true_divide, subtract to Tensor

This commit is contained in:
shaojunsong 2022-10-27 20:06:19 +08:00
parent 2e0a40557b
commit 2967aa99a5
12 changed files with 255 additions and 27 deletions

View File

@ -0,0 +1,20 @@
mindspore.Tensor.subtract
==========================
.. py:method:: mindspore.Tensor.subtract(other, *, alpha=1)
对Tensor进行逐元素的减法。
.. math::
output[i] = x[i] - alpha * y[i]
参数:
- **other** (Union[Tensor, number.Number]) - 参与减法的Tensor或者Number。
- **alpha** (Number) - :math:`other` 的乘数。默认值1。
返回:
Tensorshape与广播后的shape相同数据类型为输入中精度较高的类型。
异常:
- **TypeError** - `other` 不是Tensor、number.Number。

View File

@ -0,0 +1,20 @@
mindspore.Tensor.triu
=====================
.. py:method:: mindspore.Tensor.triu(diagonal=0)
根据对角线返回相应的三角矩阵。默认为主对角线。
参数:
-**diagonal** (int) - 对角线的系数。默认值0。
返回:
Tensorshape和dtype与输入相同。
异常:
-**TypeError** - 如果 `diagonal` 不是int。
-**TypeError** - 如果 `x` 不是Tensor。
-**ValueError** - 如果shape的长度小于1。

View File

@ -0,0 +1,9 @@
mindspore.Tensor.true_divide
============================
.. py:method:: mindspore.Tensor.true_divide(value)
当Tensor.div()的 :math:`rounding_mode=None` 时的别名。
参考`Tensor.div() <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/Tensor/mindspore.Tensor.div.html>`_

View File

@ -57,6 +57,7 @@ mindspore.Tensor
mindspore.Tensor.less_equal
mindspore.Tensor.igamma
mindspore.Tensor.igammac
mindspore.Tensor.true_divide
逐元素运算
^^^^^^^^^^^^^
@ -104,6 +105,7 @@ mindspore.Tensor
mindspore.Tensor.square
mindspore.Tensor.std
mindspore.Tensor.sub
mindspore.Tensor.subtract
mindspore.Tensor.svd
mindspore.Tensor.tan
mindspore.Tensor.tanh
@ -158,10 +160,10 @@ Reduction方法
:toctree: Tensor
:nosignatures:
mindspore.Tensor.det
mindspore.Tensor.ger
mindspore.Tensor.log_matrix_determinant
mindspore.Tensor.matrix_determinant
mindspore.Tensor.det
Tensor操作方法
----------------
@ -253,6 +255,7 @@ Array操作
mindspore.Tensor.to_tensor
mindspore.Tensor.trace
mindspore.Tensor.transpose
mindspore.Tensor.triu
mindspore.Tensor.unfold
mindspore.Tensor.unique_consecutive
mindspore.Tensor.unique_with_pad

View File

@ -62,6 +62,7 @@ Mathematical Methods
mindspore.Tensor.less_equal
mindspore.Tensor.igamma
mindspore.Tensor.igammac
mindspore.Tensor.true_divide
Element-wise Methods
^^^^^^^^^^^^^^^^^^^^
@ -108,6 +109,7 @@ Element-wise Methods
mindspore.Tensor.sqrt
mindspore.Tensor.std
mindspore.Tensor.sub
mindspore.Tensor.subtract
mindspore.Tensor.svd
mindspore.Tensor.square
mindspore.Tensor.tan
@ -163,10 +165,10 @@ Linear Algebraic Methods
:toctree: Tensor
:nosignatures:
mindspore.Tensor.det
mindspore.Tensor.ger
mindspore.Tensor.log_matrix_determinant
mindspore.Tensor.matrix_determinant
mindspore.Tensor.det
Tensor Operation Methods
------------------------
@ -258,6 +260,7 @@ Array Methods
mindspore.Tensor.to_tensor
mindspore.Tensor.trace
mindspore.Tensor.transpose
mindspore.Tensor.triu
mindspore.Tensor.unfold
mindspore.Tensor.unique_consecutive
mindspore.Tensor.unique_with_pad

View File

@ -280,6 +280,9 @@ BuiltInTypeMap &GetMethodMap() {
{"sqrt", std::string("sqrt")}, // P.Sqrt()
{"square", std::string("square")}, // P.Square()
{"sub", std::string("sub")}, // P.Sub()
{"true_divide", std::string("true_divide")}, // true_divide()
{"triu", std::string("triu")}, // triu()
{"subtract", std::string("subtract")}, // true_divide()
{"exp", std::string("exp")}, // P.Exp()
{"repeat", std::string("repeat")}, // C.repeat_elements
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()

View File

@ -2908,6 +2908,28 @@ def top_k(input_x, k, sorted=True):
return F.top_k(input_x, k, sorted)
def subtract(x, other, *, alpha=1):
r"""
Computes the element-wise subtraction of input tensors.
"""
return F.sub(x, other * alpha)
def true_divide(divident, divisor):
r"""
Computes the element-wise division of input tensors.
"""
return F.div(divident, divisor, None)
# pylint: disable=redefined-outer-name
def triu(x, diagonal=0):
r"""
Returns the triangular matrix based on the diagonal.
"""
return F.Triu(diagonal)(x)
#############
# Iteration #
#############
@ -3227,7 +3249,7 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
r"""
Computes the outer-product of `vec1` and `vec2` and adds it to `x`.
"""
return F.addr(x, vec1, vec2, beta=1, alpha=1)
return F.addr(x, vec1, vec2, beta=beta, alpha=alpha)
def addmv(x, mat, vec, beta=1, alpha=1):

View File

@ -630,7 +630,7 @@ class Tensor(Tensor_):
def ndimension(self):
r"""
Refer to `Tensor.ndim()
<https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/Tensor/mindspore.Tensor.ndim.html>` _.
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/Tensor/mindspore.Tensor.ndim.html>`_.
"""
return len(self._shape)
@ -970,6 +970,67 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('add')()(self, y)
def subtract(self, other, *, alpha=1):
r"""
Performs the element-wise subtraction of input tensors.
.. math::
output[i] = x[i] - alpha * y[i]
Args:
other (Tensor, Number): The tensor or number to be subtracted.
alpha (Number): The multiplier for `other`. Default: 1.
Returns:
Tensor, has the same shape and dtype as input tensors.
Raises:
TypeError: `other` is not Tensor, number.Number.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([4, 5, 6]), mindspore.float32)
>>> y = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> z = x.subtract(y, alpha=1)
>>> print(z)
[3. 3. 3.]
"""
self._init_check()
return tensor_operator_registry.get('sub')(self, alpha * other)
def true_divide(self, value):
r"""
Alias for Tensor.div() with :math:`rounding_mode=None`.
Refer to `Tensor.div()
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/Tensor/mindspore.Tensor.div.html>`_.
"""
self._init_check()
return tensor_operator_registry.get('div')(self, value, None)
def triu(self, diagonal=0):
r"""
Returns a triangular matrix based on the diagonal. Default is the main diagonal.
Args:
diagonal (int): The index of diagonal. Default: 0.
Returns:
Tensor, a tensor has the same shape and data type as input.
Raises:
TypeError: If `diagonal` is not an int.
TypeError: If `x` is not an Tensor.
ValueError: If length of shape of x is less than 1.
Supported Platforms:
``GPU`` ``CPU``
"""
self._init_check()
validator.check_value_type('diagonal', diagonal, [int], 'triu')
return tensor_operator_registry.get('triu')(diagonal)(self)
def addr(self, vec1, vec2, beta=1, alpha=1):
r"""
Executes the outer-product of `vec1` and `vec2` and adds it to the input tensor.
@ -1015,7 +1076,7 @@ class Tensor(Tensor_):
"""
self._init_check()
return tensor_operator_registry.get('addr')(self, vec1, vec2, beta=1, alpha=1)
return tensor_operator_registry.get('addr')(self, vec1, vec2, beta=beta, alpha=alpha)
def all(self, axis=(), keep_dims=False):
"""
@ -4273,8 +4334,8 @@ class Tensor(Tensor_):
def det(self):
r"""
Refer to `Tensor.matrix_determinant()
<https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/Tensor/mindspore.Tensor.matrix_determinant.html
>` _.
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/Tensor/mindspore.Tensor.matrix_determinant.html
>`_.
"""
self._init_check()
return tensor_operator_registry.get('matrix_determinant')(self)
@ -5214,7 +5275,7 @@ class Tensor(Tensor_):
dtype_ = self.dtype
x = self.asnumpy()
n = x.strides[1]
strides = tuple(np.array(strides)*n)
strides = tuple(np.array(strides) * n)
return Tensor(np.lib.stride_tricks.as_strided(x, shape, strides, subok, writeable), dtype=dtype_)
def randperm(self, max_length=1, pad=-1):
@ -6368,7 +6429,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('bmm')(self, mat2)
def to(self, dtype):
r"""
Performs tensor dtype conversion.
@ -6396,7 +6456,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('to')()(self, dtype)
def bool(self):
r"""
Converts input tensor dtype to `bool`.
@ -6417,7 +6476,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('bool')()(self, mstype.bool_)
def float(self):
r"""
Converts input tensor dtype to `float32`.
@ -6437,7 +6495,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('float')()(self, mstype.float32)
def half(self):
r"""
Converts input tensor dtype to `float16`.
@ -6457,7 +6514,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('half')()(self, mstype.float16)
def int(self):
r"""
Converts input tensor dtype to `int32`. If the value in tensor is float or half, the decimal will be discarded.
@ -6477,7 +6533,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('int')()(self, mstype.int32)
def long(self):
r"""
Converts input tensor dtype to `int64`. If the value in tensor is float or half, the decimal will be discarded.
@ -6497,7 +6552,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('long')()(self, mstype.int64)
def cholesky(self, upper=False):
r"""
Computes the Cholesky decomposition of a symmetric positive-definite matrix :math:`A`
@ -6541,7 +6595,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('cholesky')(upper=upper)(self)
def cholesky_inverse(self, upper=False):
r"""
Returns the inverse of the positive definite matrix using cholesky matrix factorization.
@ -6583,7 +6636,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('cholesky_inverse')(upper=upper)(self)
def conj(self):
r"""
Returns a tensor of complex numbers that are the complex conjugate of each element in input.
@ -6611,7 +6663,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('conj')(self)
def cross(self, other, dim=None):
r"""
Returns the cross product of vectors in dimension `dim` of input tensor and `other`. input tensor and `other`
@ -6647,7 +6698,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('cross')(self, other, dim)
def erfinv(self):
r"""
Computes the inverse error function of input. The inverse error function is defined in the range `(-1, 1)` as:
@ -6674,7 +6724,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('erfinv')(self)
def less_equal(self, other):
r"""
Computes the boolean value of :math:`input <= other` element-wise.
@ -6709,7 +6758,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('less_equal')(self, other)
def fold(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
r"""
Combines an array of sliding local blocks into a large containing tensor.
@ -6753,7 +6801,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('fold')(self, output_size, kernel_size, dilation, padding, stride)
def unfold(self, kernel_size, dilation=1, padding=0, stride=1):
r"""
Extracts sliding local blocks from a batched input tensor.
@ -6799,7 +6846,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('unfold')(self, kernel_size, dilation, padding, stride)
def expand(self, size):
r"""
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.
@ -6839,7 +6885,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('expand')(self, size)
def cumprod(self, dim, dtype=None):
r"""
Computes the cumulative product of the tensor along dimension `dim`. For example, if input tensor is a
@ -6871,7 +6916,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('cumprod')(self, dim, dtype)
def div(self, other, rounding_mode=None):
r"""
Divides the tensor `input` by the given input tensor `other` in floating-point type element-wise.
@ -6917,7 +6961,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('div')(self, other, rounding_mode)
def equal(self, other):
r"""
Computes the equivalence between the tensor input tensor `input` and the
@ -6959,7 +7002,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('equal')(self, other)
def expm1(self):
r"""
Returns exponential then minus 1 of a tensor element-wise.

View File

@ -36,6 +36,7 @@ from mindspore.ops.operations import _inner_ops
from mindspore.ops.operations import linalg_ops
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.array_ops import UniqueConsecutive
from mindspore.ops.operations.array_ops import Triu
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.composite import _Vmap
@ -338,6 +339,7 @@ tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub)
tensor_operator_registry.register('triu', Triu)
tensor_operator_registry.register('tan', P.Tan)
tensor_operator_registry.register('acos', acos)
tensor_operator_registry.register('cos', cos)

View File

@ -0,0 +1,34 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x, other):
return x.subtract(other, alpha=2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_subtract(mode):
"""
Feature: tensor.subtract()
Description: Verify the result of tensor.subtract
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor([4, 5, 6], dtype=mstype.float32)
y = Tensor([1, 2, 3], dtype=mstype.float32)
output = net(x, y)
expected = np.array([2, 1, 0], dtype=np.float32)
assert np.allclose(output.asnumpy(), expected)

View File

@ -0,0 +1,36 @@
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x):
return x.triu(diagonal=1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_subtract(mode):
"""
Feature: tensor.subtract()
Description: Verify the result of tensor.subtract
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[10, 11, 12, 13],
[14, 15, 16, 17]]))
output = net(x)
expected = np.array([[0, 2, 3, 4],
[0, 0, 7, 8],
[0, 0, 0, 13],
[0, 0, 0, 0]])
assert np.array_equal(output.asnumpy(), expected)

View File

@ -0,0 +1,34 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x, other):
return x.true_divide(other)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_subtract(mode):
"""
Feature: tensor.subtract()
Description: Verify the result of tensor.subtract
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.array([1.0, 2.0, 3.0]), mstype.float32)
y = Tensor(np.array([4.0, 5.0, 6.0]), mstype.float32)
output = net(x, y)
expected = np.array([0.25, 0.4, 0.5], dtype=np.float32)
assert np.allclose(output.asnumpy(), expected)