!49046 fix aicpu lusolve

Merge pull request !49046 from yanzhenxiang2020/br_fix_aicpu_lusolve
This commit is contained in:
i-robot 2023-02-18 07:14:02 +00:00 committed by Gitee
commit b330772618
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 4 deletions

View File

@ -113,9 +113,8 @@ uint32_t LuSolveCpuKernel::LuSolve(CpuKernelContext &ctx, T *b_working_ptr, T *l
for (int64_t i = 0; i < input_0_Shape->GetDimSize(b_dim - 2); i++) { for (int64_t i = 0; i < input_0_Shape->GetDimSize(b_dim - 2); i++) {
matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1)); matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1));
} }
MatrixXd L = matrix_A.template triangularView<Eigen::UnitLower>(); MatrixXd result = matrix_A.template triangularView<Eigen::UnitLower>().solve(matrix_b);
MatrixXd U = matrix_A.template triangularView<Eigen::Upper>(); result.noalias() = matrix_A.template triangularView<Eigen::Upper>().solve(result);
MatrixXd result = (L * U).lu().solve(matrix_b);
for (int64_t m = 0; m < b_stride; m++) { for (int64_t m = 0; m < b_stride; m++) {
*(output_y + a * b_stride + m) = (T2) * (result.data() + m); *(output_y + a * b_stride + m) = (T2) * (result.data() + m);
} }
@ -182,4 +181,4 @@ uint32_t LuSolveCpuKernel::LuSolveCompute(CpuKernelContext &ctx) {
return KERNEL_STATUS_OK; return KERNEL_STATUS_OK;
} }
REGISTER_CPU_KERNEL(kLuSolve, LuSolveCpuKernel); REGISTER_CPU_KERNEL(kLuSolve, LuSolveCpuKernel);
} // namespace aicpu } // namespace aicpu