Merge pull request !47568 from shaojunsong/fix/addbmm
This commit is contained in:
i-robot 2023-01-06 04:23:36 +00:00 committed by Gitee
commit 04e107ec7d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 0 deletions

View File

@ -22,4 +22,5 @@
Tensor`x` 具有相同的dtype。
异常:
- **TypeError** - 如果 `alpha``beta` 不是int或者float。
- **ValueError** - 如果 `batch1` `batch2` 不能进行批量矩阵乘法。

View File

@ -4758,6 +4758,7 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
Tensor, has the same dtype as `x`.
Raises:
TypeError: If `alpha` or beta is not an int or float.
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
Supported Platforms:
@ -4776,6 +4777,10 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
[1285. 1377. 1469.]
[1621. 1745. 1869.]]
"""
if not isinstance(alpha, (int, float)):
raise TypeError(f"For 'addbmm', parameter 'alpha' must be an int or float, but got {type(alpha)}.")
if not isinstance(beta, (int, float)):
raise TypeError(f"For 'addbmm', parameter 'beta' must be an int or float, but got {type(beta)}.")
bmm_op = _get_cache_prim(P.BatchMatMul)()
bmm_res = bmm_op(batch1, batch2)
return beta * x + alpha * (bmm_res.sum(axis=0))