Fixed addmv tensor API.

Fixed diagonal API dtype support.
This commit is contained in:
alashkari 2022-12-28 14:43:08 -05:00
parent b05b142ac6
commit d054809185
2 changed files with 3 additions and 3 deletions

View File

@ -3503,7 +3503,7 @@ class Tensor(Tensor_):
[30. 27.]
"""
self._init_check()
return tensor_operator_registry.get('addmv')(self, mat, vec, beta=1, alpha=1)
return tensor_operator_registry.get('addmv')(self, mat, vec, beta=beta, alpha=alpha)
def asinh(self):
r"""

View File

@ -5725,11 +5725,11 @@ def diagonal(input, offset=0, dim1=0, dim2=1):
elif offset != 0:
e = e.astype(mstype.float32)
if offset > 0:
e_left = fill_op(dtype, (n, offset), 0)
e_left = fill_op(mstype.float32, (n, offset), 0)
e_right = e[..., 0:m - offset:1]
e = _get_cache_prim(P.Concat)(1)((e_left, e_right)).astype(dtype)
elif offset < 0:
e_upper = fill_op(dtype, (-offset, m), 0)
e_upper = fill_op(mstype.float32, (-offset, m), 0)
e_lower = e[0:n + offset:1, ...]
e = _get_cache_prim(P.Concat)(0)((e_upper, e_lower)).astype(dtype)
e = _get_cache_prim(P.BroadcastTo)(x_shape)(e)