Merge pull request !45197 from shaojunsong/tensor1104
This commit is contained in:
i-robot 2022-11-10 06:53:05 +00:00 committed by Gitee
commit 19f2e20376
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 407 additions and 2 deletions

View File

@ -0,0 +1,6 @@
mindspore.Tensor.addbmm
=======================
.. py:method:: mindspore.Tensor.addbmm(batch1, batch2, *, beta=1, alpha=1)
详情请参考 :func:`mindspore.ops.addbmm`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.addmm
======================
.. py:method:: mindspore.Tensor.addmm(mat1, mat2, *, beta=1, alpha=1)
详情请参考 :func:`mindspore.ops.addmm`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.adjoint
========================
.. py:method:: mindspore.Tensor.adjoint()
详情请参考 :func:`mindspore.ops.adjoint`

View File

@ -24,10 +24,13 @@ mindspore.Tensor
mindspore.Tensor.absolute
mindspore.Tensor.acos
mindspore.Tensor.add
mindspore.Tensor.addbmm
mindspore.Tensor.addcdiv
mindspore.Tensor.addcmul
mindspore.Tensor.addmm
mindspore.Tensor.addmv
mindspore.Tensor.addr
mindspore.Tensor.adjoint
mindspore.Tensor.all
mindspore.Tensor.amax
mindspore.Tensor.amin

View File

@ -0,0 +1,25 @@
mindspore.ops.addbmm
=====================
.. py:class:: mindspore.ops.addbmm(x, batch1, batch2, *, beta=1, alpha=1)
`batch1``batch2` 应用批量矩阵乘法后进行reduced add。矩阵 `x` 和最终的结果相加。
`alpha``beta` 分别是 `batch1``batch2` 矩阵乘法和 `x` 的乘数。如果 `beta` 是0那么 `x` 将会被忽略。
.. math::
output = \beta x + \alpha (\sum_{i=0}^{b-1} {batch1 @ batch2})
参数:
- **x** (Tensor) - 被添加的tensor。
- **batch1** (Tensor) - 矩阵乘法中的第一个张量。
- **batch2** (Tensor) - 矩阵乘法中的第二个张量。
关键字参数:
- **beta** (Union[int, float],可选) - `x` 的乘数。默认值1。
- **alpha** (Union[int, float],可选) - `batch1` @ `batch2` 的乘数。默认值1。
返回:
Tensor`x` 具有相同的dtype。
异常:
- **ValueError**If `batch1` `batch2` 不能进行批量矩阵乘法。

View File

@ -0,0 +1,24 @@
mindspore.ops.addmm
====================
.. py:class:: mindspore.ops.addmm(x, mat1, mat2, *, beta=1, alpha=1)
`mat1``mat2` 应用矩阵乘法。矩阵 `x` 和最终的结果相加。 `alpha``beta` 分别是 `mat1``mat2` 矩阵乘法和 `x` 的乘数。如果 `beta` 是0那么 `x` 将会被忽略。
.. math::
output = \beta x + \alpha (mat1 @ mat2)
参数:
- **x** (Tensor) - 被添加的tensor。
- **mat1** (Tensor) - 矩阵乘法中的第一个张量。
- **mat2** (Tensor) - 矩阵乘法中的第二个张量。
关键字参数:
- **beta** (Union[int, float],可选) - `x` 的乘数。默认值1。
- **alpha** (Union[int, float],可选) - `mat1` @ `mat2` 的乘数。默认值1。
返回:
Tensor`x` 具有相同的dtype。
异常:
- **ValueError**If `mat1``mat2` 不能进行矩阵乘法。

View File

@ -0,0 +1,15 @@
mindspore.ops.adjoint
======================
.. py:class:: mindspore.ops.adjoint(x)
计算张量的共轭,并转置最后两个维度。
参数:
- **x** (Tensor) - 参与计算的tensor。
返回:
Tensor`x` 具有相同的dtype和shape。
异常:
- **TypeError**`x` 不是tensor。

View File

@ -30,10 +30,13 @@
mindspore.Tensor.absolute
mindspore.Tensor.acos
mindspore.Tensor.add
mindspore.Tensor.addbmm
mindspore.Tensor.addcdiv
mindspore.Tensor.addcmul
mindspore.Tensor.addmm
mindspore.Tensor.addmv
mindspore.Tensor.addr
mindspore.Tensor.adjoint
mindspore.Tensor.all
mindspore.Tensor.amax
mindspore.Tensor.amin

View File

