diff --git a/docs/api/api_python/ops/mindspore.ops.func_addbmm.rst b/docs/api/api_python/ops/mindspore.ops.func_addbmm.rst index e9b93bd929a..4830bc79ecd 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_addbmm.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_addbmm.rst @@ -22,4 +22,5 @@ Tensor,和 `x` 具有相同的dtype。 异常: + - **TypeError** - 如果 `alpha`,`beta` 不是int或者float。 - **ValueError** - 如果 `batch1`, `batch2` 不能进行批量矩阵乘法。 diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 0b1400a72ce..048e0f460e7 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -4758,6 +4758,7 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1): Tensor, has the same dtype as `x`. Raises: + TypeError: If `alpha` or beta is not an int or float. ValueError: If `batch1`, `batch2` cannot apply batch matrix multiplication. Supported Platforms: @@ -4776,6 +4777,10 @@ def addbmm(x, batch1, batch2, *, beta=1, alpha=1): [1285. 1377. 1469.] [1621. 1745. 1869.]] """ + if not isinstance(alpha, (int, float)): + raise TypeError(f"For 'addbmm', parameter 'alpha' must be an int or float, but got {type(alpha)}.") + if not isinstance(beta, (int, float)): + raise TypeError(f"For 'addbmm', parameter 'beta' must be an int or float, but got {type(beta)}.") bmm_op = _get_cache_prim(P.BatchMatMul)() bmm_res = bmm_op(batch1, batch2) return beta * x + alpha * (bmm_res.sum(axis=0))