forked from mindspore-Ecosystem/mindspore
!32810 [assistant][ops] Add new Baddbmm
Merge pull request !32810 from wxy220/Baddbmm
This commit is contained in:
commit
247da9fbf3
|
@ -205,6 +205,7 @@ from .math_func import (
|
|||
truncate_mod,
|
||||
gumbel_softmax,
|
||||
matmul,
|
||||
baddbmm,
|
||||
)
|
||||
from .nn_func import (
|
||||
adaptive_avgpool2d,
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import check_is_number
|
||||
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
|
||||
BesselK1e, Renorm, Lcm, Gcd)
|
||||
from ...common import dtype as mstype
|
||||
|
@ -3560,6 +3561,67 @@ def matmul(x1, x2):
|
|||
return reshape_op(res, shape_out)
|
||||
|
||||
|
||||
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
|
||||
r"""
|
||||
Performs a batch matrix-matrix product of matrices in batch1 and batch2. input is added to the final result.
|
||||
The formula is defined as follows:
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \beta \text{input}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor to be added.
|
||||
batch1 (Tensor): The first batch of matrices to be multiplied.
|
||||
batch2 (Tensor): The second batch of matrices to be multiplied.
|
||||
alpha: multiplier for `batch1 @ batch2`.
|
||||
beta: multiplier for input.
|
||||
|
||||
Returns:
|
||||
Tensor, the output tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: For Baddbmm, inputs are not tensors.
|
||||
ValueError: If `batch1` and `batch2` are not 3-D tensors.
|
||||
TypeError: For inputs of type FloatTensor or DoubleTensor, \
|
||||
arguments beta and alpha not be real numbers, otherwise not be integers.
|
||||
ValueError: For Baddbmm, attributes alpha and beta are not real numbers
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.ones([1, 3, 3]).astype(np.float32))
|
||||
>>> batch1 = Tensor(np.ones([1, 3, 4]).astype(np.float32))
|
||||
>>> batch2 = Tensor(np.ones([1, 4, 3]).astype(np.float32))
|
||||
>>> output = ops.baddbmm(input, batch1, batch2)
|
||||
>>> print(output)
|
||||
[[[5. 5. 5.]
|
||||
[5. 5. 5.]
|
||||
[5. 5. 5.]]]
|
||||
"""
|
||||
dtypeop = P.DType()
|
||||
bmmop = P.BatchMatMul(False, False)
|
||||
if not (isinstance(x, Tensor) and isinstance(batch1, Tensor) and isinstance(batch2, Tensor)):
|
||||
raise TypeError("For Baddbmm, inputs must be all tensors.")
|
||||
if len(batch1.shape) != 3 or len(batch2.shape) != 3:
|
||||
raise ValueError("For batch1 and batch2 must be 3-D tensors each containing the same number of matrices, "
|
||||
f"but got length of batch1:'{len(batch1.shape)}', length of batch2:'{len(batch2.shape)}'.")
|
||||
input_dtype = dtypeop(x)
|
||||
if not (input_dtype == dtypeop(batch1) and input_dtype == dtypeop(batch2)):
|
||||
raise TypeError("For Baddbmm, the inputs should be the same dtype.")
|
||||
if input_dtype in (mstype.float16, mstype.float32, mstype.float64):
|
||||
if not (isinstance(alpha, (int, float)) and isinstance(beta, (int, float))):
|
||||
raise TypeError("For attributes alpha and beta should be real numbers.")
|
||||
check_is_number(alpha, (int, float))
|
||||
check_is_number(beta, (int, float))
|
||||
else:
|
||||
if not (isinstance(alpha, int) and isinstance(beta, int)):
|
||||
raise TypeError("For inputs of type not FloatTensor or DoubleTensor, "
|
||||
"arguments beta and alpha must be integers.")
|
||||
y = beta * x + alpha * (bmmop(batch1, batch2))
|
||||
return y
|
||||
|
||||
|
||||
__all__ = [
|
||||
'addn',
|
||||
'absolute',
|
||||
|
@ -3666,6 +3728,7 @@ __all__ = [
|
|||
'truncate_div',
|
||||
'truncate_mod',
|
||||
'gumbel_softmax',
|
||||
'matmul'
|
||||
'matmul',
|
||||
'baddbmm'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -491,6 +491,17 @@ class Rad2degNet(nn.Cell):
|
|||
return self.rad2deg(x)
|
||||
|
||||
|
||||
class BaddbmmNet(nn.Cell):
|
||||
def __init__(self, beta=1, alpha=1):
|
||||
super(BaddbmmNet, self).__init__()
|
||||
self.beta = beta
|
||||
self.alpha = alpha
|
||||
self.baddbmm = ops.baddbmm
|
||||
|
||||
def construct(self, x, batch1, batch2):
|
||||
return self.baddbmm(x, batch1, batch2, self.beta, self.alpha)
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('MatMulGrad', {
|
||||
'block': GradWrap(NetWithLoss(MatMulNet())),
|
||||
|
@ -601,6 +612,13 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]], np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]], np.float32))],
|
||||
}),
|
||||
('Baddbmm', {
|
||||
'block': BaddbmmNet(),
|
||||
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
|
||||
Tensor(np.ones([1, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.ones([1, 4, 3]).astype(np.float32))],
|
||||
'skip': ['backward']
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_math_ops]
|
||||
|
@ -696,6 +714,12 @@ raise_set = [
|
|||
('Rad2deg_2_Error', {
|
||||
'block': (lambda x: Rad2degNet(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.array([[3, -3], [6, -6], [1, -1]], np.int32))]}),
|
||||
('Baddbmm_Error', {
|
||||
'block': (BaddbmmNet(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.ones([1, 3, 3]).astype(np.float32)),
|
||||
Tensor(np.ones([1, 3, 4]).astype(np.float32)),
|
||||
[1, 2]],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue