addmv、addr 接口差异优化

This commit is contained in:
gaoshuanglong 2023-02-18 18:09:47 +08:00
parent 5c88bef26d
commit 8e4c63a07d
1 changed files with 5 additions and 5 deletions

View File

@ -5136,8 +5136,8 @@ def addmv(x, mat, vec, beta=1, alpha=1):
input_dtype = dtypeop(x) input_dtype = dtypeop(x)
if not (isinstance(x, Tensor) and isinstance(mat, Tensor) and isinstance(vec, Tensor)): if not (isinstance(x, Tensor) and isinstance(mat, Tensor) and isinstance(vec, Tensor)):
raise TypeError("For Addmv, inputs must be all tensors.") raise TypeError("For Addmv, inputs must be all tensors.")
if not (input_dtype == dtypeop(mat) and input_dtype == dtypeop(vec)): if dtypeop(mat) != dtypeop(vec):
raise TypeError("For Addmv, the inputs should be the same dtype.") raise TypeError("For Addmv, the mat and vec should be the same dtype.")
_check_input_dtype("x", input_dtype, _check_input_dtype("x", input_dtype,
[mstype.float16, mstype.float32, mstype.float64, [mstype.float16, mstype.float32, mstype.float64,
mstype.int16, mstype.int32, mstype.int64], "Addmv") mstype.int16, mstype.int32, mstype.int64], "Addmv")
@ -5226,11 +5226,11 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
input_dtype = dtypeop(x) input_dtype = dtypeop(x)
if not (isinstance(x, Tensor) and isinstance(vec1, Tensor) and isinstance(vec2, Tensor)): if not (isinstance(x, Tensor) and isinstance(vec1, Tensor) and isinstance(vec2, Tensor)):
raise TypeError("For Addr, inputs must be all tensors.") raise TypeError("For Addr, inputs must be all tensors.")
if not (input_dtype == dtypeop(vec1) and input_dtype == dtypeop(vec2)): if dtypeop(vec1) != dtypeop(vec2):
raise TypeError("For Addr, the inputs should be the same dtype.") raise TypeError("For Addr, the vec1 and vec2 should be the same dtype.")
_check_input_dtype("x", input_dtype, _check_input_dtype("x", input_dtype,
[mstype.float16, mstype.float32, mstype.float64, [mstype.float16, mstype.float32, mstype.float64,
mstype.int16, mstype.int32, mstype.int64], "Addmv") mstype.int16, mstype.int32, mstype.int64], "Addr")
_check_attr_dtype("alpha", alpha, [int, float, bool], "Addr") _check_attr_dtype("alpha", alpha, [int, float, bool], "Addr")
_check_attr_dtype("beta", beta, [int, float, bool], "Addr") _check_attr_dtype("beta", beta, [int, float, bool], "Addr")
if input_dtype in (mstype.int16, mstype.int32, mstype.int64): if input_dtype in (mstype.int16, mstype.int32, mstype.int64):