!45954 修复函数接口lu_unpack对Primitive接口调用错误
Merge pull request !45954 from hedongdong/fix_luunpack
This commit is contained in:
commit
5b3dbdd137
|
@ -169,6 +169,7 @@ exp2_ = P.Pow()
|
|||
truncate_div_ = P.TruncateDiv()
|
||||
truncate_mod_ = P.TruncateMod()
|
||||
sparse_segment_mean_ = SparseSegmentMean()
|
||||
lu_unpack_ = LuUnpack()
|
||||
xlogy_ = P.Xlogy()
|
||||
square_ = P.Square()
|
||||
sqrt_ = P.Sqrt()
|
||||
|
@ -5505,7 +5506,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|||
... [ 0.1015, -0.5363, 0.6165]]]), mstype.float64)
|
||||
>>> LU_pivots = Tensor(np.array([[1, 3, 3],
|
||||
... [2, 3, 3]]), mstype.int32)
|
||||
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots)
|
||||
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots)
|
||||
>>> print(pivots)
|
||||
[[[1. 0. 0.]
|
||||
[0. 0. 1.]
|
||||
|
@ -5528,7 +5529,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|||
[ 0. -0.4779 0.6701]
|
||||
[ 0. 0. 0.6165]]]
|
||||
"""
|
||||
pivots, l, u = LuUnpack(LU_data, LU_pivots)
|
||||
pivots, l, u = lu_unpack_(LU_data, LU_pivots)
|
||||
if unpack_data:
|
||||
if unpack_pivots:
|
||||
return pivots, l, u
|
||||
|
|
Loading…
Reference in New Issue