!45954 修复函数接口lu_unpack对Primitive接口调用错误

Merge pull request !45954 from hedongdong/fix_luunpack
This commit is contained in:
i-robot 2022-11-25 02:32:41 +00:00 committed by Gitee
commit 5b3dbdd137
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 2 deletions

View File

@ -169,6 +169,7 @@ exp2_ = P.Pow()
truncate_div_ = P.TruncateDiv()
truncate_mod_ = P.TruncateMod()
sparse_segment_mean_ = SparseSegmentMean()
lu_unpack_ = LuUnpack()
xlogy_ = P.Xlogy()
square_ = P.Square()
sqrt_ = P.Sqrt()
@ -5505,7 +5506,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
... [ 0.1015, -0.5363, 0.6165]]]), mstype.float64)
>>> LU_pivots = Tensor(np.array([[1, 3, 3],
... [2, 3, 3]]), mstype.int32)
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots)
>>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots)
>>> print(pivots)
[[[1. 0. 0.]
[0. 0. 1.]
@ -5528,7 +5529,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
[ 0. -0.4779 0.6701]
[ 0. 0. 0.6165]]]
"""
pivots, l, u = LuUnpack(LU_data, LU_pivots)
pivots, l, u = lu_unpack_(LU_data, LU_pivots)
if unpack_data:
if unpack_pivots:
return pivots, l, u