@ -336,12 +336,15 @@ BuiltInTypeMap &GetMethodMap() {
{"sigmoid", std::string("sigmoid")}, // P.Sigmoid()
{"addr", std::string("addr")}, // addr()
{"add", std::string("add")}, // P.Add()
{"asin", std::string("asin")}, // asin()
{"addbmm", std::string("addbmm")}, // addbmm()
{"addmm", std::string("addmm")}, // addmm()
{"addmv", std::string("addmv")}, // addmv()
{"adjoint", std::string("adjoint")}, // adjoint()
{"arccosh", std::string("acosh")}, // arccosh()
{"arcsin", std::string("asin")}, // arcsin()
{"arctan", std::string("atan")}, // arctan()
{"arctan2", std::string("atan2")}, // arctan2()
{"asin", std::string("asin")}, // asin()
{"asinh", std::string("asinh")}, // asinh()
{"arcsinh", std::string("asinh")}, // arcsinh()
{"atan", std::string("atan")}, // atan()

View File

@ -3405,6 +3405,20 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
return F.addr(x, vec1, vec2, beta=beta, alpha=alpha)
def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
r"""
Performs matrix multiplication with a reduced sum, and add `x` to the result.
"""
return F.addbmm(x, batch1, batch2, beta=beta, alpha=alpha)
def addmm(x, mat1, mat2, *, beta=1, alpha=1):
r"""
Performs matrix multiplication, and add `x` to the result.
"""
return F.addmm(x, mat1, mat2, beta=beta, alpha=alpha)
def addmv(x, mat, vec, beta=1, alpha=1):
r"""
Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.
@ -3412,6 +3426,13 @@ def addmv(x, mat, vec, beta=1, alpha=1):
return F.addmv(x, mat, vec, beta, alpha)
def adjoint(x):
r"""
Computes the conjucated matrix with the last 2 dimensions transposed.
"""
return F.adjoint(x)
def asinh(x):
r"""
Computes inverse hyperbolic sine of the input element-wise.

View File

@ -1025,6 +1025,20 @@ class Tensor(Tensor_):
validator.check_value_type('diagonal', diagonal, [int], 'triu')
return tensor_operator_registry.get('triu')(diagonal)(self)
def addbmm(self, batch1, batch2, *, beta=1, alpha=1):
r"""
For details, please refer to :func:`mindspore.ops.addbmm`.
"""
self._init_check()
return tensor_operator_registry.get('addbmm')(self, batch1, batch2, beta=beta, alpha=alpha)
def addmm(self, mat1, mat2, *, beta=1, alpha=1):
r"""
For details, please refer to :func:`mindspore.ops.addmm`.
"""
self._init_check()
return tensor_operator_registry.get('addmm')(self, mat1, mat2, beta=beta, alpha=alpha)
def addr(self, vec1, vec2, beta=1, alpha=1):
r"""
Executes the outer-product of `vec1` and `vec2` and adds it to the input tensor.
@ -1072,6 +1086,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('addr')(self, vec1, vec2, beta=beta, alpha=alpha)
def adjoint(self):
r"""
For details, please refer to :func:`mindspore.ops.adjoint`.
"""
self._init_check()
return tensor_operator_registry.get('adjoint')(self)
def all(self, axis=(), keep_dims=False):
"""
Check all tensor elements along a given axis evaluate to True.

View File

@ -199,8 +199,11 @@ from .math_func import (
std,
ldexp,
mv,
addbmm,
addmv,
addmm,
addr,
adjoint,
inplace_add,
inplace_sub,
inplace_update,

View File

@ -2448,7 +2448,7 @@ def log_matrix_determinant(x):
return log_matrix_determinant_(x)
def matrix_solve(matrix, rhs, adjoint=False):
def matrix_solve(matrix, rhs, adjoint=False): # pylint: disable=redefined-outer-name
r"""
Solves systems of linear equations.
@ -3809,6 +3809,69 @@ def mv(mat, vec):
return out
def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
r"""
Applies batch matrix multiplication to `batch1` and `batch2`, with a reduced add step. The matrix `x` is add to
final result.
The optional values `alpha` and `beta` are the matrix-matrix product between `batch1` and `batch2` and the scale
factor for the added tensor `x` respectively. If `beta` is 0, then `x` will be ignored.
.. math::
output = \beta x + \alpha (\sum_{i=0}^{b-1} {batch1 @ batch2})
Args:
x (Tensor): Tensor to be added.
batch1 (Tensor): The first batch of tensor to be multiplied.
batch2 (Tensor): The second batch of tensor to be multiplied.
Keyword Args:
beta (scalar[int, float], optional): Multiplier for `x`. Default: 1.
alpha (scalar[int, float], optional): Multiplier for `batch1` @ `batch2`. Default: 1.
Returns:
Tensor, has the same dtype as `x`.
Raises:
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
bmm_op = _get_cache_prim(P.BatchMatMul)()
bmm_res = bmm_op(batch1, batch2)
return beta * x + alpha * (bmm_res.sum(axis=0))
def addmm(x, mat1, mat2, *, beta=1, alpha=1):
r"""
Multiplies matrix `mat1` and matrix `mat2`. The matrix `x` is added to the final result.
Args:
x (Tensor): Tensor to be added.
mat1 (Tensor): The first tensor to be multiplied.
mat2 (Tensor): The second tensor to be multiplied.
Keyword Args:
beta (scalar[int, float], optional): Multiplier for `x`. Default: 1.
alpha (scalar[int, float], optional): Multiplier for `mat1` @ `mat2`. Default: 1.
.. math::
output = \beta x + \alpha (mat1 @ mat2)
Returns:
Tensor, has the same dtype as `x`.
Raises:
ValueError: If `mat1`, `mat2` cannot apply matrix multiplication.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
matmul_op = _get_cache_prim(P.MatMul)()
return beta * x + alpha * (matmul_op(mat1, mat2))
def addmv(x, mat, vec, beta=1, alpha=1):
"""
Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.
@ -3874,6 +3937,32 @@ def addmv(x, mat, vec, beta=1, alpha=1):
return out
def adjoint(x):
r"""
Returns a view of the tensor conjugated and with the last two dimensions transposed.
Args:
x (Tensor): Input tensor.
Returns:
Tensor, the calculated result.
Raises:
TypeError: If `x` is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
_dtype = x.dtype
_dim = x.ndim
perm = [i for i in range(_dim)]
perm[-2], perm[-1] = perm[-1], perm[-2]
t = ops.transpose(x, tuple(perm))
if _dtype in (mstype.complex64, mstype.complex128):
return t.conj()
return t
def addr(x, vec1, vec2, beta=1, alpha=1):
"""
Executes the outer-product of `vec1` and `vec2` and adds it to the vec1rix `x`.
@ -7088,6 +7177,7 @@ __all__ = [
'abs',
'tensor_add',
'add',
'addbmm',
'addcdiv',
'addcmul',
'argmin',
@ -7108,7 +7198,9 @@ __all__ = [
'tensor_gt',
'logaddexp',
'mv',
'addmm',
'addmv',
'adjoint',
'outer',
'gt',
'tensor_ge',

View File

@ -214,7 +214,10 @@ tensor_operator_registry.register('sigmoid', P.Sigmoid)
tensor_operator_registry.register('median', Median)
tensor_operator_registry.register('tanh', tanh)
tensor_operator_registry.register('exp', P.Exp)
tensor_operator_registry.register('addbmm', addbmm)
tensor_operator_registry.register('addmm', addmm)
tensor_operator_registry.register('addmv', addmv)
tensor_operator_registry.register('adjoint', adjoint)
tensor_operator_registry.register('asinh', asinh)
tensor_operator_registry.register('atan', atan)
tensor_operator_registry.register('atanh', atanh)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, y, z, beta, alpha):
return x.addbmm(y, z, beta=beta, alpha=alpha)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_addbmm(mode):
"""
Feature: tensor.addbmm
Description: Verify the result of addbmm
Expectation: success
"""
ms.set_context(mode=mode)
arr1 = np.arange(9).astype(np.float32).reshape((3, 3))
arr2 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
arr3 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
x = Tensor(arr1)
y = Tensor(arr2)
z = Tensor(arr3)
net = Net()
output = net(x, y, z, 0.5, 2)
expect_output = np.array([[1896.0000, 2016.5000, 2137.0000],
[2569.5000, 2754.0000, 2938.5000],
[3243.0000, 3491.5000, 3740.0000]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, y, z, beta, alpha):
return x.addmm(y, z, beta=beta, alpha=alpha)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_addmm(mode):
"""
Feature: tensor.addmm
Description: Verify the result of addmm
Expectation: success
"""
ms.set_context(mode=mode)
arr1 = np.arange(9).astype(np.float32).reshape((3, 3))
arr2 = np.arange(12).astype(np.float32).reshape((3, 4))
arr3 = np.arange(12).astype(np.float32).reshape((4, 3))
x = Tensor(arr1)
y = Tensor(arr2)
z = Tensor(arr3)
net = Net()
output = net(x, y, z, 0.5, 2)
expect_output = np.array([[84.0000, 96.5000, 109.0000],
[229.5000, 274.0000, 318.5000],
[375.0000, 451.5000, 528.0000]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,68 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.adjoint()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_adjoint(mode):
"""
Feature: tensor.adjoint
Description: Verify the result of adjoint
Expectation: success, however, when running on Ascend, transpose does not support Complex numbers.
"""
ms.set_context(mode=mode)
x = Tensor(np.array([[0., 1.], [2., 3.]]), ms.float32)
net = Net()
output = net(x)
expect_output = np.array([[0., 2.],
[1., 3.]])
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_adjoint_complex(mode):
"""
Feature: tensor.adjoint
Description: Verify the result of adjoint
Expectation: success.
"""
ms.set_context(mode=mode)
x = Tensor(np.array([[0. + 0.j, 1. + 1.j], [2. + 2.j, 3. + 3.j]]), ms.complex128)
net = Net()
output = net(x)
expect_output = np.array([[0. - 0.j, 2. - 2.j],
[1. - 1.j, 3. - 3.j]])
assert np.allclose(output.asnumpy(), expect_output)