forked from mindspore-Ecosystem/mindspore
!49046 fix aicpu lusolve
Merge pull request !49046 from yanzhenxiang2020/br_fix_aicpu_lusolve
This commit is contained in:
commit
b330772618
|
@ -113,9 +113,8 @@ uint32_t LuSolveCpuKernel::LuSolve(CpuKernelContext &ctx, T *b_working_ptr, T *l
|
||||||
for (int64_t i = 0; i < input_0_Shape->GetDimSize(b_dim - 2); i++) {
|
for (int64_t i = 0; i < input_0_Shape->GetDimSize(b_dim - 2); i++) {
|
||||||
matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1));
|
matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1));
|
||||||
}
|
}
|
||||||
MatrixXd L = matrix_A.template triangularView<Eigen::UnitLower>();
|
MatrixXd result = matrix_A.template triangularView<Eigen::UnitLower>().solve(matrix_b);
|
||||||
MatrixXd U = matrix_A.template triangularView<Eigen::Upper>();
|
result.noalias() = matrix_A.template triangularView<Eigen::Upper>().solve(result);
|
||||||
MatrixXd result = (L * U).lu().solve(matrix_b);
|
|
||||||
for (int64_t m = 0; m < b_stride; m++) {
|
for (int64_t m = 0; m < b_stride; m++) {
|
||||||
*(output_y + a * b_stride + m) = (T2) * (result.data() + m);
|
*(output_y + a * b_stride + m) = (T2) * (result.data() + m);
|
||||||
}
|
}
|
||||||
|
@ -182,4 +181,4 @@ uint32_t LuSolveCpuKernel::LuSolveCompute(CpuKernelContext &ctx) {
|
||||||
return KERNEL_STATUS_OK;
|
return KERNEL_STATUS_OK;
|
||||||
}
|
}
|
||||||
REGISTER_CPU_KERNEL(kLuSolve, LuSolveCpuKernel);
|
REGISTER_CPU_KERNEL(kLuSolve, LuSolveCpuKernel);
|
||||||
} // namespace aicpu
|
} // namespace aicpu
|
||||||
|
|
Loading…
Reference in New Issue