fix cholesky factor and solve for cpu backend

This commit is contained in:
z00512249 2021-11-01 20:09:19 +08:00
parent eff2318bb6
commit 9d13db2ac3
4 changed files with 41 additions and 47 deletions

View File

@ -16,6 +16,7 @@
#include "backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h"
#include <vector>
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
@ -46,7 +47,6 @@ void CholeskyCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size
if (*row != *col) {
MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col;
}
return;
}
template <typename T>
@ -70,15 +70,12 @@ template <typename T>
bool CholeskyCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_row_,
input_col_);
Map<Matrix<T, RowMajor>> input(input_value, input_row_, input_col_);
T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
output_col_);
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt;
Map<Matrix<T, RowMajor>> output(output_value, output_row_, output_col_);
Eigen::LLT<Matrix<T, RowMajor>> llt;
llt.compute(input);
if (clean_) {
if (lower_) {
output = llt.matrixL();
@ -86,12 +83,16 @@ bool CholeskyCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const s
output = llt.matrixU();
}
} else {
output = llt.matrixLLT();
if (lower_) {
output = llt.matrixLLT();
} else {
output = llt.matrixLLT().transpose();
}
}
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
return true;
}
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
MS_LOG_EXCEPTION << kernel_name_ << " output cholesky shape invalid.";
}
} // namespace kernel
} // namespace mindspore

View File

@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
namespace mindspore {
@ -41,7 +42,6 @@ void CholeskySolverCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
}
return;
}
template <typename T>
@ -59,30 +59,37 @@ void CholeskySolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
InitMatrixInfo(input_b_shape, &input_b_row_, &input_b_col_);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
InitMatrixInfo(output_shape, &output_row_, &output_col_);
lower_ = AnfAlgo ::GetNodeAttr<bool>(kernel_node, LOWER);
if (input_a_row_ != input_b_row_) {
MS_LOG_EXCEPTION << kernel_name_ << "llt solve input row is not equal to b row: " << input_a_row_ << " vs "
<< input_b_row_;
}
}
template <typename T>
bool CholeskySolverCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_a_row_,
input_a_col_);
Map<Matrix<T, RowMajor>> input(input_value, input_a_row_, input_a_col_);
T *input_b_value = reinterpret_cast<T *>(inputs[kInputBIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(input_b_value, input_b_row_,
input_b_col_);
Map<Matrix<T, RowMajor>> input_b(input_b_value, input_b_row_, input_b_col_);
T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
output_col_);
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt(input);
Map<Matrix<T, RowMajor>> output(output_value, output_row_, output_col_);
output = llt.solve(input_b);
if (lower_) {
output.noalias() = input.template triangularView<Lower>().solve(input_b);
input.adjoint().template triangularView<Upper>().solveInPlace(output);
} else {
output.noalias() = input.adjoint().template triangularView<Lower>().solve(input_b);
input.template triangularView<Upper>().solveInPlace(output);
}
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
return true;
}
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
MS_LOG_EXCEPTION << kernel_name_ << " output cholesky solve shape invalid.";
}
} // namespace kernel
} // namespace mindspore

View File

@ -42,6 +42,7 @@ class CholeskySolverCPUKernel : public CPUKernel {
size_t output_row_{1};
size_t output_col_{1};
TypeId dtype_{kNumberTypeFloat32};
bool lower_{false};
};
MS_REG_CPU_KERNEL_T(

View File

@ -16,28 +16,12 @@ class Cholesky(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, lower=False, clean=False, split_dim=0):
def __init__(self, lower=False, clean=True):
super().__init__(name="Cholesky")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
if self.split_dim != 0:
height = x_shape[0]
width = x_shape[1]
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = x_shape
return out_shape
def __infer__(self, x):
x_shape = x['shape']
x_dtype = x['dtype']
@ -54,10 +38,9 @@ class CholeskySolver(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, lower=False, split_dim=0):
def __init__(self, lower=False):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def __infer__(self, x, b):
@ -71,18 +54,18 @@ class CholeskySolver(PrimitiveWithInfer):
class CholeskyNet(nn.Cell):
def __init__(self, lower=False, clean=False, split_dim=0):
def __init__(self, lower=False, clean=False):
super(CholeskyNet, self).__init__()
self.cholesky = Cholesky(lower, clean, split_dim)
self.cholesky = Cholesky(lower, clean)
def construct(self, x):
return self.cholesky(x)
class CholeskySolverNet(nn.Cell):
def __init__(self, lower=False, split_dim=0):
def __init__(self, lower=False):
super(CholeskySolverNet, self).__init__()
self.cholesky_solver = CholeskySolver(lower, split_dim)
self.cholesky_solver = CholeskySolver(lower)
def construct(self, c, b):
return self.cholesky_solver(c, b)
@ -167,10 +150,12 @@ def test_cholesky_solver():
b = np.array([1, 1, 1, 1], dtype=np.float32)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
scp_c, lower = scp.linalg.cho_factor(a, lower=True)
scp_x = scp.linalg.cho_solve((scp_c, lower), b)
mscp_c, mscp_lower = cho_factor(tensor_a, lower=True)
mscp_x = cho_solve((tensor_a, mscp_lower), tensor_b)
scp_c, lower = scp.linalg.cho_factor(a, lower=False)
mscp_c, mscp_lower = cho_factor(tensor_a, lower=False)
assert np.allclose(scp_c, mscp_c.asnumpy())
scp_factor = (scp_c, lower)
ms_cho_factor = (mscp_c, mscp_lower)
scp_x = scp.linalg.cho_solve(scp_factor, b)
mscp_x = cho_solve(ms_cho_factor, tensor_b)
assert np.allclose(scp_x, mscp_x.asnumpy())