forked from mindspore-Ecosystem/mindspore
fix cholesky factor and solve for cpu backend
This commit is contained in:
parent
eff2318bb6
commit
9d13db2ac3
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/Cholesky"
|
||||
|
@ -46,7 +47,6 @@ void CholeskyCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size
|
|||
if (*row != *col) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -70,15 +70,12 @@ template <typename T>
|
|||
bool CholeskyCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_row_,
|
||||
input_col_);
|
||||
Map<Matrix<T, RowMajor>> input(input_value, input_row_, input_col_);
|
||||
|
||||
T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
|
||||
output_col_);
|
||||
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt;
|
||||
Map<Matrix<T, RowMajor>> output(output_value, output_row_, output_col_);
|
||||
Eigen::LLT<Matrix<T, RowMajor>> llt;
|
||||
llt.compute(input);
|
||||
|
||||
if (clean_) {
|
||||
if (lower_) {
|
||||
output = llt.matrixL();
|
||||
|
@ -86,12 +83,16 @@ bool CholeskyCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const s
|
|||
output = llt.matrixU();
|
||||
}
|
||||
} else {
|
||||
output = llt.matrixLLT();
|
||||
if (lower_) {
|
||||
output = llt.matrixLLT();
|
||||
} else {
|
||||
output = llt.matrixLLT().transpose();
|
||||
}
|
||||
}
|
||||
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output cholesky shape invalid.";
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/Cholesky"
|
||||
namespace mindspore {
|
||||
|
@ -41,7 +42,6 @@ void CholeskySolverCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape
|
|||
*row = shape.at(shape.size() - kRowIndex);
|
||||
*col = shape.at(shape.size() - kColIndex);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -59,30 +59,37 @@ void CholeskySolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
InitMatrixInfo(input_b_shape, &input_b_row_, &input_b_col_);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
|
||||
InitMatrixInfo(output_shape, &output_row_, &output_col_);
|
||||
lower_ = AnfAlgo ::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
if (input_a_row_ != input_b_row_) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "llt solve input row is not equal to b row: " << input_a_row_ << " vs "
|
||||
<< input_b_row_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CholeskySolverCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_a_row_,
|
||||
input_a_col_);
|
||||
Map<Matrix<T, RowMajor>> input(input_value, input_a_row_, input_a_col_);
|
||||
|
||||
T *input_b_value = reinterpret_cast<T *>(inputs[kInputBIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(input_b_value, input_b_row_,
|
||||
input_b_col_);
|
||||
Map<Matrix<T, RowMajor>> input_b(input_b_value, input_b_row_, input_b_col_);
|
||||
|
||||
T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
|
||||
output_col_);
|
||||
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt(input);
|
||||
Map<Matrix<T, RowMajor>> output(output_value, output_row_, output_col_);
|
||||
|
||||
output = llt.solve(input_b);
|
||||
if (lower_) {
|
||||
output.noalias() = input.template triangularView<Lower>().solve(input_b);
|
||||
input.adjoint().template triangularView<Upper>().solveInPlace(output);
|
||||
} else {
|
||||
output.noalias() = input.adjoint().template triangularView<Lower>().solve(input_b);
|
||||
input.template triangularView<Upper>().solveInPlace(output);
|
||||
}
|
||||
|
||||
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output cholesky solve shape invalid.";
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,7 @@ class CholeskySolverCPUKernel : public CPUKernel {
|
|||
size_t output_row_{1};
|
||||
size_t output_col_{1};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
bool lower_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
|
|
|
@ -16,28 +16,12 @@ class Cholesky(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=False, split_dim=0):
|
||||
def __init__(self, lower=False, clean=True):
|
||||
super().__init__(name="Cholesky")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if self.split_dim != 0:
|
||||
height = x_shape[0]
|
||||
width = x_shape[1]
|
||||
if height <= self.split_dim:
|
||||
out_shape = [1, height, width]
|
||||
else:
|
||||
batch = height // self.split_dim
|
||||
if height != batch * self.split_dim:
|
||||
batch += 1
|
||||
out_shape = [batch, self.split_dim, self.split_dim]
|
||||
else:
|
||||
out_shape = x_shape
|
||||
return out_shape
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = x['shape']
|
||||
x_dtype = x['dtype']
|
||||
|
@ -54,10 +38,9 @@ class CholeskySolver(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, split_dim=0):
|
||||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolver")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, b):
|
||||
|
@ -71,18 +54,18 @@ class CholeskySolver(PrimitiveWithInfer):
|
|||
|
||||
|
||||
class CholeskyNet(nn.Cell):
|
||||
def __init__(self, lower=False, clean=False, split_dim=0):
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super(CholeskyNet, self).__init__()
|
||||
self.cholesky = Cholesky(lower, clean, split_dim)
|
||||
self.cholesky = Cholesky(lower, clean)
|
||||
|
||||
def construct(self, x):
|
||||
return self.cholesky(x)
|
||||
|
||||
|
||||
class CholeskySolverNet(nn.Cell):
|
||||
def __init__(self, lower=False, split_dim=0):
|
||||
def __init__(self, lower=False):
|
||||
super(CholeskySolverNet, self).__init__()
|
||||
self.cholesky_solver = CholeskySolver(lower, split_dim)
|
||||
self.cholesky_solver = CholeskySolver(lower)
|
||||
|
||||
def construct(self, c, b):
|
||||
return self.cholesky_solver(c, b)
|
||||
|
@ -167,10 +150,12 @@ def test_cholesky_solver():
|
|||
b = np.array([1, 1, 1, 1], dtype=np.float32)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
scp_c, lower = scp.linalg.cho_factor(a, lower=True)
|
||||
scp_x = scp.linalg.cho_solve((scp_c, lower), b)
|
||||
|
||||
mscp_c, mscp_lower = cho_factor(tensor_a, lower=True)
|
||||
mscp_x = cho_solve((tensor_a, mscp_lower), tensor_b)
|
||||
scp_c, lower = scp.linalg.cho_factor(a, lower=False)
|
||||
mscp_c, mscp_lower = cho_factor(tensor_a, lower=False)
|
||||
assert np.allclose(scp_c, mscp_c.asnumpy())
|
||||
|
||||
scp_factor = (scp_c, lower)
|
||||
ms_cho_factor = (mscp_c, mscp_lower)
|
||||
scp_x = scp.linalg.cho_solve(scp_factor, b)
|
||||
mscp_x = cho_solve(ms_cho_factor, tensor_b)
|
||||
assert np.allclose(scp_x, mscp_x.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue