clean codex of cholesky

This commit is contained in:
z00512249 2022-01-24 19:22:42 +08:00
parent e0fcbdc2c7
commit da4fd37f93
2 changed files with 3 additions and 6 deletions

View File

@ -18,7 +18,6 @@
#include <vector>
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
namespace mindspore {
namespace kernel {
@ -68,7 +67,7 @@ void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
}
template <typename T>
bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
Map<Matrix<T, RowMajor>> input(input_value, input_row_, input_col_);
@ -76,7 +75,7 @@ bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, cons
T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Map<Matrix<T, RowMajor>> output(output_value, output_row_, output_col_);
Eigen::LLT<Matrix<T, RowMajor>> llt;
llt.compute(input);
(void)llt.compute(input);
if (clean_) {
if (lower_) {
output = llt.matrixL();

View File

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
namespace mindspore {
namespace kernel {
@ -67,8 +66,7 @@ void CholeskySolverCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
}
template <typename T>
bool CholeskySolverCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
bool CholeskySolverCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
Map<Matrix<T, RowMajor>> input(input_value, input_a_row_, input_a_col_);