forked from mindspore-Ecosystem/mindspore
addmv、addr 接口差异优化
This commit is contained in:
parent
5c88bef26d
commit
8e4c63a07d
|
@ -5136,8 +5136,8 @@ def addmv(x, mat, vec, beta=1, alpha=1):
|
||||||
input_dtype = dtypeop(x)
|
input_dtype = dtypeop(x)
|
||||||
if not (isinstance(x, Tensor) and isinstance(mat, Tensor) and isinstance(vec, Tensor)):
|
if not (isinstance(x, Tensor) and isinstance(mat, Tensor) and isinstance(vec, Tensor)):
|
||||||
raise TypeError("For Addmv, inputs must be all tensors.")
|
raise TypeError("For Addmv, inputs must be all tensors.")
|
||||||
if not (input_dtype == dtypeop(mat) and input_dtype == dtypeop(vec)):
|
if dtypeop(mat) != dtypeop(vec):
|
||||||
raise TypeError("For Addmv, the inputs should be the same dtype.")
|
raise TypeError("For Addmv, the mat and vec should be the same dtype.")
|
||||||
_check_input_dtype("x", input_dtype,
|
_check_input_dtype("x", input_dtype,
|
||||||
[mstype.float16, mstype.float32, mstype.float64,
|
[mstype.float16, mstype.float32, mstype.float64,
|
||||||
mstype.int16, mstype.int32, mstype.int64], "Addmv")
|
mstype.int16, mstype.int32, mstype.int64], "Addmv")
|
||||||
|
@ -5226,11 +5226,11 @@ def addr(x, vec1, vec2, beta=1, alpha=1):
|
||||||
input_dtype = dtypeop(x)
|
input_dtype = dtypeop(x)
|
||||||
if not (isinstance(x, Tensor) and isinstance(vec1, Tensor) and isinstance(vec2, Tensor)):
|
if not (isinstance(x, Tensor) and isinstance(vec1, Tensor) and isinstance(vec2, Tensor)):
|
||||||
raise TypeError("For Addr, inputs must be all tensors.")
|
raise TypeError("For Addr, inputs must be all tensors.")
|
||||||
if not (input_dtype == dtypeop(vec1) and input_dtype == dtypeop(vec2)):
|
if dtypeop(vec1) != dtypeop(vec2):
|
||||||
raise TypeError("For Addr, the inputs should be the same dtype.")
|
raise TypeError("For Addr, the vec1 and vec2 should be the same dtype.")
|
||||||
_check_input_dtype("x", input_dtype,
|
_check_input_dtype("x", input_dtype,
|
||||||
[mstype.float16, mstype.float32, mstype.float64,
|
[mstype.float16, mstype.float32, mstype.float64,
|
||||||
mstype.int16, mstype.int32, mstype.int64], "Addmv")
|
mstype.int16, mstype.int32, mstype.int64], "Addr")
|
||||||
_check_attr_dtype("alpha", alpha, [int, float, bool], "Addr")
|
_check_attr_dtype("alpha", alpha, [int, float, bool], "Addr")
|
||||||
_check_attr_dtype("beta", beta, [int, float, bool], "Addr")
|
_check_attr_dtype("beta", beta, [int, float, bool], "Addr")
|
||||||
if input_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
if input_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
||||||
|
|
Loading…
Reference in New Issue