forked from mindspore-Ecosystem/mindspore
fix lu factor and solve for cpu backend
This commit is contained in:
parent
2427f7b3a4
commit
872d8b37da
|
@ -16,10 +16,10 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/LU"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -33,6 +33,7 @@ constexpr size_t kPermutationIndex = 2;
|
|||
constexpr size_t kLUDefaultShape = 1;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
constexpr int kZeroThreshold = INT32_MIN;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
|
@ -70,51 +71,145 @@ void LUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
void LUCPUKernel<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
size_t lu_size = lu_col_ * sizeof(T);
|
||||
(void)workspace_size_list_.emplace_back(lu_size);
|
||||
(void)workspace_size_list_.emplace_back(lu_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LUCPUKernel<T>::GetPermutatedValue(const T *lu_value, const int *per, size_t i, size_t j) {
|
||||
const T *pered_lu_value = lu_value + per[i] * lu_col_ + j;
|
||||
return *pered_lu_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, int *per, size_t k, size_t rows) {
|
||||
T max_major_value = static_cast<T>(kZeroThreshold);
|
||||
int max_major_index = 0;
|
||||
for (size_t i = k; i < rows; ++i) {
|
||||
T value = GetPermutatedValue(lu_value, per, i, k);
|
||||
T abs_value = std::abs(value);
|
||||
if (abs_value > max_major_value) {
|
||||
max_major_value = abs_value;
|
||||
max_major_index = i;
|
||||
}
|
||||
}
|
||||
size_t per_k = per[k];
|
||||
per[k] = per[max_major_index];
|
||||
per[max_major_index] = per_k;
|
||||
return max_major_value != static_cast<T>(kZeroThreshold);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LUCPUKernel<T>::SetPermutatedValue(T *lu_value, const int *per, size_t i, size_t j, const T &value) {
|
||||
T *pered_lu_value = lu_value + per[i] * lu_col_ + j;
|
||||
*pered_lu_value = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
// input matrix of (m,n) PA = LU
|
||||
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
|
||||
|
||||
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
Map<Matrix<T, RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
|
||||
int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
|
||||
// pivots permutation value
|
||||
int *per_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
|
||||
// permutation matrix value
|
||||
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
|
||||
Map<Matrix<int, RowMajor>> output_permutation(permutation_value, permutation_row_, permutation_col_);
|
||||
|
||||
if (a_row_ == a_col_) {
|
||||
// partial_piv_lu
|
||||
auto partial_lu = input_a.lu();
|
||||
auto partial_p = partial_lu.permutationP();
|
||||
output_lu.noalias() = partial_lu.matrixLU();
|
||||
output_permutation.noalias() = partial_p.toDenseMatrix();
|
||||
} else {
|
||||
// full_piv_lu
|
||||
auto full_piv_lu = input_a.fullPivLu();
|
||||
auto full_piv_p = full_piv_lu.permutationP();
|
||||
output_lu.noalias() = full_piv_lu.matrixLU();
|
||||
output_permutation.noalias() = full_piv_p.toDenseMatrix();
|
||||
T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr);
|
||||
T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr);
|
||||
// init pivots
|
||||
for (size_t i = 0; i < pivots_row_; ++i) {
|
||||
per_value[i] = i;
|
||||
}
|
||||
|
||||
// calculate permutation array from permutation matrix to indicate scipy's pivots.
|
||||
for (int i = 0; i < static_cast<int>(output_permutation.rows()); ++i) {
|
||||
if (output_permutation(i, i) != 0) {
|
||||
pivots_value[i] = i;
|
||||
// 1. memcpy input to output, do full lu inplace.
|
||||
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
|
||||
|
||||
int s = std::min(a_row_, a_col_);
|
||||
// 2. do lu decompose inplace
|
||||
for (int k = 0; k < s; ++k) {
|
||||
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
|
||||
if (!UpdateMajorPermutation(lu_value, per_value, k, lu_row_)) {
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < static_cast<int>(output_permutation.cols()); ++j) {
|
||||
if (output_permutation(i, j) != 0) {
|
||||
pivots_value[i] = j;
|
||||
break;
|
||||
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
|
||||
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
|
||||
// 2.3 change major col values
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
|
||||
// set inplace new lu matrix value.
|
||||
SetPermutatedValue(lu_value, per_value, i, k, y);
|
||||
}
|
||||
|
||||
// 2.4 Gauss elimination core
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
for (size_t j = k + 1; j < lu_col_; ++j) {
|
||||
T y =
|
||||
static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
|
||||
GetPermutatedValue(lu_value, per_value, i, k) * GetPermutatedValue(lu_value, per_value, k, j));
|
||||
SetPermutatedValue(lu_value, per_value, i, j, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
// here, we note that eigen calculate permutation matrix is col major, so transpose it to row major,
|
||||
// but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax.
|
||||
output_permutation.transposeInPlace();
|
||||
if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) {
|
||||
return true;
|
||||
|
||||
// 3. calculate final lu by permutation list
|
||||
std::unordered_map<int, std::pair<int, bool>> pivots_map;
|
||||
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
|
||||
pivots_map[per_value[i]] = {i, false};
|
||||
}
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
|
||||
int pivots_count = 0;
|
||||
for (const auto &pivot : pivots_map) {
|
||||
pivots_count++;
|
||||
int key = pivot.first;
|
||||
int index = pivot.second.first;
|
||||
bool is_visited = pivot.second.second;
|
||||
if (is_visited || index == (pivots_count - 1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
T *lu_ori_row = lu_value + index * lu_col_;
|
||||
T *lu_trans_row = lu_value + key * lu_col_;
|
||||
// copy ori data to trans lu
|
||||
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data ptr
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
|
||||
// update pivot map
|
||||
pivots_map[key] = {index, true};
|
||||
// put ori data which stored in workspace to mapped new place
|
||||
is_visited = pivots_map[index].second;
|
||||
while (!is_visited) {
|
||||
key = index;
|
||||
index = pivots_map[key].first;
|
||||
is_visited = pivots_map[key].second;
|
||||
lu_ori_row = lu_value + index * lu_col_;
|
||||
T *tmp_wk = lu_trans_wk;
|
||||
lu_trans_wk = lu_ori_wk;
|
||||
lu_ori_wk = tmp_wk;
|
||||
// copy new ori data to trans workspace
|
||||
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data place
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
|
||||
pivots_map[key] = {index, true};
|
||||
}
|
||||
}
|
||||
|
||||
// 4. calculate final permutation matrix
|
||||
// for PA = LU get: base + row * permutation_row_ + col
|
||||
// for A = PLU get: base + col * permutation_row_ + row
|
||||
// here, we do A = PLU which is same as scipy.
|
||||
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
|
||||
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
|
||||
for (size_t i = 0; i < pivots_row_; ++i) {
|
||||
int position = per_value[i];
|
||||
int *per_addr = permutation_value + position * permutation_row_ + i;
|
||||
*per_addr = 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,10 @@ class LUCPUKernel : public CPUKernel {
|
|||
|
||||
private:
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
T GetPermutatedValue(const T *lu_value, const int *per, size_t i, size_t j);
|
||||
bool UpdateMajorPermutation(T *lu_value, int *per, size_t k, size_t rows);
|
||||
void SetPermutatedValue(T *lu_value, const int *per, size_t i, size_t j, const T &value);
|
||||
size_t a_row_{1};
|
||||
size_t a_col_{1};
|
||||
size_t lu_row_{1};
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/eigen/matmul_double_cpu_kernel.h"
|
||||
#define EIGEN_NO_MALLOC
|
||||
#include <Eigen/Dense>
|
||||
#include <vector>
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "backend/kernel_compiler/cpu/eigen/matrix_inverse_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#define EIGEN_NO_MALLOC
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -15,12 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/eigen/solve_triangular_cpu_kernel.h"
|
||||
#define EIGEN_NO_MALLOC
|
||||
#include <Eigen/Dense>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Eigen::ColMajor;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from . import optimize, sparse, linalg
|
||||
from .optimize import minimize, line_search
|
||||
from .sparse import cg, gmres
|
||||
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve
|
||||
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve, lu, lu_factor, lu_solve
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(optimize.__all__)
|
||||
|
|
|
@ -18,10 +18,13 @@ from .. import ops
|
|||
from .ops import SolveTriangular
|
||||
from .ops import CholeskySolver
|
||||
from .ops import Cholesky
|
||||
from .ops import LU
|
||||
from .ops import LUSolver
|
||||
from .ops import EighNet
|
||||
from ..ops import operations as P
|
||||
|
||||
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh']
|
||||
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh', 'lu_factor', 'lu',
|
||||
'lu_solve']
|
||||
|
||||
|
||||
def block_diag(*arrs):
|
||||
|
@ -102,14 +105,10 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
Default is to use upper triangle.
|
||||
trans (0, 1, 2, 'N', 'T', 'C', optional):
|
||||
Type of system to solve:
|
||||
|
||||
======== =========
|
||||
trans system
|
||||
======== =========
|
||||
0 or 'N' a x = b
|
||||
1 or 'T' a^T x = b
|
||||
2 or 'C' a^H x = b
|
||||
======== =========
|
||||
trans: system:
|
||||
0 or 'N' a x = b
|
||||
1 or 'T' a^T x = b
|
||||
2 or 'C' a^H x = b
|
||||
unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and
|
||||
will not be referenced.
|
||||
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance)
|
||||
|
@ -160,21 +159,18 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
Compute the inverse of a matrix.
|
||||
|
||||
Args:
|
||||
a (Tensor): Tensor
|
||||
Square matrix to be inverted.
|
||||
overwrite_a (bool, optional): Discard data in `a` (may improve performance).
|
||||
Default is False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||
only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
a (Tensor): Square matrix to be inverted.
|
||||
overwrite_a (bool, optional): Discard data in `a` (may improve performance). Default is False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
||||
if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
ainv (Tensor): Inverse of the matrix `a`.
|
||||
|
||||
Raises:
|
||||
LinAlgError: If `a` is singular.
|
||||
ValueError: If `a` is not square, or not 2D.
|
||||
LinAlgError: If :math:'a' is singular.
|
||||
ValueError: If :math:'a' is not square, or not 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -328,52 +324,40 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
Solve a standard or generalized eigenvalue problem for a complex
|
||||
Hermitian or real symmetric matrix.
|
||||
|
||||
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of
|
||||
Tensor ``a``, where ``b`` is positive definite such that for every
|
||||
eigenvalue λ (i-th entry of w) and its eigenvector ``vi`` (i-th column of
|
||||
``v``) satisfies::
|
||||
|
||||
a @ vi = λ * b @ vi
|
||||
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of Tensor ``a``,
|
||||
where ``b`` is positive definite such that for every eigenvalue λ (i-th entry of w) and
|
||||
its eigenvector ``vi`` (i-th column of``v``) satisfies::
|
||||
a @ vi = λ * b @ vi
|
||||
vi.conj().T @ a @ vi = λ
|
||||
vi.conj().T @ b @ vi = 1
|
||||
|
||||
In the standard problem, ``b`` is assumed to be the identity matrix.
|
||||
|
||||
Args:
|
||||
a (Tensor): (M, M) Tensor
|
||||
A complex Hermitian or real symmetric matrix whose eigenvalues and
|
||||
a (Tensor): A (M, M) complex Hermitian or real symmetric matrix whose eigenvalues and
|
||||
eigenvectors will be computed.
|
||||
b (Tensor, optional): (M, M) Tensor
|
||||
A complex Hermitian or real symmetric definite positive matrix in.
|
||||
b (Tensor, optional): A (M, M) complex Hermitian or real symmetric definite positive matrix in.
|
||||
If omitted, identity matrix is assumed.
|
||||
lower (bool, optional): Whether the pertinent Tensor data is taken from
|
||||
the lower or upper triangle of ``a`` and, if applicable, ``b``. (Default: lower)
|
||||
eigvals_only (bool, optional): Whether to calculate only eigenvalues
|
||||
and no eigenvectors. (Default: both are calculated)
|
||||
_type (int, optional): For the generalized problems, this keyword specifies
|
||||
the problem type to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible
|
||||
inputs)::
|
||||
|
||||
lower (bool, optional): Whether the pertinent Tensor data is taken from the lower or upper
|
||||
triangle of ``a`` and, if applicable, ``b``. (Default: lower)
|
||||
eigvals_only (bool, optional): Whether to calculate only eigenvalues and no eigenvectors.
|
||||
(Default: both are calculated)
|
||||
_type (int, optional): For the generalized problems, this keyword specifies the problem type
|
||||
to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible inputs)::
|
||||
1 => a @ v = w @ b @ v
|
||||
2 => a @ b @ v = w @ v
|
||||
3 => b @ a @ v = w @ v
|
||||
|
||||
This keyword is ignored for standard problems.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in ``a``
|
||||
(may improve performance). Default is False.
|
||||
overwrite_b (bool, optional): Whether to overwrite data in ``b``
|
||||
(may improve performance). Default is False.
|
||||
check_finite (bool, optional): Whether to check that the input matrices
|
||||
contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
turbo (bool, optional): use divide and conquer algorithm (faster but
|
||||
expensive in memory, only for generalized eigenvalue problem and
|
||||
if full set of eigenvalues are requested.). Has no significant
|
||||
effect if eigenvectors are not requested.
|
||||
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order)
|
||||
eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1.
|
||||
If omitted, all eigenvalues and eigenvectors are returned.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in ``a`` (may improve performance). Default is False.
|
||||
overwrite_b (bool, optional): Whether to overwrite data in ``b`` (may improve performance). Default is False.
|
||||
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
||||
if the inputs do contain infinities or NaNs.
|
||||
turbo (bool, optional): use divide and conquer algorithm (faster but expensive in memory, only
|
||||
for generalized eigenvalue problem and if full set of eigenvalues are requested.).
|
||||
Has no significant effect if eigenvectors are not requested.
|
||||
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order) eigenvalues
|
||||
and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1. If omitted, all eigenvalues
|
||||
and eigenvectors are returned.
|
||||
|
||||
Returns:
|
||||
w (Tensor): (N,) Tensor, The N (1<=N<=M) selected eigenvalues, in ascending order,
|
||||
|
@ -381,10 +365,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
v (Tensor): (M, N) Tensor, (if ``eigvals_only == False``)
|
||||
|
||||
Raises:
|
||||
LinAlgError: If eigenvalue computation does not converge, an error occurred, or
|
||||
b matrix is not definite positive. Note that if input matrices are
|
||||
not symmetric or Hermitian, no error will be reported but results will
|
||||
be wrong.
|
||||
LinAlgError: If eigenvalue computation does not converge, an error occurred, or b matrix is not
|
||||
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
|
||||
be reported but results will be wrong.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -400,3 +383,227 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
"""
|
||||
eigh_net = EighNet(not eigvals_only, lower=True)
|
||||
return eigh_net(a)
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(pivots, permutation_size: int):
|
||||
"""transfer pivots to permutation"""
|
||||
batch_dims = pivots.shape[:-1]
|
||||
k = pivots.shape[-1]
|
||||
per = mnp.arange(0, permutation_size)
|
||||
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
|
||||
permutation = mnp.array(permutation)
|
||||
if permutation_size == 0:
|
||||
return permutation
|
||||
|
||||
for i in range(k):
|
||||
j = pivots[..., i]
|
||||
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[loc + (j,)]
|
||||
permutation[..., i] = y
|
||||
permutation[loc + (j,)] = x
|
||||
return permutation
|
||||
|
||||
|
||||
def lu_solve_core(in_lu, permutation, b, trans):
|
||||
""" core implementation of lu solve"""
|
||||
m = in_lu.shape[0]
|
||||
res_shape = b.shape[1:]
|
||||
prod_result = 1
|
||||
for sh in res_shape:
|
||||
prod_result *= sh
|
||||
x = mnp.reshape(b, (m, prod_result))
|
||||
if trans == 0:
|
||||
trans_str = "N"
|
||||
x = x[permutation, :]
|
||||
elif trans == 1:
|
||||
trans_str = "T"
|
||||
elif trans == 2:
|
||||
trans_str = "C"
|
||||
else:
|
||||
raise ValueError("trans error, it's value must be 0, 1, 2")
|
||||
ms_lu_solve = LUSolver(trans_str)
|
||||
output = ms_lu_solve(in_lu, x)
|
||||
return mnp.reshape(output, b.shape)
|
||||
|
||||
|
||||
def check_lu_shape(in_lu, b):
|
||||
""" check lu input shape"""
|
||||
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
|
||||
raise ValueError("last two dimensions of LU decomposition must be equal.")
|
||||
|
||||
if b.shape is None:
|
||||
raise ValueError(" LU decomposition input b's rank must >=1.")
|
||||
rhs_vector = in_lu.ndim == b.ndim + 1
|
||||
if rhs_vector:
|
||||
if b.shape[-1] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
mnp.expand_dims(b, axis=1)
|
||||
else:
|
||||
if b.shape[-2] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
|
||||
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute pivoted LU decomposition of a matrix.
|
||||
The decomposition is::
|
||||
A = P L U
|
||||
where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.
|
||||
|
||||
Args:
|
||||
a (Tensor): square matrix of (M, M) to decompose
|
||||
overwrite_a (bool, optional): Whether to overwrite data in A (may increase performance)
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
lu (Tensor): a square matrix of (N, N) containing U in its upper triangle, and L in its lower triangle.
|
||||
The unit diagonal elements of L are not stored.
|
||||
piv (Tensor): (N,) Pivot indices representing the permutation matrix P:
|
||||
row i of matrix was interchanged with row piv[i].
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import lu_factor
|
||||
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
|
||||
>>> lu, piv = lu_factor(A)
|
||||
>>> lu
|
||||
[[ 7. , 5. , 6. , 6. ],
|
||||
[ 0.28571429, 3.57142857, 6.28571429, 5.28571429],
|
||||
[ 0.71428571, 0.12 , -1.04 , 3.08 ],
|
||||
[ 0.71428571, -0.44 , -0.46153846, 7.46153846]]
|
||||
>>> piv
|
||||
[2, 0, 3, 1]
|
||||
"""
|
||||
del overwrite_a, check_finite
|
||||
msp_lu = LU()
|
||||
m_lu, pivots, _ = msp_lu(a)
|
||||
return m_lu, pivots
|
||||
|
||||
|
||||
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute pivoted LU decomposition of a matrix.
|
||||
|
||||
The decomposition is::
|
||||
|
||||
A = P L U
|
||||
|
||||
where P is a permutation matrix, L lower triangular with unit
|
||||
diagonal elements, and U upper triangular.
|
||||
|
||||
Args:
|
||||
a (Tensor): a (M, N) matrix to decompose
|
||||
permute_l (bool, optional): Perform the multiplication P*L (Default: do not permute)
|
||||
overwrite_a (bool, optional): Whether to overwrite data in a (may improve performance)
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||
only finite numbers. Disabling may give a performance gain, but may result
|
||||
in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
**(If permute_l == False)**
|
||||
|
||||
p (Tensor): (M, M) Permutation matrix
|
||||
l (Tensor): (M, K) Lower triangular or trapezoidal matrix with unit diagonal.
|
||||
K = min(M, N)
|
||||
u (Tensor): (K, N) Upper triangular or trapezoidal matrix
|
||||
|
||||
**(If permute_l == True)**
|
||||
|
||||
pl (Tensor): (M, K) Permuted L matrix.
|
||||
K = min(M, N)
|
||||
u (Tensor): (K, N) Upper triangular or trapezoidal matrix
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import lu
|
||||
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
|
||||
>>> p, l, u = lu(A)
|
||||
>>> p
|
||||
[[0., 1., 0., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
[1., 0., 0., 0.],
|
||||
[0., 0., 1., 0.]]
|
||||
>>> l
|
||||
[[ 1. , 0. , 0. , 0. ],
|
||||
[ 0.28571429, 1. , 0. , 0. ],
|
||||
[ 0.71428571, 0.12 , 1. , 0. ],
|
||||
[ 0.71428571, -0.44 , -0.46153846, 1. ]]
|
||||
>>> u
|
||||
[[ 7. , 5. , 6. , 6. ],
|
||||
[ 0. , 3.57142857, 6.28571429, 5.28571429],
|
||||
[ 0. , 0. , -1.04 , 3.08 ],
|
||||
[ 0. , 0. , 0. , 7.46153846]]
|
||||
"""
|
||||
del overwrite_a, check_finite
|
||||
msp_lu = LU()
|
||||
m_lu, _, p = msp_lu(a)
|
||||
m = a.shape[-2]
|
||||
n = a.shape[-1]
|
||||
k = min(m, n)
|
||||
a_dtype = a.dtype
|
||||
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.dot(p, l), u
|
||||
return p, l, u
|
||||
|
||||
|
||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
"""Solve an equation system, a x = b, given the LU factorization of a
|
||||
|
||||
Args:
|
||||
lu_and_piv (Tensor, Tensor): Factorization of the coefficient matrix a, as given by lu_factor
|
||||
b (Tensor): Right-hand side
|
||||
trans (int, optional): {0, 1, 2}
|
||||
Type of system to solve:
|
||||
===== =========
|
||||
trans system
|
||||
===== =========
|
||||
0 a x = b
|
||||
1 a^T x = b
|
||||
2 a^H x = b
|
||||
===== =========
|
||||
overwrite_b (bool, optional): Whether to overwrite data in b (may increase performance)
|
||||
check_finite ( bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
||||
if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
x (Tesnor): Solution to the system
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import lu_factor, lu_solve
|
||||
>>> A = Tensor(onp.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]).astype(onp.float64))
|
||||
>>> b = Tensor(onp.array([1, 1, 1, 1]).astype(onp.float64))
|
||||
>>> lu, piv = lu_factor(A)
|
||||
>>> lu_solve((lu, piv), b)
|
||||
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
|
||||
"""
|
||||
|
||||
del overwrite_b, check_finite
|
||||
m_lu, pivots = lu_and_piv
|
||||
# 1. check shape
|
||||
check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. calculate permutation
|
||||
permutation = pivots
|
||||
# 3. rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = lu_solve_core(m_lu, permutation, b, trans)
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
|
|
@ -101,7 +101,7 @@ class SolveTriangular(PrimitiveWithInfer):
|
|||
|
||||
class Cholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for _Cholesky base class.
|
||||
Cholesky decomposition for A.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -175,14 +175,14 @@ class CholeskySolver(PrimitiveWithInfer):
|
|||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolver")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['y'])
|
||||
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, b):
|
||||
def __infer__(self, A, b):
|
||||
b_shape = b['shape']
|
||||
x_dtype = x['dtype']
|
||||
a_dtype = A['dtype']
|
||||
return {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': x_dtype,
|
||||
'dtype': a_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
@ -276,3 +276,53 @@ class EigNet(nn.Cell):
|
|||
if self.bv:
|
||||
return (r[0], r[1])
|
||||
return r[0]
|
||||
|
||||
|
||||
class LU(PrimitiveWithInfer):
|
||||
"""
|
||||
LU decomposition with partial pivoting
|
||||
A = P.L.U
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super().__init__(name="LU")
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = list(x['shape'])
|
||||
x_dtype = x['dtype']
|
||||
ndim = len(x_shape)
|
||||
|
||||
if ndim in (1, 2):
|
||||
permutation_shape = (x_shape[0], x_shape[0])
|
||||
else:
|
||||
permutation_shape = (x_shape[0], x_shape[1], x_shape[1])
|
||||
output = {
|
||||
'shape': (x_shape, permutation_shape[:-1], permutation_shape),
|
||||
'dtype': (x_dtype, mstype.int32, mstype.int32),
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class LUSolver(PrimitiveWithInfer):
|
||||
"""
|
||||
LUSolver for Ax = b
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, trans: str):
|
||||
super().__init__(name="LUSolver")
|
||||
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
|
||||
self.trans = validator.check_value_type("trans", trans, [str], self.name)
|
||||
|
||||
def __infer__(self, a, b):
|
||||
b_shape = list(b['shape'])
|
||||
a_dtype = a['dtype']
|
||||
output = {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': a_dtype,
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
|
|
@ -1,247 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from typing import Generic
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
import mindspore.numpy as mnp
|
||||
import scipy as scp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class LU(PrimitiveWithInfer):
|
||||
"""
|
||||
LU decomposition with partial pivoting
|
||||
P.A = L.U
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super().__init__(name="LU")
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = list(x['shape'])
|
||||
x_dtype = x['dtype']
|
||||
ndim = len(x_shape)
|
||||
permutation_shape = x_shape
|
||||
if ndim == 0:
|
||||
pivots_shape = x_shape
|
||||
elif ndim == 1:
|
||||
pivots_shape = x_shape[:-1]
|
||||
else:
|
||||
pivots_shape = x_shape[-2:-1]
|
||||
output = {
|
||||
'shape': (x_shape, pivots_shape, permutation_shape),
|
||||
'dtype': (x_dtype, mstype.int32, mstype.int32),
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class LUSolver(PrimitiveWithInfer):
|
||||
"""
|
||||
LUSolver for Ax = b
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, trans: str):
|
||||
super().__init__(name="LUSolver")
|
||||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
|
||||
self.trans = validator.check_value_type("trans", trans, [str], self.name)
|
||||
|
||||
def __infer__(self, x, b):
|
||||
b_shape = list(b['shape'])
|
||||
x_dtype = x['dtype']
|
||||
output = {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class LuNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LuNet, self).__init__()
|
||||
self.lu = LU()
|
||||
|
||||
def construct(self, a):
|
||||
return self.lu(a)
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(pivots, permutation_size: int):
|
||||
batch_dims = pivots.shape[:-1]
|
||||
k = pivots.shape[-1]
|
||||
per = mnp.arange(0, permutation_size)
|
||||
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
|
||||
permutation = mnp.array(permutation)
|
||||
if permutation_size == 0:
|
||||
return permutation
|
||||
|
||||
for i in range(k):
|
||||
j = pivots[..., i]
|
||||
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[loc + (j,)]
|
||||
permutation[..., i] = y
|
||||
permutation[loc + (j,)] = x
|
||||
return permutation
|
||||
|
||||
|
||||
def _lu_solve_core(in_lu, permutation, b, trans):
|
||||
m = in_lu.shape[0]
|
||||
res_shape = b.shape[1:]
|
||||
prod_result = 1
|
||||
for sh in res_shape:
|
||||
prod_result *= sh
|
||||
x = mnp.reshape(b, (m, prod_result))
|
||||
if trans == 0:
|
||||
trans_str = "N"
|
||||
x = x[permutation, :]
|
||||
elif trans == 1:
|
||||
trans_str = "T"
|
||||
elif trans == 2:
|
||||
trans_str = "C"
|
||||
else:
|
||||
raise ValueError("trans error, it's value must be 0, 1, 2")
|
||||
ms_lu_solve = LUSolver(trans_str)
|
||||
output = ms_lu_solve(in_lu, x)
|
||||
return mnp.reshape(output, b.shape)
|
||||
|
||||
|
||||
def _check_lu_shape(in_lu, b):
|
||||
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
|
||||
raise ValueError("last two dimensions of LU decomposition must be equal.")
|
||||
|
||||
if b.shape is None:
|
||||
raise ValueError(" LU decomposition input b's rank must >=1.")
|
||||
rhs_vector = in_lu.ndim == b.ndim + 1
|
||||
if rhs_vector:
|
||||
if b.shape[-1] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
mnp.expand_dims(b, axis=1)
|
||||
else:
|
||||
if b.shape[-2] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
|
||||
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
mscp_lu = LuNet()
|
||||
m_lu, pivots, _ = mscp_lu(a)
|
||||
return m_lu, pivots
|
||||
|
||||
|
||||
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
mscp_lu = LuNet()
|
||||
m_lu, _, p = mscp_lu(a)
|
||||
m = a.shape[-2]
|
||||
n = a.shape[-1]
|
||||
k = min(m, n)
|
||||
a_dtype = a.dtype
|
||||
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.matmul(p, l), u
|
||||
return p, l, u
|
||||
|
||||
|
||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
del overwrite_b, check_finite
|
||||
m_lu, pivots = lu_and_piv
|
||||
# 1. check shape
|
||||
_check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. calculate permutation
|
||||
permutation = pivots
|
||||
# 3. rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = _lu_solve_core(m_lu, permutation, b, trans)
|
||||
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
||||
def create_full_rank_matrix(m, n, dtype):
|
||||
a_rank = 0
|
||||
a = np.random.random((m, n)).astype(dtype)
|
||||
while a_rank != m:
|
||||
a = (a + np.eye(m, n)).astype(dtype)
|
||||
a_rank = np.linalg.matrix_rank(a)
|
||||
return a
|
||||
|
||||
|
||||
def create_sym_pos_matrix(m, n, dtype):
|
||||
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
|
||||
return np.dot(a, a.T)
|
||||
|
||||
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_square_lu_net(n: int, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix(n, n, dtype)
|
||||
s_lu, _ = scp.linalg.lu_factor(a)
|
||||
mscp_lu_net = LuNet()
|
||||
tensor_a = Tensor(a)
|
||||
mscp_lu, _, _ = mscp_lu_net(tensor_a)
|
||||
assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
|
||||
|
||||
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_lu_solver_net(n: int, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix(n, n, dtype)
|
||||
b = np.random.random((n, 1)).astype(dtype)
|
||||
s_lu, s_piv = scp.linalg.lu_factor(a)
|
||||
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
mscp_lu_net = LuNet()
|
||||
mscp_lu, pivots, _ = mscp_lu_net(tensor_a)
|
||||
np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
|
||||
|
||||
lu_factor_x = (s_lu, s_piv)
|
||||
msc_lu_factor = (mscp_lu, pivots)
|
||||
|
||||
scp_x = scp.linalg.lu_solve(lu_factor_x, b)
|
||||
mscp_x = lu_solve(msc_lu_factor, tensor_b)
|
||||
|
||||
real_b = mnp.dot(tensor_a, mscp_x)
|
||||
expected_b = np.dot(a, scp_x)
|
||||
|
||||
assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3)
|
||||
assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3)
|
|
@ -22,7 +22,8 @@ import scipy as osp
|
|||
import mindspore.scipy as msp
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.numpy as mnp
|
||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
|
||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
||||
create_random_rank_matrix
|
||||
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -227,3 +228,82 @@ def test_eigh_solver(n: int):
|
|||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (10, 5), (20, 20)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu(shape: (int, int), dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_random_rank_matrix(shape, dtype)
|
||||
s_p, s_l, s_u = osp.linalg.lu(a)
|
||||
tensor_a = Tensor(a)
|
||||
m_p, m_l, m_u = msp.linalg.lu(tensor_a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(m_p.asnumpy(), s_p, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(m_l.asnumpy(), s_l, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 10, 20])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu_factor(n: int, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
s_lu, _ = osp.linalg.lu_factor(a)
|
||||
tensor_a = Tensor(a)
|
||||
m_lu, pivots = msp.linalg.lu_factor(tensor_a)
|
||||
m_l, m_u = onp.tril(m_lu.asnumpy(), k=-1) + onp.eye(n), onp.triu(m_lu.asnumpy())
|
||||
s_l, s_u = onp.tril(s_lu, k=-1) + onp.eye(n), onp.triu(s_lu)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(m_lu.asnumpy(), s_lu, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(a[pivots.asnumpy()], onp.dot(m_l, m_u), rtol=rtol, atol=atol)
|
||||
assert onp.allclose(a[pivots.asnumpy()], onp.dot(s_l, s_u), rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 10, 20])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu_solve(n: int, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.random((n, 1)).astype(dtype)
|
||||
s_lu, s_piv = osp.linalg.lu_factor(a)
|
||||
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
|
||||
m_lu, m_piv = msp.linalg.lu_factor(tensor_a)
|
||||
|
||||
lu_factor_x = (s_lu, s_piv)
|
||||
msp_lu_factor = (m_lu, m_piv)
|
||||
|
||||
osp_x = osp.linalg.lu_solve(lu_factor_x, b)
|
||||
msp_x = msp.linalg.lu_solve(msp_lu_factor, tensor_b)
|
||||
real_b = mnp.dot(tensor_a, msp_x)
|
||||
expected_b = onp.dot(a, osp_x)
|
||||
rtol = 1.e-3
|
||||
atol = 1.e-3
|
||||
assert onp.allclose(real_b.asnumpy(), expected_b, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)
|
||||
|
|
|
@ -44,7 +44,7 @@ def match_array(actual, expected, error=0, err_msg=''):
|
|||
|
||||
|
||||
def create_full_rank_matrix(shape, dtype):
|
||||
if len(shape) < 2 and shape[-1] != shape[-2]:
|
||||
if len(shape) < 2 or shape[-1] != shape[-2]:
|
||||
raise ValueError(
|
||||
'Full rank matrix must be a square matrix, but has shape: ', shape)
|
||||
|
||||
|
@ -61,8 +61,15 @@ def create_full_rank_matrix(shape, dtype):
|
|||
return a
|
||||
|
||||
|
||||
def create_random_rank_matrix(shape, dtype):
|
||||
if len(shape) < 2:
|
||||
raise ValueError(
|
||||
'random rank matrix must shape bigger than two dims, but has shape: ', shape)
|
||||
return onp.random.random(shape).astype(dtype)
|
||||
|
||||
|
||||
def create_sym_pos_matrix(shape, dtype):
|
||||
if len(shape) != 2 and shape[0] != shape[1]:
|
||||
if len(shape) != 2 or shape[0] != shape[1]:
|
||||
raise ValueError(
|
||||
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue