forked from mindspore-Ecosystem/mindspore
!47568 Fix addbmm
Merge pull request !47568 from shaojunsong/fix/addbmm
This commit is contained in:
commit
04e107ec7d
|
@ -22,4 +22,5 @@
|
||||||
Tensor,和 `x` 具有相同的dtype。
|
Tensor,和 `x` 具有相同的dtype。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
|
- **TypeError** - 如果 `alpha`,`beta` 不是int或者float。
|
||||||
- **ValueError** - 如果 `batch1`, `batch2` 不能进行批量矩阵乘法。
|
- **ValueError** - 如果 `batch1`, `batch2` 不能进行批量矩阵乘法。
|
||||||
|
|
|
@ -4758,6 +4758,7 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
|
||||||
Tensor, has the same dtype as `x`.
|
Tensor, has the same dtype as `x`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
TypeError: If `alpha` or beta is not an int or float.
|
||||||
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
|
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
|
@ -4776,6 +4777,10 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
|
||||||
[1285. 1377. 1469.]
|
[1285. 1377. 1469.]
|
||||||
[1621. 1745. 1869.]]
|
[1621. 1745. 1869.]]
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(alpha, (int, float)):
|
||||||
|
raise TypeError(f"For 'addbmm', parameter 'alpha' must be an int or float, but got {type(alpha)}.")
|
||||||
|
if not isinstance(beta, (int, float)):
|
||||||
|
raise TypeError(f"For 'addbmm', parameter 'beta' must be an int or float, but got {type(beta)}.")
|
||||||
bmm_op = _get_cache_prim(P.BatchMatMul)()
|
bmm_op = _get_cache_prim(P.BatchMatMul)()
|
||||||
bmm_res = bmm_op(batch1, batch2)
|
bmm_res = bmm_op(batch1, batch2)
|
||||||
return beta * x + alpha * (bmm_res.sum(axis=0))
|
return beta * x + alpha * (bmm_res.sum(axis=0))
|
||||||
|
|
Loading…
Reference in New Issue