forked from mindspore-Ecosystem/mindspore
!47568 Fix addbmm
Merge pull request !47568 from shaojunsong/fix/addbmm
This commit is contained in:
commit
04e107ec7d
|
@ -22,4 +22,5 @@
|
|||
Tensor,和 `x` 具有相同的dtype。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `alpha`,`beta` 不是int或者float。
|
||||
- **ValueError** - 如果 `batch1`, `batch2` 不能进行批量矩阵乘法。
|
||||
|
|
|
@ -4758,6 +4758,7 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
|
|||
Tensor, has the same dtype as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `alpha` or beta is not an int or float.
|
||||
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
|
||||
|
||||
Supported Platforms:
|
||||
|
@ -4776,6 +4777,10 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
|
|||
[1285. 1377. 1469.]
|
||||
[1621. 1745. 1869.]]
|
||||
"""
|
||||
if not isinstance(alpha, (int, float)):
|
||||
raise TypeError(f"For 'addbmm', parameter 'alpha' must be an int or float, but got {type(alpha)}.")
|
||||
if not isinstance(beta, (int, float)):
|
||||
raise TypeError(f"For 'addbmm', parameter 'beta' must be an int or float, but got {type(beta)}.")
|
||||
bmm_op = _get_cache_prim(P.BatchMatMul)()
|
||||
bmm_res = bmm_op(batch1, batch2)
|
||||
return beta * x + alpha * (bmm_res.sum(axis=0))
|
||||
|
|
Loading…
Reference in New Issue