[OCCM] fix function lu_unpack
This commit is contained in:
parent
c0445d8a4e
commit
bb2bfb161d
|
@ -168,6 +168,7 @@ exp2_ = P.Pow()
|
|||
truncate_div_ = P.TruncateDiv()
|
||||
truncate_mod_ = P.TruncateMod()
|
||||
sparse_segment_mean_ = SparseSegmentMean()
|
||||
lu_unpack_ = LuUnpack()
|
||||
xlogy_ = P.Xlogy()
|
||||
square_ = P.Square()
|
||||
sqrt_ = P.Sqrt()
|
||||
|
@ -5459,7 +5460,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|||
... [ 0.1015, -0.5363, 0.6165]]]), mstype.float64)
|
||||
>>> LU_pivots = Tensor(np.array([[1, 3, 3],
|
||||
... [2, 3, 3]]), mstype.int32)
|
||||
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots)
|
||||
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots)
|
||||
>>> print(pivots)
|
||||
[[[1. 0. 0.]
|
||||
[0. 0. 1.]
|
||||
|
@ -5482,7 +5483,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|||
[ 0. -0.4779 0.6701]
|
||||
[ 0. 0. 0.6165]]]
|
||||
"""
|
||||
pivots, l, u = LuUnpack(LU_data, LU_pivots)
|
||||
pivots, l, u = lu_unpack_(LU_data, LU_pivots)
|
||||
if unpack_data:
|
||||
if unpack_pivots:
|
||||
return pivots, l, u
|
||||
|
|
Loading…
Reference in New Issue