From d3726b435d2454870a9f42bad7768400f44b3208 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Thu, 28 Oct 2021 19:07:12 +0800 Subject: [PATCH] fix lu api bugs --- .../cpu/eigen/eigen_common_utils.h | 36 ++++ .../cpu/eigen/lu_cpu_kernel.cc | 41 +++-- .../kernel_compiler/cpu/eigen/lu_cpu_kernel.h | 6 +- .../cpu/eigen/lu_solve_cpu_kernel.cc | 28 +-- .../cpu/eigen/lu_solve_cpu_kernel.h | 8 +- .../cpu/scatter_nd_update_cpu_kernel.cc | 2 + .../cpu/scatter_nd_update_cpu_kernel.h | 7 + tests/st/ops/cpu/test_lu_op.py | 166 ++++++++++++++---- 8 files changed, 230 insertions(+), 64 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigen_common_utils.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigen_common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigen_common_utils.h new file mode 100644 index 00000000000..729c265ba77 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigen_common_utils.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ +#include "Eigen/Dense" +#include "Eigen/Core" +namespace mindspore { +namespace kernel { +using Eigen::ColMajor; +using Eigen::Dynamic; +using Eigen::Lower; +using Eigen::Map; +using Eigen::MatrixBase; +using Eigen::RowMajor; +using Eigen::UnitLower; +using Eigen::UnitUpper; +using Eigen::Upper; +template +using Matrix = Eigen::Matrix; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc index 2de89d0ff69..61519e8d2c6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc @@ -16,13 +16,13 @@ #include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h" #include +#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" #include "utils/ms_utils.h" #include "Eigen/Dense" #include "Eigen/LU" namespace mindspore { namespace kernel { - namespace { constexpr size_t kLUInputsNum = 1; constexpr size_t kLUaIndex = 0; @@ -73,27 +73,44 @@ template bool LUCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { T *a_value = reinterpret_cast(inputs[kLUaIndex]->addr); - Eigen::Map> input_a(a_value, a_row_, a_col_); + Map> input_a(a_value, a_row_, a_col_); T *lu_value = reinterpret_cast(outputs[kLuIndex]->addr); - Eigen::Map> output_lu(lu_value, lu_row_, lu_col_); + Map> output_lu(lu_value, lu_row_, lu_col_); int *pivots_value = reinterpret_cast(outputs[kPivotsIndex]->addr); - Eigen::Map> output_pivots( - pivots_value, pivots_row_, pivots_col_); int *permutation_value = reinterpret_cast(outputs[kPermutationIndex]->addr); - Eigen::Map> output_permutation( - permutation_value, permutation_row_, permutation_col_); + Map> output_permutation(permutation_value, permutation_row_, permutation_col_); if (a_row_ == a_col_) { // partial_piv_lu - output_lu = input_a.lu().matrixLU(); - output_pivots = input_a.lu().permutationP().indices(); + auto partial_lu = input_a.lu(); + auto partial_p = partial_lu.permutationP(); + output_lu.noalias() = partial_lu.matrixLU(); + output_permutation.noalias() = partial_p.toDenseMatrix(); } else { // full_piv_lu - output_lu = input_a.fullPivLu().matrixLU(); - output_pivots = input_a.fullPivLu().permutationP().indices(); + auto full_piv_lu = input_a.fullPivLu(); + auto full_piv_p = full_piv_lu.permutationP(); + output_lu.noalias() = full_piv_lu.matrixLU(); + output_permutation.noalias() = full_piv_p.toDenseMatrix(); } - output_permutation = output_pivots; + + // calculate permutation array from permutation matrix to indicate scipy's pivots. + for (int i = 0; i < static_cast(output_permutation.rows()); ++i) { + if (output_permutation(i, i) != 0) { + pivots_value[i] = i; + continue; + } + for (int j = 0; j < static_cast(output_permutation.cols()); ++j) { + if (output_permutation(i, j) != 0) { + pivots_value[i] = j; + break; + } + } + } + // here, we note that eigen calculate permutation matrix is col major, so transpose it to row major, + // but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax. + output_permutation.transposeInPlace(); if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) { return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h index 39b301f0ac4..94813196815 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -62,4 +62,4 @@ MS_REG_CPU_KERNEL_T(LU, } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc index 9bb7971f8aa..83e869379dc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc @@ -16,13 +16,13 @@ #include "backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h" #include +#include #include "utils/ms_utils.h" +#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" #include "Eigen/Dense" #include "Eigen/LU" - namespace mindspore { namespace kernel { - namespace { constexpr size_t kLUInputsNum = 2; constexpr size_t kLUaIndex = 0; @@ -70,6 +70,7 @@ void LUSolverCPUKernel::InitKernel(const CNodePtr &kernel_node) { out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex); out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex); } + trans_ = AnfAlgo ::GetNodeAttr(kernel_node, TRANS); } template @@ -77,23 +78,28 @@ bool LUSolverCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { T *a_value = reinterpret_cast(inputs[kLUaIndex]->addr); - Eigen::Map> input_a(a_value, a_row_, a_col_); + Map> input_a(a_value, a_row_, a_col_); T *b_value = reinterpret_cast(inputs[kLUbIndex]->addr); - Eigen::Map> input_b(b_value, b_row_, b_col_); + Map> input_b(b_value, b_row_, b_col_); T *output_lu_value = reinterpret_cast(outputs[kLuIndex]->addr); - Eigen::Map> output_lu(output_lu_value, out_row_, - out_col_); - if (a_row_ == a_col_) { - // partial_piv_lu - output_lu = input_a.lu().solve(input_b); + Map> output_lu(output_lu_value, out_row_, out_col_); + if (trans_ == "N") { + output_lu.noalias() = input_a.template triangularView().solve(input_b); + output_lu.noalias() = input_a.template triangularView().solve(output_lu); + } else if (trans_ == "T") { + output_lu.noalias() = input_a.template triangularView().solve(input_b); + output_lu.noalias() = input_a.template triangularView().solve(output_lu); + } else if (trans_ == "C") { + MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is not supported C: " << trans_; } else { - // full_piv_lu - output_lu = input_a.fullPivLu().solve(input_b); + MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is invalid: " << trans_; } + if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) { MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid."; } + return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h index 573ae15e74d..540864cc057 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h @@ -14,10 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_SOLVER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ #include +#include #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" @@ -39,6 +40,7 @@ class LUSolverCPUKernel : public CPUKernel { size_t b_col_{1}; size_t out_row_{1}; size_t out_col_{1}; + std::string trans_{}; TypeId dtype_{kNumberTypeFloat32}; }; @@ -53,4 +55,4 @@ MS_REG_CPU_KERNEL_T( } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index 268f69b27e0..e3c123e0069 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -106,6 +106,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inp LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeFloat32) { LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); } else { MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h index 2cf15b61c47..1d739c2fb25 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -72,6 +72,13 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), ScatterNdUpdateCPUKernel); +MS_REG_CPU_KERNEL(ScatterNdUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + ScatterNdUpdateCPUKernel) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_lu_op.py b/tests/st/ops/cpu/test_lu_op.py index 32e3c09fc4b..e323b7a5bfe 100644 --- a/tests/st/ops/cpu/test_lu_op.py +++ b/tests/st/ops/cpu/test_lu_op.py @@ -19,11 +19,14 @@ from mindspore import Tensor import mindspore.common.dtype as mstype from mindspore.ops import PrimitiveWithInfer from mindspore.ops import prim_attr_register -from scipy.linalg import lu_factor -from scipy.linalg import lu_solve +from mindspore._checkparam import Validator as validator +import mindspore.numpy as mnp +import scipy as scp import numpy as np import pytest +np.random.seed(0) + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') @@ -41,19 +44,14 @@ class LU(PrimitiveWithInfer): def __infer__(self, x): x_shape = list(x['shape']) x_dtype = x['dtype'] - pivots_shape = [] - permutation_shape = [] ndim = len(x_shape) + permutation_shape = x_shape if ndim == 0: pivots_shape = x_shape - permutation_shape = x_shape elif ndim == 1: pivots_shape = x_shape[:-1] - permutation_shape = x_shape[:-1] else: pivots_shape = x_shape[-2:-1] - permutation_shape = x_shape[-2:-1] - output = { 'shape': (x_shape, pivots_shape, permutation_shape), 'dtype': (x_dtype, mstype.int32, mstype.int32), @@ -68,9 +66,10 @@ class LUSolver(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self): + def __init__(self, trans: str): super().__init__(name="LUSolver") self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) + self.trans = validator.check_value_type("trans", trans, [str], self.name) def __infer__(self, x, b): b_shape = list(b['shape']) @@ -92,42 +91,128 @@ class LuNet(nn.Cell): return self.lu(a) -class LUSolverNet(nn.Cell): - def __init__(self): - super(LUSolverNet, self).__init__() - self.lu_solver = LUSolver() +def lu_pivots_to_permutation(pivots, permutation_size: int): + batch_dims = pivots.shape[:-1] + k = pivots.shape[-1] + per = mnp.arange(0, permutation_size) + permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,)) + permutation = mnp.array(permutation) + if permutation_size == 0: + return permutation - def construct(self, a, b): - return self.lu_solver(a, b) + for i in range(k): + j = pivots[..., i] + loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims)) + x = permutation[..., i] + y = permutation[loc + (j,)] + permutation[..., i] = y + permutation[loc + (j,)] = x + return permutation -def _match_array(actual, expected, error=0): - if isinstance(actual, int): - actual = np.asarray(actual) - if isinstance(actual, tuple): - actual = np.asarray(actual) - - if error > 0: - np.testing.assert_almost_equal(actual, expected, decimal=error) +def _lu_solve_core(in_lu, permutation, b, trans): + m = in_lu.shape[0] + res_shape = b.shape[1:] + prod_result = 1 + for sh in res_shape: + prod_result *= sh + x = mnp.reshape(b, (m, prod_result)) + if trans == 0: + trans_str = "N" + x = x[permutation, :] + elif trans == 1: + trans_str = "T" + elif trans == 2: + trans_str = "C" else: - np.testing.assert_equal(actual, expected) + raise ValueError("trans error, it's value must be 0, 1, 2") + ms_lu_solve = LUSolver(trans_str) + output = ms_lu_solve(in_lu, x) + return mnp.reshape(output, b.shape) + + +def _check_lu_shape(in_lu, b): + if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]: + raise ValueError("last two dimensions of LU decomposition must be equal.") + + if b.shape is None: + raise ValueError(" LU decomposition input b's rank must >=1.") + rhs_vector = in_lu.ndim == b.ndim + 1 + if rhs_vector: + if b.shape[-1] != in_lu.shape[-1]: + raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions") + mnp.expand_dims(b, axis=1) + else: + if b.shape[-2] != in_lu.shape[-1]: + raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions") + + +def lu_factor(a, overwrite_a=False, check_finite=True): + del overwrite_a, check_finite + mscp_lu = LuNet() + m_lu, pivots, _ = mscp_lu(a) + return m_lu, pivots + + +def lu(a, permute_l=False, overwrite_a=False, check_finite=True): + del overwrite_a, check_finite + mscp_lu = LuNet() + m_lu, _, p = mscp_lu(a) + m = a.shape[-2] + n = a.shape[-1] + k = min(m, n) + a_dtype = a.dtype + l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype) + u = mnp.triu(m_lu)[:k, :] + if permute_l: + return mnp.matmul(p, l), u + return p, l, u + + +def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): + del overwrite_b, check_finite + m_lu, pivots = lu_and_piv + # 1. check shape + _check_lu_shape(m_lu, b) + # here permutation array has been calculated, just use it. + # 2. calculate permutation + permutation = pivots + # 3. rhs_vector + rhs_vector = m_lu.ndim == b.ndim + 1 + x = _lu_solve_core(m_lu, permutation, b, trans) + + return x[..., 0] if rhs_vector else x + + +def create_full_rank_matrix(m, n, dtype): + a_rank = 0 + a = np.random.random((m, n)).astype(dtype) + while a_rank != m: + a = (a + np.eye(m, n)).astype(dtype) + a_rank = np.linalg.matrix_rank(a) + return a + + +def create_sym_pos_matrix(m, n, dtype): + a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype) + return np.dot(a, a.T) @pytest.mark.platform_x86_cpu @pytest.mark.parametrize('n', [10, 20]) @pytest.mark.parametrize('dtype', [np.float32, np.float64]) -def test_lu_net(n: int, dtype: Generic): +def test_square_lu_net(n: int, dtype: Generic): """ Feature: ALL To ALL Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - a = (np.random.random((n, n)) + np.eye(n)).astype(dtype) - s_lu, _ = lu_factor(a) + a = create_full_rank_matrix(n, n, dtype) + s_lu, _ = scp.linalg.lu_factor(a) mscp_lu_net = LuNet() tensor_a = Tensor(a) mscp_lu, _, _ = mscp_lu_net(tensor_a) - _match_array(mscp_lu.asnumpy(), s_lu, error=4) + assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3) @pytest.mark.platform_x86_cpu @@ -139,13 +224,24 @@ def test_lu_solver_net(n: int, dtype: Generic): Description: test cases for lu_solve test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - a = (np.random.random((n, n)) + np.eye(n)).astype(dtype) + a = create_full_rank_matrix(n, n, dtype) b = np.random.random((n, 1)).astype(dtype) - s_lu, s_piv = lu_factor(a) - lu_factor_x = (s_lu, s_piv) - scp_x = lu_solve(lu_factor_x, b) - mscp_lu_net = LUSolverNet() + s_lu, s_piv = scp.linalg.lu_factor(a) + tensor_a = Tensor(a) tensor_b = Tensor(b) - mscp_x = mscp_lu_net(tensor_a, tensor_b) - _match_array(mscp_x.asnumpy(), scp_x, error=4) + mscp_lu_net = LuNet() + mscp_lu, pivots, _ = mscp_lu_net(tensor_a) + np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3) + + lu_factor_x = (s_lu, s_piv) + msc_lu_factor = (mscp_lu, pivots) + + scp_x = scp.linalg.lu_solve(lu_factor_x, b) + mscp_x = lu_solve(msc_lu_factor, tensor_b) + + real_b = mnp.dot(tensor_a, mscp_x) + expected_b = np.dot(a, scp_x) + + assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3) + assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3)