Fix addbmm

This commit is contained in:
shaojunsong 2023-01-05 17:30:24 +08:00
parent 2ef887450b
commit 2b31fcaf44
2 changed files with 6 additions and 0 deletions

View File

@ -22,4 +22,5 @@
Tensor`x` 具有相同的dtype。
异常:
- **TypeError** - 如果 `alpha``beta` 不是int或者float。
- **ValueError** - 如果 `batch1` `batch2` 不能进行批量矩阵乘法。

View File

@ -4809,6 +4809,7 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
Tensor, has the same dtype as `x`.
Raises:
TypeError: If `alpha` or beta is not an int or float.
ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication.
Supported Platforms:
@ -4827,6 +4828,10 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1):
[1285. 1377. 1469.]
[1621. 1745. 1869.]]
"""
if not isinstance(alpha, (int, float)):
raise TypeError(f"For 'addbmm', parameter 'alpha' must be an int or float, but got {type(alpha)}.")
if not isinstance(beta, (int, float)):
raise TypeError(f"For 'addbmm', parameter 'beta' must be an int or float, but got {type(beta)}.")
bmm_op = _get_cache_prim(P.BatchMatMul)()
bmm_res = bmm_op(batch1, batch2)
return beta * x + alpha * (bmm_res.sum(axis=0))