From 2996b8d7f0c416cbeb0e30faa03a19e3e84025ce Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Fri, 17 Feb 2023 14:08:30 +0800 Subject: [PATCH] fix aicpu lusolve --- .../aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_solve.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_solve.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_solve.cc index d1e70be960b..09d061d28c4 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_solve.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_solve.cc @@ -113,9 +113,8 @@ uint32_t LuSolveCpuKernel::LuSolve(CpuKernelContext &ctx, T *b_working_ptr, T *l for (int64_t i = 0; i < input_0_Shape->GetDimSize(b_dim - 2); i++) { matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1)); } - MatrixXd L = matrix_A.template triangularView(); - MatrixXd U = matrix_A.template triangularView(); - MatrixXd result = (L * U).lu().solve(matrix_b); + MatrixXd result = matrix_A.template triangularView().solve(matrix_b); + result.noalias() = matrix_A.template triangularView().solve(result); for (int64_t m = 0; m < b_stride; m++) { *(output_y + a * b_stride + m) = (T2) * (result.data() + m); } @@ -182,4 +181,4 @@ uint32_t LuSolveCpuKernel::LuSolveCompute(CpuKernelContext &ctx) { return KERNEL_STATUS_OK; } REGISTER_CPU_KERNEL(kLuSolve, LuSolveCpuKernel); -} // namespace aicpu \ No newline at end of file +} // namespace aicpu