!32810 [assistant][ops] Add new Baddbmm

Merge pull request !32810 from wxy220/Baddbmm
This commit is contained in:
i-robot 2022-06-13 02:52:27 +00:00 committed by Gitee
commit 247da9fbf3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 89 additions and 1 deletions

View File

@ -205,6 +205,7 @@ from .math_func import (
truncate_mod,
gumbel_softmax,
matmul,
baddbmm,
)
from .nn_func import (
adaptive_avgpool2d,

View File

@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import check_is_number
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
BesselK1e, Renorm, Lcm, Gcd)
from ...common import dtype as mstype
@ -3560,6 +3561,67 @@ def matmul(x1, x2):
return reshape_op(res, shape_out)
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
r"""
Performs a batch matrix-matrix product of matrices in batch1 and batch2. input is added to the final result.
The formula is defined as follows:
.. math::
\text{out}_{i} = \beta \text{input}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})
Args:
x (Tensor): The tensor to be added.
batch1 (Tensor): The first batch of matrices to be multiplied.
batch2 (Tensor): The second batch of matrices to be multiplied.
alpha: multiplier for `batch1 @ batch2`.
beta: multiplier for input.
Returns:
Tensor, the output tensor.
Raises:
TypeError: For Baddbmm, inputs are not tensors.
ValueError: If `batch1` and `batch2` are not 3-D tensors.
TypeError: For inputs of type FloatTensor or DoubleTensor, \
arguments beta and alpha not be real numbers, otherwise not be integers.
ValueError: For Baddbmm, attributes alpha and beta are not real numbers
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input = Tensor(np.ones([1, 3, 3]).astype(np.float32))
>>> batch1 = Tensor(np.ones([1, 3, 4]).astype(np.float32))
>>> batch2 = Tensor(np.ones([1, 4, 3]).astype(np.float32))
>>> output = ops.baddbmm(input, batch1, batch2)
>>> print(output)
[[[5. 5. 5.]
[5. 5. 5.]
[5. 5. 5.]]]
"""
dtypeop = P.DType()
bmmop = P.BatchMatMul(False, False)
if not (isinstance(x, Tensor) and isinstance(batch1, Tensor) and isinstance(batch2, Tensor)):
raise TypeError("For Baddbmm, inputs must be all tensors.")
if len(batch1.shape) != 3 or len(batch2.shape) != 3:
raise ValueError("For batch1 and batch2 must be 3-D tensors each containing the same number of matrices, "
f"but got length of batch1:'{len(batch1.shape)}', length of batch2:'{len(batch2.shape)}'.")
input_dtype = dtypeop(x)
if not (input_dtype == dtypeop(batch1) and input_dtype == dtypeop(batch2)):
raise TypeError("For Baddbmm, the inputs should be the same dtype.")
if input_dtype in (mstype.float16, mstype.float32, mstype.float64):
if not (isinstance(alpha, (int, float)) and isinstance(beta, (int, float))):
raise TypeError("For attributes alpha and beta should be real numbers.")
check_is_number(alpha, (int, float))
check_is_number(beta, (int, float))
else:
if not (isinstance(alpha, int) and isinstance(beta, int)):
raise TypeError("For inputs of type not FloatTensor or DoubleTensor, "
"arguments beta and alpha must be integers.")
y = beta * x + alpha * (bmmop(batch1, batch2))
return y
__all__ = [
'addn',
'absolute',
@ -3666,6 +3728,7 @@ __all__ = [
'truncate_div',
'truncate_mod',
'gumbel_softmax',
'matmul'
'matmul',
'baddbmm'
]
__all__.sort()

View File

@ -491,6 +491,17 @@ class Rad2degNet(nn.Cell):
return self.rad2deg(x)
class BaddbmmNet(nn.Cell):
def __init__(self, beta=1, alpha=1):
super(BaddbmmNet, self).__init__()
self.beta = beta
self.alpha = alpha
self.baddbmm = ops.baddbmm
def construct(self, x, batch1, batch2):
return self.baddbmm(x, batch1, batch2, self.beta, self.alpha)
test_case_math_ops = [
('MatMulGrad', {
'block': GradWrap(NetWithLoss(MatMulNet())),
@ -601,6 +612,13 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]], np.float32))],
'desc_bprop': [Tensor(np.array([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]], np.float32))],
}),
('Baddbmm', {
'block': BaddbmmNet(),
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
Tensor(np.ones([1, 3, 4]).astype(np.float32)),
Tensor(np.ones([1, 4, 3]).astype(np.float32))],
'skip': ['backward']
}),
]
test_case_lists = [test_case_math_ops]
@ -696,6 +714,12 @@ raise_set = [
('Rad2deg_2_Error', {
'block': (lambda x: Rad2degNet(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.array([[3, -3], [6, -6], [1, -1]], np.int32))]}),
('Baddbmm_Error', {
'block': (BaddbmmNet(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
Tensor(np.ones([1, 3, 4]).astype(np.float32)),
[1, 2]],
'skip': ['backward']}),
]