From bb2bfb161d2a139c00893cf768d102825770f0f5 Mon Sep 17 00:00:00 2001 From: hedongdong Date: Wed, 23 Nov 2022 21:49:11 +0800 Subject: [PATCH] [OCCM] fix function lu_unpack --- mindspore/python/mindspore/ops/function/math_func.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index e1bb75c7140..49e16aad364 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -168,6 +168,7 @@ exp2_ = P.Pow() truncate_div_ = P.TruncateDiv() truncate_mod_ = P.TruncateMod() sparse_segment_mean_ = SparseSegmentMean() +lu_unpack_ = LuUnpack() xlogy_ = P.Xlogy() square_ = P.Square() sqrt_ = P.Sqrt() @@ -5459,7 +5460,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): ... [ 0.1015, -0.5363, 0.6165]]]), mstype.float64) >>> LU_pivots = Tensor(np.array([[1, 3, 3], ... [2, 3, 3]]), mstype.int32) - >>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots) + >>> pivots, L, U = F.lu_unpack(LU_data, LU_pivots) >>> print(pivots) [[[1. 0. 0.] [0. 0. 1.] @@ -5482,7 +5483,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): [ 0. -0.4779 0.6701] [ 0. 0. 0.6165]]] """ - pivots, l, u = LuUnpack(LU_data, LU_pivots) + pivots, l, u = lu_unpack_(LU_data, LU_pivots) if unpack_data: if unpack_pivots: return pivots, l, u