!45420 Add tensor.baddmm, doc and st test

Merge pull request !45420 from DavidFFFan/api_tensor
This commit is contained in:
i-robot 2022-11-14 08:17:26 +00:00 committed by Gitee
commit 5ad5f358d6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 114 additions and 1 deletions

View File

@ -162,6 +162,7 @@ mindspore.ops.function
mindspore.ops.atan
mindspore.ops.atan2
mindspore.ops.atanh
mindspore.ops.baddbmm
mindspore.ops.bernoulli
mindspore.ops.bessel_i0
mindspore.ops.bessel_i0e

View File

@ -0,0 +1,8 @@
mindspore.Tensor.baddbmm
========================
.. py:method:: mindspore.Tensor.baddbmm(batch1, batch2, beta=1, alpha=1)
计算三维矩阵batch1、batch2的乘积与Tensor的和。
详情请参考 :func:`mindspore.ops.baddbmm`

View File

@ -55,6 +55,7 @@ mindspore.Tensor
mindspore.Tensor.atan
mindspore.Tensor.atan2
mindspore.Tensor.atanh
mindspore.Tensor.baddbmm
mindspore.Tensor.bernoulli
mindspore.Tensor.bitwise_and
mindspore.Tensor.bitwise_or

View File

@ -0,0 +1,27 @@
mindspore.ops.baddbmm
=====================
.. py:function:: mindspore.ops.baddbmm(x, batch1, batch2, beta=1, alpha=1)
对输入的两个三维矩阵batch1与batch2相乘并将结果与x相加。
计算公式定义如下:
.. math::
\text{out}_{i} = \beta \text{x}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})
参数:
- **x** (Tensor) - 输入Tensorshape为 :math:`(C, W, H)`
- **batch1** (Tensor) - 输入Tensorshape为 :math:`(C, W, T)`
- **batch2** (Tensor) - 输入Tensorshape为 :math:`(C, T, H)`
- **beta** (float, int) - `x` 的系数默认值为1。
- **alpha** (float, int) - :math:`batch1 @ batch2` 的系数默认值为1。
返回:
Tensor其数据类型与 `x` 相同,其维度与 `batch1@batch2` 的结果相同。
异常:
- **ValueError** - `batch1``batch2` 的不是三维Tensor。
- **TypeError** - `x``batch1``batch2` 的类型不是Tensor。
- **TypeError** - `x``batch1``batch2` 数据类型不一致。
- **TypeError** - `beta``alpha` 不是实数类型。
- **TypeError** - 如果 `x``batch1``batch2` 为整数类型, `beta``alpha` 必须是整数类型。

View File

@ -61,6 +61,7 @@
mindspore.Tensor.atan
mindspore.Tensor.atan2
mindspore.Tensor.atanh
mindspore.Tensor.baddbmm
mindspore.Tensor.bernoulli
mindspore.Tensor.bitwise_and
mindspore.Tensor.bitwise_or

View File

@ -163,6 +163,7 @@ Element-by-Element Operations
mindspore.ops.atan
mindspore.ops.atan2
mindspore.ops.atanh
mindspore.ops.baddbmm
mindspore.ops.bernoulli
mindspore.ops.bessel_i0
mindspore.ops.bessel_i0e

View File

@ -352,6 +352,7 @@ BuiltInTypeMap &GetMethodMap() {
{"atan", std::string("atan")}, // atan()
{"atanh", std::string("atanh")}, // atanh()
{"arctanh", std::string("atanh")}, // arctanh()
{"baddbmm", std::string("baddbmm")}, // baddbmm
{"bmm", std::string("bmm")}, // bmm()
{"value", std::string("value_")}, // P.Load(param, U)
{"to", std::string("to")}, // to()

View File

@ -3486,6 +3486,13 @@ def arctanh(x):
return F.atanh(x)
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
r"""
For details, please refer to :func:`mindspore.ops.baddbmm`.
"""
return F.baddbmm(x, batch1, batch2, beta=beta, alpha=alpha)
def bmm(input_x, mat2):
r"""
Computes matrix multiplication between two tensors by batch.

View File

@ -1167,6 +1167,15 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('atan2')(self, y)
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
r"""
Calculate the sum of the product of the three-dimensional matrix batch1, batch2 and the Tensor.
For details, please refer to :func:`mindspore.ops.baddbmm`.
"""
self._init_check()
return tensor_operator_registry.get('baddbmm')(self, batch1, batch2, beta=beta, alpha=alpha)
def view(self, *shape):
"""
Reshape the tensor according to the input shape. It's the same as :func:`mindspore.Tensor.reshape`,

View File

@ -7348,7 +7348,6 @@ __all__ = [
'trunc',
'gumbel_softmax',
'matmul',
'baddbmm',
'cummin',
'cummax',
'cumsum',
@ -7373,6 +7372,7 @@ __all__ = [
'remainder',
'accumulate_n',
'iou',
'baddbmm',
'bmm',
'trapz',
'cholesky',

View File

@ -128,6 +128,7 @@ tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('atan2', atan2)
tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('baddbmm', baddbmm)
tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub)

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, y, z, beta, alpha):
return x.baddbmm(y, z, beta=beta, alpha=alpha)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_baddbmm(mode):
"""
Feature: tensor.baddbmm
Description: Verify the result of baddbmm
Expectation: success
"""
ms.set_context(mode=mode)
arr1 = np.arange(18).astype(np.float32).reshape((2, 3, 3))
arr2 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
arr3 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
x = Tensor(arr1)
y = Tensor(arr2)
z = Tensor(arr3)
net = Net()
output = net(x, y, z, 2, 0.4)
expect_output = np.array([[[16.8000, 21.2000, 25.6000],
[51.6000, 62.4000, 73.2000],
[86.4000, 103.6000, 120.8000]],
[[380.4000, 404.0000, 427.6000],
[492.0000, 522.0000, 552.0000],
[603.6000, 640.0000, 676.4000]]])
assert np.allclose(output.asnumpy(), expect_output)