fix lu factor and solve for cpu backend

This commit is contained in:
z00512249 2021-11-26 16:24:01 +08:00
parent 2427f7b3a4
commit 872d8b37da
11 changed files with 546 additions and 354 deletions

View File

@ -16,10 +16,10 @@
#include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h"
#include <vector>
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include <algorithm>
#include <utility>
#include <unordered_map>
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
@ -33,6 +33,7 @@ constexpr size_t kPermutationIndex = 2;
constexpr size_t kLUDefaultShape = 1;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
constexpr int kZeroThreshold = INT32_MIN;
} // namespace
template <typename T>
@ -70,51 +71,145 @@ void LUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
}
template <typename T>
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
void LUCPUKernel<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
size_t lu_size = lu_col_ * sizeof(T);
(void)workspace_size_list_.emplace_back(lu_size);
(void)workspace_size_list_.emplace_back(lu_size);
}
template <typename T>
T LUCPUKernel<T>::GetPermutatedValue(const T *lu_value, const int *per, size_t i, size_t j) {
const T *pered_lu_value = lu_value + per[i] * lu_col_ + j;
return *pered_lu_value;
}
template <typename T>
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, int *per, size_t k, size_t rows) {
T max_major_value = static_cast<T>(kZeroThreshold);
int max_major_index = 0;
for (size_t i = k; i < rows; ++i) {
T value = GetPermutatedValue(lu_value, per, i, k);
T abs_value = std::abs(value);
if (abs_value > max_major_value) {
max_major_value = abs_value;
max_major_index = i;
}
}
size_t per_k = per[k];
per[k] = per[max_major_index];
per[max_major_index] = per_k;
return max_major_value != static_cast<T>(kZeroThreshold);
}
template <typename T>
void LUCPUKernel<T>::SetPermutatedValue(T *lu_value, const int *per, size_t i, size_t j, const T &value) {
T *pered_lu_value = lu_value + per[i] * lu_col_ + j;
*pered_lu_value = value;
}
template <typename T>
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
// input matrix of (m,n) PA = LU
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Map<Matrix<T, RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
// pivots permutation value
int *per_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
// permutation matrix value
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
Map<Matrix<int, RowMajor>> output_permutation(permutation_value, permutation_row_, permutation_col_);
if (a_row_ == a_col_) {
// partial_piv_lu
auto partial_lu = input_a.lu();
auto partial_p = partial_lu.permutationP();
output_lu.noalias() = partial_lu.matrixLU();
output_permutation.noalias() = partial_p.toDenseMatrix();
} else {
// full_piv_lu
auto full_piv_lu = input_a.fullPivLu();
auto full_piv_p = full_piv_lu.permutationP();
output_lu.noalias() = full_piv_lu.matrixLU();
output_permutation.noalias() = full_piv_p.toDenseMatrix();
T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr);
T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr);
// init pivots
for (size_t i = 0; i < pivots_row_; ++i) {
per_value[i] = i;
}
// calculate permutation array from permutation matrix to indicate scipy's pivots.
for (int i = 0; i < static_cast<int>(output_permutation.rows()); ++i) {
if (output_permutation(i, i) != 0) {
pivots_value[i] = i;
// 1. memcpy input to output, do full lu inplace.
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
int s = std::min(a_row_, a_col_);
// 2. do lu decompose inplace
for (int k = 0; k < s; ++k) {
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
if (!UpdateMajorPermutation(lu_value, per_value, k, lu_row_)) {
continue;
}
for (int j = 0; j < static_cast<int>(output_permutation.cols()); ++j) {
if (output_permutation(i, j) != 0) {
pivots_value[i] = j;
break;
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
// 2.3 change major col values
for (size_t i = k + 1; i < lu_row_; ++i) {
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
// set inplace new lu matrix value.
SetPermutatedValue(lu_value, per_value, i, k, y);
}
// 2.4 Gauss elimination core
for (size_t i = k + 1; i < lu_row_; ++i) {
for (size_t j = k + 1; j < lu_col_; ++j) {
T y =
static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
GetPermutatedValue(lu_value, per_value, i, k) * GetPermutatedValue(lu_value, per_value, k, j));
SetPermutatedValue(lu_value, per_value, i, j, y);
}
}
}
// here, we note that eigen calculate permutation matrix is col major, so transpose it to row major,
// but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax.
output_permutation.transposeInPlace();
if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) {
return true;
// 3. calculate final lu by permutation list
std::unordered_map<int, std::pair<int, bool>> pivots_map;
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
pivots_map[per_value[i]] = {i, false};
}
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
int pivots_count = 0;
for (const auto &pivot : pivots_map) {
pivots_count++;
int key = pivot.first;
int index = pivot.second.first;
bool is_visited = pivot.second.second;
if (is_visited || index == (pivots_count - 1)) {
continue;
}
T *lu_ori_row = lu_value + index * lu_col_;
T *lu_trans_row = lu_value + key * lu_col_;
// copy ori data to trans lu
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data ptr
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
// update pivot map
pivots_map[key] = {index, true};
// put ori data which stored in workspace to mapped new place
is_visited = pivots_map[index].second;
while (!is_visited) {
key = index;
index = pivots_map[key].first;
is_visited = pivots_map[key].second;
lu_ori_row = lu_value + index * lu_col_;
T *tmp_wk = lu_trans_wk;
lu_trans_wk = lu_ori_wk;
lu_ori_wk = tmp_wk;
// copy new ori data to trans workspace
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data place
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
pivots_map[key] = {index, true};
}
}
// 4. calculate final permutation matrix
// for PA = LU get: base + row * permutation_row_ + col
// for A = PLU get: base + col * permutation_row_ + row
// here, we do A = PLU which is same as scipy.
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
for (size_t i = 0; i < pivots_row_; ++i) {
int position = per_value[i];
int *per_addr = permutation_value + position * permutation_row_ + i;
*per_addr = 1;
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -34,6 +34,10 @@ class LUCPUKernel : public CPUKernel {
private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const int *per, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, int *per, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const int *per, size_t i, size_t j, const T &value);
size_t a_row_{1};
size_t a_col_{1};
size_t lu_row_{1};

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/eigen/matmul_double_cpu_kernel.h"
#define EIGEN_NO_MALLOC
#include <Eigen/Dense>
#include <vector>

View File

@ -17,7 +17,6 @@
#include "backend/kernel_compiler/cpu/eigen/matrix_inverse_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
#include "Eigen/Dense"
#define EIGEN_NO_MALLOC
namespace mindspore {
namespace kernel {

View File

@ -15,12 +15,10 @@
*/
#include "backend/kernel_compiler/cpu/eigen/solve_triangular_cpu_kernel.h"
#define EIGEN_NO_MALLOC
#include <Eigen/Dense>
#include <vector>
#include <string>
#include <type_traits>
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;

View File

@ -17,7 +17,7 @@
from . import optimize, sparse, linalg
from .optimize import minimize, line_search
from .sparse import cg, gmres
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve, lu, lu_factor, lu_solve
__all__ = []
__all__.extend(optimize.__all__)

View File

@ -18,10 +18,13 @@ from .. import ops
from .ops import SolveTriangular
from .ops import CholeskySolver
from .ops import Cholesky
from .ops import LU
from .ops import LUSolver
from .ops import EighNet
from ..ops import operations as P
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh']
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh', 'lu_factor', 'lu',
'lu_solve']
def block_diag(*arrs):
@ -102,14 +105,10 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
Default is to use upper triangle.
trans (0, 1, 2, 'N', 'T', 'C', optional):
Type of system to solve:
======== =========
trans system
======== =========
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
======== =========
trans: system:
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and
will not be referenced.
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance)
@ -160,21 +159,18 @@ def inv(a, overwrite_a=False, check_finite=True):
Compute the inverse of a matrix.
Args:
a (Tensor): Tensor
Square matrix to be inverted.
overwrite_a (bool, optional): Discard data in `a` (may improve performance).
Default is False.
check_finite (bool, optional): Whether to check that the input matrix contains
only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
a (Tensor): Square matrix to be inverted.
overwrite_a (bool, optional): Discard data in `a` (may improve performance). Default is False.
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
if the inputs do contain infinities or NaNs.
Returns:
ainv (Tensor): Inverse of the matrix `a`.
Raises:
LinAlgError: If `a` is singular.
ValueError: If `a` is not square, or not 2D.
LinAlgError: If :math:'a' is singular.
ValueError: If :math:'a' is not square, or not 2D.
Supported Platforms:
``CPU`` ``GPU``
@ -328,52 +324,40 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
Solve a standard or generalized eigenvalue problem for a complex
Hermitian or real symmetric matrix.
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of
Tensor ``a``, where ``b`` is positive definite such that for every
eigenvalue λ (i-th entry of w) and its eigenvector ``vi`` (i-th column of
``v``) satisfies::
a @ vi = λ * b @ vi
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of Tensor ``a``,
where ``b`` is positive definite such that for every eigenvalue λ (i-th entry of w) and
its eigenvector ``vi`` (i-th column of``v``) satisfies::
a @ vi = λ * b @ vi
vi.conj().T @ a @ vi = λ
vi.conj().T @ b @ vi = 1
In the standard problem, ``b`` is assumed to be the identity matrix.
Args:
a (Tensor): (M, M) Tensor
A complex Hermitian or real symmetric matrix whose eigenvalues and
a (Tensor): A (M, M) complex Hermitian or real symmetric matrix whose eigenvalues and
eigenvectors will be computed.
b (Tensor, optional): (M, M) Tensor
A complex Hermitian or real symmetric definite positive matrix in.
b (Tensor, optional): A (M, M) complex Hermitian or real symmetric definite positive matrix in.
If omitted, identity matrix is assumed.
lower (bool, optional): Whether the pertinent Tensor data is taken from
the lower or upper triangle of ``a`` and, if applicable, ``b``. (Default: lower)
eigvals_only (bool, optional): Whether to calculate only eigenvalues
and no eigenvectors. (Default: both are calculated)
_type (int, optional): For the generalized problems, this keyword specifies
the problem type to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible
inputs)::
lower (bool, optional): Whether the pertinent Tensor data is taken from the lower or upper
triangle of ``a`` and, if applicable, ``b``. (Default: lower)
eigvals_only (bool, optional): Whether to calculate only eigenvalues and no eigenvectors.
(Default: both are calculated)
_type (int, optional): For the generalized problems, this keyword specifies the problem type
to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible inputs)::
1 => a @ v = w @ b @ v
2 => a @ b @ v = w @ v
3 => b @ a @ v = w @ v
This keyword is ignored for standard problems.
overwrite_a (bool, optional): Whether to overwrite data in ``a``
(may improve performance). Default is False.
overwrite_b (bool, optional): Whether to overwrite data in ``b``
(may improve performance). Default is False.
check_finite (bool, optional): Whether to check that the input matrices
contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
turbo (bool, optional): use divide and conquer algorithm (faster but
expensive in memory, only for generalized eigenvalue problem and
if full set of eigenvalues are requested.). Has no significant
effect if eigenvectors are not requested.
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order)
eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1.
If omitted, all eigenvalues and eigenvectors are returned.
overwrite_a (bool, optional): Whether to overwrite data in ``a`` (may improve performance). Default is False.
overwrite_b (bool, optional): Whether to overwrite data in ``b`` (may improve performance). Default is False.
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
if the inputs do contain infinities or NaNs.
turbo (bool, optional): use divide and conquer algorithm (faster but expensive in memory, only
for generalized eigenvalue problem and if full set of eigenvalues are requested.).
Has no significant effect if eigenvectors are not requested.
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order) eigenvalues
and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1. If omitted, all eigenvalues
and eigenvectors are returned.
Returns:
w (Tensor): (N,) Tensor, The N (1<=N<=M) selected eigenvalues, in ascending order,
@ -381,10 +365,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
v (Tensor): (M, N) Tensor, (if ``eigvals_only == False``)
Raises:
LinAlgError: If eigenvalue computation does not converge, an error occurred, or
b matrix is not definite positive. Note that if input matrices are
not symmetric or Hermitian, no error will be reported but results will
be wrong.
LinAlgError: If eigenvalue computation does not converge, an error occurred, or b matrix is not
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
be reported but results will be wrong.
Supported Platforms:
``CPU`` ``GPU``
@ -400,3 +383,227 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
"""
eigh_net = EighNet(not eigvals_only, lower=True)
return eigh_net(a)
def lu_pivots_to_permutation(pivots, permutation_size: int):
"""transfer pivots to permutation"""
batch_dims = pivots.shape[:-1]
k = pivots.shape[-1]
per = mnp.arange(0, permutation_size)
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
permutation = mnp.array(permutation)
if permutation_size == 0:
return permutation
for i in range(k):
j = pivots[..., i]
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[loc + (j,)]
permutation[..., i] = y
permutation[loc + (j,)] = x
return permutation
def lu_solve_core(in_lu, permutation, b, trans):
""" core implementation of lu solve"""
m = in_lu.shape[0]
res_shape = b.shape[1:]
prod_result = 1
for sh in res_shape:
prod_result *= sh
x = mnp.reshape(b, (m, prod_result))
if trans == 0:
trans_str = "N"
x = x[permutation, :]
elif trans == 1:
trans_str = "T"
elif trans == 2:
trans_str = "C"
else:
raise ValueError("trans error, it's value must be 0, 1, 2")
ms_lu_solve = LUSolver(trans_str)
output = ms_lu_solve(in_lu, x)
return mnp.reshape(output, b.shape)
def check_lu_shape(in_lu, b):
""" check lu input shape"""
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
raise ValueError("last two dimensions of LU decomposition must be equal.")
if b.shape is None:
raise ValueError(" LU decomposition input b's rank must >=1.")
rhs_vector = in_lu.ndim == b.ndim + 1
if rhs_vector:
if b.shape[-1] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
mnp.expand_dims(b, axis=1)
else:
if b.shape[-2] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.
Args:
a (Tensor): square matrix of (M, M) to decompose
overwrite_a (bool, optional): Whether to overwrite data in A (may increase performance)
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
lu (Tensor): a square matrix of (N, N) containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv (Tensor): (N,) Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import lu_factor
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
>>> lu, piv = lu_factor(A)
>>> lu
[[ 7. , 5. , 6. , 6. ],
[ 0.28571429, 3.57142857, 6.28571429, 5.28571429],
[ 0.71428571, 0.12 , -1.04 , 3.08 ],
[ 0.71428571, -0.44 , -0.46153846, 7.46153846]]
>>> piv
[2, 0, 3, 1]
"""
del overwrite_a, check_finite
msp_lu = LU()
m_lu, pivots, _ = msp_lu(a)
return m_lu, pivots
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Args:
a (Tensor): a (M, N) matrix to decompose
permute_l (bool, optional): Perform the multiplication P*L (Default: do not permute)
overwrite_a (bool, optional): Whether to overwrite data in a (may improve performance)
check_finite (bool, optional): Whether to check that the input matrix contains
only finite numbers. Disabling may give a performance gain, but may result
in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
**(If permute_l == False)**
p (Tensor): (M, M) Permutation matrix
l (Tensor): (M, K) Lower triangular or trapezoidal matrix with unit diagonal.
K = min(M, N)
u (Tensor): (K, N) Upper triangular or trapezoidal matrix
**(If permute_l == True)**
pl (Tensor): (M, K) Permuted L matrix.
K = min(M, N)
u (Tensor): (K, N) Upper triangular or trapezoidal matrix
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import lu
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
>>> p, l, u = lu(A)
>>> p
[[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]]
>>> l
[[ 1. , 0. , 0. , 0. ],
[ 0.28571429, 1. , 0. , 0. ],
[ 0.71428571, 0.12 , 1. , 0. ],
[ 0.71428571, -0.44 , -0.46153846, 1. ]]
>>> u
[[ 7. , 5. , 6. , 6. ],
[ 0. , 3.57142857, 6.28571429, 5.28571429],
[ 0. , 0. , -1.04 , 3.08 ],
[ 0. , 0. , 0. , 7.46153846]]
"""
del overwrite_a, check_finite
msp_lu = LU()
m_lu, _, p = msp_lu(a)
m = a.shape[-2]
n = a.shape[-1]
k = min(m, n)
a_dtype = a.dtype
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.dot(p, l), u
return p, l, u
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Args:
lu_and_piv (Tensor, Tensor): Factorization of the coefficient matrix a, as given by lu_factor
b (Tensor): Right-hand side
trans (int, optional): {0, 1, 2}
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b (bool, optional): Whether to overwrite data in b (may increase performance)
check_finite ( bool, optional): Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
if the inputs do contain infinities or NaNs.
Returns:
x (Tesnor): Solution to the system
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import lu_factor, lu_solve
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
>>> b = Tensor(onp.array([1, 1, 1, 1]).astype(onp.float64))
>>> lu, piv = lu_factor(A)
>>> lu_solve((lu, piv), b)
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
"""
del overwrite_b, check_finite
m_lu, pivots = lu_and_piv
# 1. check shape
check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. calculate permutation
permutation = pivots
# 3. rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
x = lu_solve_core(m_lu, permutation, b, trans)
return x[..., 0] if rhs_vector else x

