From 9d13db2ac3e108ffbccdbe41699e4c5cf776d6fc Mon Sep 17 00:00:00 2001 From: z00512249 Date: Mon, 1 Nov 2021 20:09:19 +0800 Subject: [PATCH] fix cholesky factor and solve for cpu backend --- .../cpu/eigen/cholesky_cpu_kernel.cc | 19 +++++---- .../cpu/eigen/cholesky_solve_cpu_kernel.cc | 27 +++++++----- .../cpu/eigen/cholesky_solve_cpu_kernel.h | 1 + tests/st/ops/cpu/test_cholesky_op.py | 41 ++++++------------- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc index 4ed132fc292..b2a8d4b7d76 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc @@ -16,6 +16,7 @@ #include "backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h" #include +#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" #include "utils/ms_utils.h" #include "Eigen/Dense" #include "Eigen/Cholesky" @@ -46,7 +47,6 @@ void CholeskyCPUKernel::InitMatrixInfo(const std::vector &shape, size if (*row != *col) { MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col; } - return; } template @@ -70,15 +70,12 @@ template bool CholeskyCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { T *input_value = reinterpret_cast(inputs[kInputIndex]->addr); - Eigen::Map> input(input_value, input_row_, - input_col_); + Map> input(input_value, input_row_, input_col_); T *output_value = reinterpret_cast(outputs[kOutputIndex]->addr); - Eigen::Map> output(output_value, output_row_, - output_col_); - Eigen::LLT> llt; + Map> output(output_value, output_row_, output_col_); + Eigen::LLT> llt; llt.compute(input); - if (clean_) { if (lower_) { output = llt.matrixL(); @@ -86,12 +83,16 @@ bool CholeskyCPUKernel::Launch(const std::vector &inputs, const s output = llt.matrixU(); } } else { - output = llt.matrixLLT(); + if (lower_) { + output = llt.matrixLLT(); + } else { + output = llt.matrixLLT().transpose(); + } } if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) { return true; } - MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid."; + MS_LOG_EXCEPTION << kernel_name_ << " output cholesky shape invalid."; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc index d0ffd417501..a3a85da6f19 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc @@ -15,6 +15,7 @@ */ #include "backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" #include "Eigen/Dense" #include "Eigen/Cholesky" namespace mindspore { @@ -41,7 +42,6 @@ void CholeskySolverCPUKernel::InitMatrixInfo(const std::vector &shape *row = shape.at(shape.size() - kRowIndex); *col = shape.at(shape.size() - kColIndex); } - return; } template @@ -59,30 +59,37 @@ void CholeskySolverCPUKernel::InitKernel(const CNodePtr &kernel_node) { InitMatrixInfo(input_b_shape, &input_b_row_, &input_b_col_); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex); InitMatrixInfo(output_shape, &output_row_, &output_col_); + lower_ = AnfAlgo ::GetNodeAttr(kernel_node, LOWER); + if (input_a_row_ != input_b_row_) { + MS_LOG_EXCEPTION << kernel_name_ << "llt solve input row is not equal to b row: " << input_a_row_ << " vs " + << input_b_row_; + } } template bool CholeskySolverCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { T *input_value = reinterpret_cast(inputs[kInputAIndex]->addr); - Eigen::Map> input(input_value, input_a_row_, - input_a_col_); + Map> input(input_value, input_a_row_, input_a_col_); T *input_b_value = reinterpret_cast(inputs[kInputBIndex]->addr); - Eigen::Map> input_b(input_b_value, input_b_row_, - input_b_col_); + Map> input_b(input_b_value, input_b_row_, input_b_col_); T *output_value = reinterpret_cast(outputs[kOutputIndex]->addr); - Eigen::Map> output(output_value, output_row_, - output_col_); - Eigen::LLT> llt(input); + Map> output(output_value, output_row_, output_col_); - output = llt.solve(input_b); + if (lower_) { + output.noalias() = input.template triangularView().solve(input_b); + input.adjoint().template triangularView().solveInPlace(output); + } else { + output.noalias() = input.adjoint().template triangularView().solve(input_b); + input.template triangularView().solveInPlace(output); + } if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) { return true; } - MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid."; + MS_LOG_EXCEPTION << kernel_name_ << " output cholesky solve shape invalid."; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h index a0b2a154fef..67823428cea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h @@ -42,6 +42,7 @@ class CholeskySolverCPUKernel : public CPUKernel { size_t output_row_{1}; size_t output_col_{1}; TypeId dtype_{kNumberTypeFloat32}; + bool lower_{false}; }; MS_REG_CPU_KERNEL_T( diff --git a/tests/st/ops/cpu/test_cholesky_op.py b/tests/st/ops/cpu/test_cholesky_op.py index bd68c9852ab..8cbf02f9959 100644 --- a/tests/st/ops/cpu/test_cholesky_op.py +++ b/tests/st/ops/cpu/test_cholesky_op.py @@ -16,28 +16,12 @@ class Cholesky(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, lower=False, clean=False, split_dim=0): + def __init__(self, lower=False, clean=True): super().__init__(name="Cholesky") self.lower = validator.check_value_type("lower", lower, [bool], self.lower) self.clean = validator.check_value_type("clean", clean, [bool], self.clean) - self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim) self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, x_shape): - if self.split_dim != 0: - height = x_shape[0] - width = x_shape[1] - if height <= self.split_dim: - out_shape = [1, height, width] - else: - batch = height // self.split_dim - if height != batch * self.split_dim: - batch += 1 - out_shape = [batch, self.split_dim, self.split_dim] - else: - out_shape = x_shape - return out_shape - def __infer__(self, x): x_shape = x['shape'] x_dtype = x['dtype'] @@ -54,10 +38,9 @@ class CholeskySolver(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, lower=False, split_dim=0): + def __init__(self, lower=False): super().__init__(name="CholeskySolver") self.lower = validator.check_value_type("lower", lower, [bool], self.lower) - self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim) self.init_prim_io_names(inputs=['x'], outputs=['y']) def __infer__(self, x, b): @@ -71,18 +54,18 @@ class CholeskySolver(PrimitiveWithInfer): class CholeskyNet(nn.Cell): - def __init__(self, lower=False, clean=False, split_dim=0): + def __init__(self, lower=False, clean=False): super(CholeskyNet, self).__init__() - self.cholesky = Cholesky(lower, clean, split_dim) + self.cholesky = Cholesky(lower, clean) def construct(self, x): return self.cholesky(x) class CholeskySolverNet(nn.Cell): - def __init__(self, lower=False, split_dim=0): + def __init__(self, lower=False): super(CholeskySolverNet, self).__init__() - self.cholesky_solver = CholeskySolver(lower, split_dim) + self.cholesky_solver = CholeskySolver(lower) def construct(self, c, b): return self.cholesky_solver(c, b) @@ -167,10 +150,12 @@ def test_cholesky_solver(): b = np.array([1, 1, 1, 1], dtype=np.float32) tensor_a = Tensor(a) tensor_b = Tensor(b) - scp_c, lower = scp.linalg.cho_factor(a, lower=True) - scp_x = scp.linalg.cho_solve((scp_c, lower), b) - - mscp_c, mscp_lower = cho_factor(tensor_a, lower=True) - mscp_x = cho_solve((tensor_a, mscp_lower), tensor_b) + scp_c, lower = scp.linalg.cho_factor(a, lower=False) + mscp_c, mscp_lower = cho_factor(tensor_a, lower=False) assert np.allclose(scp_c, mscp_c.asnumpy()) + + scp_factor = (scp_c, lower) + ms_cho_factor = (mscp_c, mscp_lower) + scp_x = scp.linalg.cho_solve(scp_factor, b) + mscp_x = cho_solve(ms_cho_factor, tensor_b) assert np.allclose(scp_x, mscp_x.asnumpy())