[OCCM] fix function lu_unpack

This commit is contained in:
hedongdong 2022-11-23 21:49:11 +08:00
parent c0445d8a4e
commit bb2bfb161d
1 changed files with 3 additions and 2 deletions

View File

@ -168,6 +168,7 @@ exp2_ = P.Pow()
truncate_div_ = P.TruncateDiv()
truncate_mod_ = P.TruncateMod()
sparse_segment_mean_ = SparseSegmentMean()
lu_unpack_ = LuUnpack()
xlogy_ = P.Xlogy()
square_ = P.Square()
sqrt_ = P.Sqrt()
@ -5459,7 +5460,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
... [ 0.1015, -0.5363, 0.6165]]]), mstype.float64)
>>> LU_pivots = Tensor(np.array([[1, 3, 3],
... [2, 3, 3]]), mstype.int32)
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots)
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots)
>>> print(pivots)
[[[1. 0. 0.]
[0. 0. 1.]
@ -5482,7 +5483,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
[ 0. -0.4779 0.6701]
[ 0. 0. 0.6165]]]
"""
pivots, l, u = LuUnpack(LU_data, LU_pivots)
pivots, l, u = lu_unpack_(LU_data, LU_pivots)
if unpack_data:
if unpack_pivots:
return pivots, l, u