View File

@ -101,7 +101,7 @@ class SolveTriangular(PrimitiveWithInfer):
class Cholesky(PrimitiveWithInfer):
"""
Inner API for _Cholesky base class.
Cholesky decomposition for A.
"""
@prim_attr_register
@ -175,14 +175,14 @@ class CholeskySolver(PrimitiveWithInfer):
def __init__(self, lower=False):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.init_prim_io_names(inputs=['x', 'b'], outputs=['y'])
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
def __infer__(self, x, b):
def __infer__(self, A, b):
b_shape = b['shape']
x_dtype = x['dtype']
a_dtype = A['dtype']
return {
'shape': tuple(b_shape),
'dtype': x_dtype,
'dtype': a_dtype,
'value': None
}
@ -276,3 +276,53 @@ class EigNet(nn.Cell):
if self.bv:
return (r[0], r[1])
return r[0]
class LU(PrimitiveWithInfer):
"""
LU decomposition with partial pivoting
A = P.L.U
"""
@prim_attr_register
def __init__(self):
super().__init__(name="LU")
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
ndim = len(x_shape)
if ndim in (1, 2):
permutation_shape = (x_shape[0], x_shape[0])
else:
permutation_shape = (x_shape[0], x_shape[1], x_shape[1])
output = {
'shape': (x_shape, permutation_shape[:-1], permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}
return output
class LUSolver(PrimitiveWithInfer):
"""
LUSolver for Ax = b
"""
@prim_attr_register
def __init__(self, trans: str):
super().__init__(name="LUSolver")
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
self.trans = validator.check_value_type("trans", trans, [str], self.name)
def __infer__(self, a, b):
b_shape = list(b['shape'])
a_dtype = a['dtype']
output = {
'shape': tuple(b_shape),
'dtype': a_dtype,
'value': None
}
return output

View File

@ -1,247 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import Generic
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from mindspore._checkparam import Validator as validator
import mindspore.numpy as mnp
import scipy as scp
import numpy as np
import pytest
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class LU(PrimitiveWithInfer):
"""
LU decomposition with partial pivoting
P.A = L.U
"""
@prim_attr_register
def __init__(self):
super().__init__(name="LU")
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
ndim = len(x_shape)
permutation_shape = x_shape
if ndim == 0:
pivots_shape = x_shape
elif ndim == 1:
pivots_shape = x_shape[:-1]
else:
pivots_shape = x_shape[-2:-1]
output = {
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}
return output
class LUSolver(PrimitiveWithInfer):
"""
LUSolver for Ax = b
"""
@prim_attr_register
def __init__(self, trans: str):
super().__init__(name="LUSolver")
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
self.trans = validator.check_value_type("trans", trans, [str], self.name)
def __infer__(self, x, b):
b_shape = list(b['shape'])
x_dtype = x['dtype']
output = {
'shape': tuple(b_shape),
'dtype': x_dtype,
'value': None
}
return output
class LuNet(nn.Cell):
def __init__(self):
super(LuNet, self).__init__()
self.lu = LU()
def construct(self, a):
return self.lu(a)
def lu_pivots_to_permutation(pivots, permutation_size: int):
batch_dims = pivots.shape[:-1]
k = pivots.shape[-1]
per = mnp.arange(0, permutation_size)
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
permutation = mnp.array(permutation)
if permutation_size == 0:
return permutation
for i in range(k):
j = pivots[..., i]
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[loc + (j,)]
permutation[..., i] = y
permutation[loc + (j,)] = x
return permutation
def _lu_solve_core(in_lu, permutation, b, trans):
m = in_lu.shape[0]
res_shape = b.shape[1:]
prod_result = 1
for sh in res_shape:
prod_result *= sh
x = mnp.reshape(b, (m, prod_result))
if trans == 0:
trans_str = "N"
x = x[permutation, :]
elif trans == 1:
trans_str = "T"
elif trans == 2:
trans_str = "C"
else:
raise ValueError("trans error, it's value must be 0, 1, 2")
ms_lu_solve = LUSolver(trans_str)
output = ms_lu_solve(in_lu, x)
return mnp.reshape(output, b.shape)
def _check_lu_shape(in_lu, b):
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
raise ValueError("last two dimensions of LU decomposition must be equal.")
if b.shape is None:
raise ValueError(" LU decomposition input b's rank must >=1.")
rhs_vector = in_lu.ndim == b.ndim + 1
if rhs_vector:
if b.shape[-1] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
mnp.expand_dims(b, axis=1)
else:
if b.shape[-2] != in_lu.shape[-1]:
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
def lu_factor(a, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
mscp_lu = LuNet()
m_lu, pivots, _ = mscp_lu(a)
return m_lu, pivots
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
mscp_lu = LuNet()
m_lu, _, p = mscp_lu(a)
m = a.shape[-2]
n = a.shape[-1]
k = min(m, n)
a_dtype = a.dtype
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.matmul(p, l), u
return p, l, u
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
del overwrite_b, check_finite
m_lu, pivots = lu_and_piv
# 1. check shape
_check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. calculate permutation
permutation = pivots
# 3. rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
x = _lu_solve_core(m_lu, permutation, b, trans)
return x[..., 0] if rhs_vector else x
def create_full_rank_matrix(m, n, dtype):
a_rank = 0
a = np.random.random((m, n)).astype(dtype)
while a_rank != m:
a = (a + np.eye(m, n)).astype(dtype)
a_rank = np.linalg.matrix_rank(a)
return a
def create_sym_pos_matrix(m, n, dtype):
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
return np.dot(a, a.T)
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_square_lu_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = create_full_rank_matrix(n, n, dtype)
s_lu, _ = scp.linalg.lu_factor(a)
mscp_lu_net = LuNet()
tensor_a = Tensor(a)
mscp_lu, _, _ = mscp_lu_net(tensor_a)
assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_lu_solver_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = create_full_rank_matrix(n, n, dtype)
b = np.random.random((n, 1)).astype(dtype)
s_lu, s_piv = scp.linalg.lu_factor(a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
mscp_lu_net = LuNet()
mscp_lu, pivots, _ = mscp_lu_net(tensor_a)
np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
lu_factor_x = (s_lu, s_piv)
msc_lu_factor = (mscp_lu, pivots)
scp_x = scp.linalg.lu_solve(lu_factor_x, b)
mscp_x = lu_solve(msc_lu_factor, tensor_b)
real_b = mnp.dot(tensor_a, mscp_x)
expected_b = np.dot(a, scp_x)
assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3)
assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3)

View File

@ -22,7 +22,8 @@ import scipy as osp
import mindspore.scipy as msp
from mindspore import context, Tensor
import mindspore.numpy as mnp
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
create_random_rank_matrix
onp.random.seed(0)
context.set_context(mode=context.PYNATIVE_MODE)
@ -227,3 +228,82 @@ def test_eigh_solver(n: int):
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (10, 5), (20, 20)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_lu(shape: (int, int), dtype):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = create_random_rank_matrix(shape, dtype)
s_p, s_l, s_u = osp.linalg.lu(a)
tensor_a = Tensor(a)
m_p, m_l, m_u = msp.linalg.lu(tensor_a)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(m_p.asnumpy(), s_p, rtol=rtol, atol=atol)
assert onp.allclose(m_l.asnumpy(), s_l, rtol=rtol, atol=atol)
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 10, 20])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_lu_factor(n: int, dtype):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = create_full_rank_matrix((n, n), dtype)
s_lu, _ = osp.linalg.lu_factor(a)
tensor_a = Tensor(a)
m_lu, pivots = msp.linalg.lu_factor(tensor_a)
m_l, m_u = onp.tril(m_lu.asnumpy(), k=-1) + onp.eye(n), onp.triu(m_lu.asnumpy())
s_l, s_u = onp.tril(s_lu, k=-1) + onp.eye(n), onp.triu(s_lu)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(m_lu.asnumpy(), s_lu, rtol=rtol, atol=atol)
assert onp.allclose(a[pivots.asnumpy()], onp.dot(m_l, m_u), rtol=rtol, atol=atol)
assert onp.allclose(a[pivots.asnumpy()], onp.dot(s_l, s_u), rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 10, 20])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_lu_solve(n: int, dtype):
"""
Feature: ALL To ALL
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = create_full_rank_matrix((n, n), dtype)
b = onp.random.random((n, 1)).astype(dtype)
s_lu, s_piv = osp.linalg.lu_factor(a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
m_lu, m_piv = msp.linalg.lu_factor(tensor_a)
lu_factor_x = (s_lu, s_piv)
msp_lu_factor = (m_lu, m_piv)
osp_x = osp.linalg.lu_solve(lu_factor_x, b)
msp_x = msp.linalg.lu_solve(msp_lu_factor, tensor_b)
real_b = mnp.dot(tensor_a, msp_x)
expected_b = onp.dot(a, osp_x)
rtol = 1.e-3
atol = 1.e-3
assert onp.allclose(real_b.asnumpy(), expected_b, rtol=rtol, atol=atol)
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)

View File

@ -44,7 +44,7 @@ def match_array(actual, expected, error=0, err_msg=''):
def create_full_rank_matrix(shape, dtype):
if len(shape) < 2 and shape[-1] != shape[-2]:
if len(shape) < 2 or shape[-1] != shape[-2]:
raise ValueError(
'Full rank matrix must be a square matrix, but has shape: ', shape)
@ -61,8 +61,15 @@ def create_full_rank_matrix(shape, dtype):
return a
def create_random_rank_matrix(shape, dtype):
if len(shape) < 2:
raise ValueError(
'random rank matrix must shape bigger than two dims, but has shape: ', shape)
return onp.random.random(shape).astype(dtype)
def create_sym_pos_matrix(shape, dtype):
if len(shape) != 2 and shape[0] != shape[1]:
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError(
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)