!25590 fix lu api bugs

Merge pull request !25590 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2021-10-29 09:04:24 +00:00 committed by Gitee
commit 073ff9d2b9
8 changed files with 230 additions and 64 deletions

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
#include "Eigen/Dense"
#include "Eigen/Core"
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;
using Eigen::Dynamic;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::UnitLower;
using Eigen::UnitUpper;
using Eigen::Upper;
template <typename T, int Major>
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_

View File

@ -16,13 +16,13 @@
#include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h"
#include <vector>
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kLUInputsNum = 1;
constexpr size_t kLUaIndex = 0;
@ -73,27 +73,44 @@ template <typename T>
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
Map<Matrix<T, RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_pivots(
pivots_value, pivots_row_, pivots_col_);
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_permutation(
permutation_value, permutation_row_, permutation_col_);
Map<Matrix<int, RowMajor>> output_permutation(permutation_value, permutation_row_, permutation_col_);
if (a_row_ == a_col_) {
// partial_piv_lu
output_lu = input_a.lu().matrixLU();
output_pivots = input_a.lu().permutationP().indices();
auto partial_lu = input_a.lu();
auto partial_p = partial_lu.permutationP();
output_lu.noalias() = partial_lu.matrixLU();
output_permutation.noalias() = partial_p.toDenseMatrix();
} else {
// full_piv_lu
output_lu = input_a.fullPivLu().matrixLU();
output_pivots = input_a.fullPivLu().permutationP().indices();
auto full_piv_lu = input_a.fullPivLu();
auto full_piv_p = full_piv_lu.permutationP();
output_lu.noalias() = full_piv_lu.matrixLU();
output_permutation.noalias() = full_piv_p.toDenseMatrix();
}
output_permutation = output_pivots;
// calculate permutation array from permutation matrix to indicate scipy's pivots.
for (int i = 0; i < static_cast<int>(output_permutation.rows()); ++i) {
if (output_permutation(i, i) != 0) {
pivots_value[i] = i;
continue;
}
for (int j = 0; j < static_cast<int>(output_permutation.cols()); ++j) {
if (output_permutation(i, j) != 0) {
pivots_value[i] = j;
break;
}
}
}
// here, we note that eigen calculate permutation matrix is col major, so transpose it to row major,
// but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax.
output_permutation.transposeInPlace();
if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) {
return true;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
@ -62,4 +62,4 @@ MS_REG_CPU_KERNEL_T(LU,
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_

View File

@ -16,13 +16,13 @@
#include "backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h"
#include <vector>
#include <string>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "Eigen/Dense"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kLUInputsNum = 2;
constexpr size_t kLUaIndex = 0;
@ -70,6 +70,7 @@ void LUSolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex);
out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex);
}
trans_ = AnfAlgo ::GetNodeAttr<std::string>(kernel_node, TRANS);
}
template <typename T>
@ -77,23 +78,28 @@ bool LUSolverCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
T *b_value = reinterpret_cast<T *>(inputs[kLUbIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(b_value, b_row_, b_col_);
Map<Matrix<T, RowMajor>> input_b(b_value, b_row_, b_col_);
T *output_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(output_lu_value, out_row_,
out_col_);
if (a_row_ == a_col_) {
// partial_piv_lu
output_lu = input_a.lu().solve(input_b);
Map<Matrix<T, RowMajor>> output_lu(output_lu_value, out_row_, out_col_);
if (trans_ == "N") {
output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(input_b);
output_lu.noalias() = input_a.template triangularView<Upper>().solve(output_lu);
} else if (trans_ == "T") {
output_lu.noalias() = input_a.template triangularView<Upper>().solve(input_b);
output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(output_lu);
} else if (trans_ == "C") {
MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is not supported C: " << trans_;
} else {
// full_piv_lu
output_lu = input_a.fullPivLu().solve(input_b);
MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is invalid: " << trans_;
}
if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
return true;
}
} // namespace kernel

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_SOLVER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
@ -39,6 +40,7 @@ class LUSolverCPUKernel : public CPUKernel {
size_t b_col_{1};
size_t out_row_{1};
size_t out_col_{1};
std::string trans_{};
TypeId dtype_{kNumberTypeFloat32};
};
@ -53,4 +55,4 @@ MS_REG_CPU_KERNEL_T(
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_

View File

@ -106,6 +106,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
}

View File

@ -72,6 +72,13 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateCPUKernel);
MS_REG_CPU_KERNEL(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterNdUpdateCPUKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -19,11 +19,14 @@ from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from scipy.linalg import lu_factor
from scipy.linalg import lu_solve
from mindspore._checkparam import Validator as validator
import mindspore.numpy as mnp
import scipy as scp
import numpy as np
import pytest
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -41,19 +44,14 @@ class LU(PrimitiveWithInfer):
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
pivots_shape = []
permutation_shape = []
ndim = len(x_shape)
permutation_shape = x_shape
if ndim == 0:
pivots_shape = x_shape
permutation_shape = x_shape
elif ndim == 1:
pivots_shape = x_shape[:-1]
permutation_shape = x_shape[:-1]
else:
pivots_shape = x_shape[-2:-1]
permutation_shape = x_shape[-2:-1]
output = {
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
@ -68,9 +66,10 @@ class LUSolver(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
def __init__(self, trans: str):
super().__init__(name="LUSolver")
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
self.trans = validator.check_value_type("trans", trans, [str], self.name)
def __infer__(self, x, b):
b_shape = list(b['shape'])
@ -92,42 +91,128 @@ class LuNet(nn.Cell):
return self.lu(a)
class LUSolverNet(nn.Cell):
def __init__(self):
super(LUSolverNet, self).__init__()
self.lu_solver = LUSolver()
def lu_pivots_to_permutation(pivots, permutation_size: int):
batch_dims = pivots.shape[:-1]
k = pivots.shape[-1]
per = mnp.arange(0, permutation_size)
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
permutation = mnp.array(permutation)
if permutation_size == 0:
return permutation
def construct(self, a, b):
return self.lu_solver(a, b)
for i in range(k):
j = pivots[..., i]
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[loc + (j,)]
permutation[..., i] = y
permutation[loc + (j,)] = x
return permutation
def _match_array(actual, expected, error=0):
if isinstance(actual, int):
actual = np.asarray(actual)
if isinstance(actual, tuple):
actual = np.asarray(actual)
if error > 0:
np.testing.assert_almost_equal(actual, expected, decimal=error)
def _lu_solve_core(in_lu, permutation, b, trans):
m = in_lu.shape[0]
res_shape = b.shape[1:]
prod_result = 1
for sh in res_shape:
prod_result *= sh
x = mnp.reshape(b, (m, prod_result))
if trans == 0:
trans_str = "N"
x = x[permutation, :]
elif trans == 1:
trans_str = "T"
elif trans == 2:
trans_str = "C"
else:
np.testing.assert_equal(actual, expected)
raise ValueError("trans error, it's value must be 0, 1, 2")
ms_lu_solve = LUSolver(trans_str)
output = ms_lu_solve(in_lu, x)
return mnp.reshape(output, b.shape)
def _check_lu_shape(in_lu, b):
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
raise ValueError("last two dimensions of LU decomposition must be equal.")
if b.shape is None:
raise ValueError(" LU decomposition input b's rank must >=1.")
rhs_vector = in_lu.ndim == b.ndim + 1
if rhs_vector:
if b.shape[-1] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
mnp.expand_dims(b, axis=1)
else:
if b.shape[-2] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
def lu_factor(a, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
mscp_lu = LuNet()
m_lu, pivots, _ = mscp_lu(a)
return m_lu, pivots
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
mscp_lu = LuNet()
m_lu, _, p = mscp_lu(a)
m = a.shape[-2]
n = a.shape[-1]
k = min(m, n)
a_dtype = a.dtype
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.matmul(p, l), u
return p, l, u
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
del overwrite_b, check_finite
m_lu, pivots = lu_and_piv
# 1. check shape
_check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. calculate permutation
permutation = pivots
# 3. rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
x = _lu_solve_core(m_lu, permutation, b, trans)
return x[..., 0] if rhs_vector else x
def create_full_rank_matrix(m, n, dtype):
a_rank = 0
a = np.random.random((m, n)).astype(dtype)
while a_rank != m:
a = (a + np.eye(m, n)).astype(dtype)
a_rank = np.linalg.matrix_rank(a)
return a
def create_sym_pos_matrix(m, n, dtype):
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
return np.dot(a, a.T)
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_lu_net(n: int, dtype: Generic):
def test_square_lu_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
s_lu, _ = lu_factor(a)
a = create_full_rank_matrix(n, n, dtype)
s_lu, _ = scp.linalg.lu_factor(a)
mscp_lu_net = LuNet()
tensor_a = Tensor(a)
mscp_lu, _, _ = mscp_lu_net(tensor_a)
_match_array(mscp_lu.asnumpy(), s_lu, error=4)
assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
@pytest.mark.platform_x86_cpu
@ -139,13 +224,24 @@ def test_lu_solver_net(n: int, dtype: Generic):
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
a = create_full_rank_matrix(n, n, dtype)
b = np.random.random((n, 1)).astype(dtype)
s_lu, s_piv = lu_factor(a)
lu_factor_x = (s_lu, s_piv)
scp_x = lu_solve(lu_factor_x, b)
mscp_lu_net = LUSolverNet()
s_lu, s_piv = scp.linalg.lu_factor(a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
mscp_x = mscp_lu_net(tensor_a, tensor_b)
_match_array(mscp_x.asnumpy(), scp_x, error=4)
mscp_lu_net = LuNet()
mscp_lu, pivots, _ = mscp_lu_net(tensor_a)
np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
lu_factor_x = (s_lu, s_piv)
msc_lu_factor = (mscp_lu, pivots)
scp_x = scp.linalg.lu_solve(lu_factor_x, b)
mscp_x = lu_solve(msc_lu_factor, tensor_b)
real_b = mnp.dot(tensor_a, mscp_x)
expected_b = np.dot(a, scp_x)
assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3)
assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3)