!45420 Add tensor.baddmm, doc and st test
Merge pull request !45420 from DavidFFFan/api_tensor
This commit is contained in:
commit
5ad5f358d6
|
@ -162,6 +162,7 @@ mindspore.ops.function
|
|||
mindspore.ops.atan
|
||||
mindspore.ops.atan2
|
||||
mindspore.ops.atanh
|
||||
mindspore.ops.baddbmm
|
||||
mindspore.ops.bernoulli
|
||||
mindspore.ops.bessel_i0
|
||||
mindspore.ops.bessel_i0e
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
mindspore.Tensor.baddbmm
|
||||
========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.baddbmm(batch1, batch2, beta=1, alpha=1)
|
||||
|
||||
计算三维矩阵batch1、batch2的乘积与Tensor的和。
|
||||
|
||||
详情请参考 :func:`mindspore.ops.baddbmm`。
|
|
@ -55,6 +55,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.atan
|
||||
mindspore.Tensor.atan2
|
||||
mindspore.Tensor.atanh
|
||||
mindspore.Tensor.baddbmm
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
mindspore.Tensor.bitwise_or
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
mindspore.ops.baddbmm
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.baddbmm(x, batch1, batch2, beta=1, alpha=1)
|
||||
|
||||
对输入的两个三维矩阵batch1与batch2相乘,并将结果与x相加。
|
||||
计算公式定义如下:
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \beta \text{x}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor,shape为 :math:`(C, W, H)` 。
|
||||
- **batch1** (Tensor) - 输入Tensor,shape为 :math:`(C, W, T)` 。
|
||||
- **batch2** (Tensor) - 输入Tensor,shape为 :math:`(C, T, H)` 。
|
||||
- **beta** (float, int) - `x` 的系数,默认值为1。
|
||||
- **alpha** (float, int) - :math:`batch1 @ batch2` 的系数,默认值为1。
|
||||
|
||||
返回:
|
||||
Tensor,其数据类型与 `x` 相同,其维度与 `batch1@batch2` 的结果相同。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `batch1` 或 `batch2` 的不是三维Tensor。
|
||||
- **TypeError** - `x` 、 `batch1` 或 `batch2` 的类型不是Tensor。
|
||||
- **TypeError** - `x` 、 `batch1` 或 `batch2` 数据类型不一致。
|
||||
- **TypeError** - `beta` 、 `alpha` 不是实数类型。
|
||||
- **TypeError** - 如果 `x` 、 `batch1` 和 `batch2` 为整数类型, `beta` 、 `alpha` 必须是整数类型。
|
|
@ -61,6 +61,7 @@
|
|||
mindspore.Tensor.atan
|
||||
mindspore.Tensor.atan2
|
||||
mindspore.Tensor.atanh
|
||||
mindspore.Tensor.baddbmm
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
mindspore.Tensor.bitwise_or
|
||||
|
|
|
@ -163,6 +163,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.atan
|
||||
mindspore.ops.atan2
|
||||
mindspore.ops.atanh
|
||||
mindspore.ops.baddbmm
|
||||
mindspore.ops.bernoulli
|
||||
mindspore.ops.bessel_i0
|
||||
mindspore.ops.bessel_i0e
|
||||
|
|
|
@ -352,6 +352,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"atan", std::string("atan")}, // atan()
|
||||
{"atanh", std::string("atanh")}, // atanh()
|
||||
{"arctanh", std::string("atanh")}, // arctanh()
|
||||
{"baddbmm", std::string("baddbmm")}, // baddbmm
|
||||
{"bmm", std::string("bmm")}, // bmm()
|
||||
{"value", std::string("value_")}, // P.Load(param, U)
|
||||
{"to", std::string("to")}, // to()
|
||||
|
|
|
@ -3486,6 +3486,13 @@ def arctanh(x):
|
|||
return F.atanh(x)
|
||||
|
||||
|
||||
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.baddbmm`.
|
||||
"""
|
||||
return F.baddbmm(x, batch1, batch2, beta=beta, alpha=alpha)
|
||||
|
||||
|
||||
def bmm(input_x, mat2):
|
||||
r"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
|
|
|
@ -1167,6 +1167,15 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('atan2')(self, y)
|
||||
|
||||
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
|
||||
r"""
|
||||
Calculate the sum of the product of the three-dimensional matrix batch1, batch2 and the Tensor.
|
||||
|
||||
For details, please refer to :func:`mindspore.ops.baddbmm`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('baddbmm')(self, batch1, batch2, beta=beta, alpha=alpha)
|
||||
|
||||
def view(self, *shape):
|
||||
"""
|
||||
Reshape the tensor according to the input shape. It's the same as :func:`mindspore.Tensor.reshape`,
|
||||
|
|
|
@ -7348,7 +7348,6 @@ __all__ = [
|
|||
'trunc',
|
||||
'gumbel_softmax',
|
||||
'matmul',
|
||||
'baddbmm',
|
||||
'cummin',
|
||||
'cummax',
|
||||
'cumsum',
|
||||
|
@ -7373,6 +7372,7 @@ __all__ = [
|
|||
'remainder',
|
||||
'accumulate_n',
|
||||
'iou',
|
||||
'baddbmm',
|
||||
'bmm',
|
||||
'trapz',
|
||||
'cholesky',
|
||||
|
|
|
@ -128,6 +128,7 @@ tensor_operator_registry.register('all', P.ReduceAll)
|
|||
tensor_operator_registry.register('any', P.ReduceAny)
|
||||
tensor_operator_registry.register('atan2', atan2)
|
||||
tensor_operator_registry.register('abs', P.Abs)
|
||||
tensor_operator_registry.register('baddbmm', baddbmm)
|
||||
tensor_operator_registry.register('sqrt', sqrt)
|
||||
tensor_operator_registry.register('square', square)
|
||||
tensor_operator_registry.register('sub', sub)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y, z, beta, alpha):
|
||||
return x.baddbmm(y, z, beta=beta, alpha=alpha)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_baddbmm(mode):
|
||||
"""
|
||||
Feature: tensor.baddbmm
|
||||
Description: Verify the result of baddbmm
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
arr1 = np.arange(18).astype(np.float32).reshape((2, 3, 3))
|
||||
arr2 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
|
||||
arr3 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
|
||||
x = Tensor(arr1)
|
||||
y = Tensor(arr2)
|
||||
z = Tensor(arr3)
|
||||
net = Net()
|
||||
output = net(x, y, z, 2, 0.4)
|
||||
expect_output = np.array([[[16.8000, 21.2000, 25.6000],
|
||||
[51.6000, 62.4000, 73.2000],
|
||||
[86.4000, 103.6000, 120.8000]],
|
||||
[[380.4000, 404.0000, 427.6000],
|
||||
[492.0000, 522.0000, 552.0000],
|
||||
[603.6000, 640.0000, 676.4000]]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue