!26546 Unify GPU/CPU ops input/output(col/rolmajor), modify related testcases, add linalg function and testcases

Merge pull request !26546 from wuwenbing/master
This commit is contained in:
i-robot 2021-11-22 04:34:33 +00:00 committed by Gitee
commit 69c4f470e4
14 changed files with 360 additions and 186 deletions

View File

@ -24,8 +24,6 @@ namespace kernel {
namespace {
constexpr size_t kInputsNum = 1;
constexpr size_t kOutputsNum = 2;
constexpr size_t kDefaultShape = 1;
constexpr auto kAMatrixDimNum = 2;
} // namespace
using Eigen::Dynamic;
@ -45,12 +43,8 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row
template <typename T, typename C>
void EigCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
<< "]";

View File

@ -22,10 +22,8 @@ namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kInputsNum = 1;
constexpr size_t kOutputsNum = 2;
constexpr size_t kDefaultShape = 1;
constexpr auto kAMatrixDimNum = 2;
} // namespace
using Eigen::Dynamic;
@ -45,12 +43,9 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row
template <typename T>
void EighCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
compute_eigen_vectors_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
<< "]";
@ -91,10 +86,8 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
auto symmetric_type = reinterpret_cast<bool *>(inputs[1]->addr);
// is the Matrix a symmetric matrix(true lower triangle, false upper triangle)
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto output_v_addr = reinterpret_cast<T *>(outputs[1]->addr);
Map<MatrixSquare<T>> A(A_addr, m_, m_);
@ -102,19 +95,19 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::
Map<MatrixSquare<T>> output(output_addr, m_, 1);
Map<MatrixSquare<T>> outputv(output_v_addr, m_, m_);
// selfadjoint matrix
if (*symmetric_type) {
if (lower_) {
A_ = A.template selfadjointView<Lower>();
} else {
A_ = A.template selfadjointView<Upper>();
}
// Real scalar eigen solver
if constexpr (std::is_same_v<T, float>) {
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
} else if constexpr (std::is_same_v<T, double>) {
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
} else {
// complex eigen solver
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors);
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors_);
}
return true;
}

View File

@ -46,36 +46,29 @@ class EighCPUKernel : public CPUKernel {
private:
size_t m_{1};
bool compute_eigen_vectors{false};
bool compute_eigen_vectors_{false};
bool lower_{true};
TypeId dtype_{kNumberTypeFloat32};
};
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EighCPUKernel, float);
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
EighCPUKernel, double);
MS_REG_CPU_KERNEL_T(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighCPUKernel, double);
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float_complex);
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double_complex);

View File

@ -18,6 +18,10 @@
#include "transpose_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "utils/complex.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
__global__ void Transpose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
@ -74,3 +78,9 @@ template void CalTranspose<int>(const size_t size, const int *input, const size_
template void CalTranspose<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, int64_t *output,
cudaStream_t cuda_stream);
template void CalTranspose<Complex<float>>(const size_t size, const Complex<float> *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, Complex<float> *output,
cudaStream_t cuda_stream);
template void CalTranspose<Complex<double>>(const size_t size, const Complex<double> *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, Complex<double> *output,
cudaStream_t cuda_stream);

View File

@ -21,14 +21,12 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighcGpuKernel, Complex<float>)
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighcGpuKernel, Complex<double>);

View File

@ -32,10 +32,12 @@
#include "utils/convert_utils.h"
#include "utils/complex.h"
#include "backend/kernel_compiler/gpu/cuda_impl/real_to_complex_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
constexpr char LOWER[] = "lower";
template <typename T>
using Complex = mindspore::utils::Complex<T>;
@ -61,6 +63,7 @@ class EighcGpuKernel : public GpuKernel {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER));
if (compute_eigen_vectors_) {
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
} else {
@ -84,13 +87,7 @@ class EighcGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// matrix A, input or output(eigenvector)
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
auto lower = GetDeviceAddress<bool>(inputs, 1);
bool h_lower{true};
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy lower do device failed");
if (h_lower) {
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
@ -105,24 +102,39 @@ class EighcGpuKernel : public GpuKernel {
// temp output eigenvalues real scalar
auto w_w_addr = GetDeviceAddress<D>(workspace, 0);
auto w_w_c_addr = GetDeviceAddress<T>(workspace, 1);
// temp eigenvector before transpose
auto w_v_addr = GetDeviceAddress<T>(workspace, 2);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy input matrix failed");
size_t input_shape[kShape2dDims] = {m_, m_};
size_t input_axis[kShape2dDims] = {1, 0};
size_t *dev_input_shape = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
size_t *dev_input_axis = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * m_, output_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, w_v_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
int lwork = 0;
if constexpr (std::is_same_v<T, Complex<float>>) {
cusolverDnCheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr),
lda_, w_w_addr, &lwork);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
"cal eigenvalues workspace failed");
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr), lda_, w_w_addr,
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(w_v_addr), lda_, w_w_addr,
reinterpret_cast<cuComplex *>(d_work), lwork, devInfo);
} else {
cusolverDnZheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_,
reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, w_w_addr, &lwork);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
"cal eigenvalues workspace failed");
cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_,
cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(w_v_addr), lda_,
w_w_addr, reinterpret_cast<cuDoubleComplex *>(d_work), lwork, devInfo);
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
@ -131,6 +143,8 @@ class EighcGpuKernel : public GpuKernel {
"copy eigenvalue from workspace to host failed");
RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr),
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
// convert real scalar to complex
if (d_work) {
cudaFree(d_work);
@ -153,8 +167,6 @@ class EighcGpuKernel : public GpuKernel {
void InitSizeLists() override {
// in/out matrix, eigenvector
input_size_list_.push_back(m_ * m_ * sizeof(T));
// uplo
input_size_list_.push_back(sizeof(bool));
// eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
output_size_list_.push_back(m_ * sizeof(T));
output_size_list_.push_back(m_ * m_ * sizeof(T));
@ -162,6 +174,7 @@ class EighcGpuKernel : public GpuKernel {
workspace_size_list_.push_back(m_ * sizeof(D));
// for temp pre-transpose complex mitrx
workspace_size_list_.push_back(m_ * sizeof(T));
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
}
size_t m_{1};
@ -171,6 +184,7 @@ class EighcGpuKernel : public GpuKernel {
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
bool compute_eigen_vectors_{false};
bool lower_{true};
std::vector<T *> h_array_{};
std::vector<size_t> input_size_list_{};
std::vector<size_t> output_size_list_{};

View File

@ -18,19 +18,13 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EighGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
EighGpuKernel, double);
MS_REG_GPU_KERNEL_ONE(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighGpuKernel, double);
} // namespace kernel
} // namespace mindspore

View File

@ -30,10 +30,12 @@
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
constexpr char LOWER[] = "lower";
template <typename T>
class EighGpuKernel : public GpuKernel {
public:
@ -47,6 +49,7 @@ class EighGpuKernel : public GpuKernel {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER));
if (compute_eigen_vectors_) {
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
} else {
@ -69,26 +72,23 @@ class EighGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// matrix A, input or output(eigenvector)
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
auto lower = GetDeviceAddress<bool>(inputs, 1);
bool h_lower{true};
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy to lower to device failed");
if (h_lower) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
// Notice :this is important
// a col or row major is different to cpu, so a lower triangle is a upper triangle, a upper is a lower in gpu mem
// so the upper is positive to it from var, but for real scalar matrix, upper eq lower, it's different from complex
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_UPPER;
} else {
uplo_ = CUBLAS_FILL_MODE_LOWER;
}
auto output_addr = GetDeviceAddress<T>(outputs, 0); // output eigenvalues
auto output_v_addr = GetDeviceAddress<T>(outputs, 1); // output eigenvalues
auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim0); // temp eigenvector before transpose
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
cudaMemcpyAsync(w_v_addr, inout_A_addr, m_ * m_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy to input matrix failed");
size_t lda_ = m_;
int lwork = 0;
if constexpr (std::is_same_v<T, float>) {
cusolverDnSsyevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, inout_A_addr, lda_, output_addr, &lwork);
@ -100,10 +100,22 @@ class EighGpuKernel : public GpuKernel {
T *d_work = nullptr;
cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork);
if constexpr (std::is_same_v<T, float>) {
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo);
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
} else if constexpr (std::is_same_v<T, double>) {
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo);
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
}
size_t input_shape[kShape2dDims] = {m_, m_};
size_t input_axis[kShape2dDims] = {1, 0};
size_t *dev_input_shape = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
size_t *dev_input_axis = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (d_work) {
cudaFree(d_work);
}
@ -125,12 +137,11 @@ class EighGpuKernel : public GpuKernel {
void InitSizeLists() override {
// in/out matrix, eigenvector
input_size_list_.push_back(m_ * m_ * sizeof(T));
// uplo
input_size_list_.push_back(sizeof(bool));
// eigenvalues
output_size_list_.push_back(m_ * sizeof(T));
// eigenvector
output_size_list_.push_back(m_ * m_ * sizeof(T));
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
}
size_t m_{1};
@ -139,6 +150,7 @@ class EighGpuKernel : public GpuKernel {
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
bool compute_eigen_vectors_{false};
bool lower_{true};
std::vector<T *> h_array_{};
std::vector<size_t> input_size_list_{};
std::vector<size_t> output_size_list_{};

View File

@ -18,9 +18,10 @@ from .. import ops
from .ops import SolveTriangular
from .ops import CholeskySolver
from .ops import Cholesky
from .ops import EighNet
from ..ops import operations as P
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve']
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh']
def block_diag(*arrs):
@ -318,3 +319,84 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
cholesky_solver_net = CholeskySolver(lower=lower)
x = cholesky_solver_net(c, b)
return x
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, _type=1,
check_finite=True):
"""
Solve a standard or generalized eigenvalue problem for a complex
Hermitian or real symmetric matrix.
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of
Tensor ``a``, where ``b`` is positive definite such that for every
eigenvalue λ (i-th entry of w) and its eigenvector ``vi`` (i-th column of
``v``) satisfies::
a @ vi = λ * b @ vi
vi.conj().T @ a @ vi = λ
vi.conj().T @ b @ vi = 1
In the standard problem, ``b`` is assumed to be the identity matrix.
Args:
a (Tensor): (M, M) Tensor
A complex Hermitian or real symmetric matrix whose eigenvalues and
eigenvectors will be computed.
b (Tensor, optional): (M, M) Tensor
A complex Hermitian or real symmetric definite positive matrix in.
If omitted, identity matrix is assumed.
lower (bool, optional): Whether the pertinent Tensor data is taken from
the lower or upper triangle of ``a`` and, if applicable, ``b``. (Default: lower)
eigvals_only (bool, optional): Whether to calculate only eigenvalues
and no eigenvectors. (Default: both are calculated)
_type (int, optional): For the generalized problems, this keyword specifies
the problem type to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible
inputs)::
1 => a @ v = w @ b @ v
2 => a @ b @ v = w @ v
3 => b @ a @ v = w @ v
This keyword is ignored for standard problems.
overwrite_a (bool, optional): Whether to overwrite data in ``a``
(may improve performance). Default is False.
overwrite_b (bool, optional): Whether to overwrite data in ``b``
(may improve performance). Default is False.
check_finite (bool, optional): Whether to check that the input matrices
contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
turbo (bool, optional): use divide and conquer algorithm (faster but
expensive in memory, only for generalized eigenvalue problem and
if full set of eigenvalues are requested.). Has no significant
effect if eigenvectors are not requested.
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order)
eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1.
If omitted, all eigenvalues and eigenvectors are returned.
Returns:
w (Tensor): (N,) Tensor, The N (1<=N<=M) selected eigenvalues, in ascending order,
each repeated according to its multiplicity.
v (Tensor): (M, N) Tensor, (if ``eigvals_only == False``)
Raises:
LinAlgError: If eigenvalue computation does not converge, an error occurred, or
b matrix is not definite positive. Note that if input matrices are
not symmetric or Hermitian, no error will be reported but results will
be wrong.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import eigh
>>> A = Tensor(onp.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]))
>>> w, v = eigh(A)
>>> onp.allclose(A @ v - v @ onp.diag(w), onp.zeros((4, 4)))
True
"""
eigh_net = EighNet(not eigvals_only, lower=True)
return eigh_net(a)

View File

@ -194,42 +194,21 @@ class Eigh(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, compute_eigenvectors):
def __init__(self, compute_eigenvectors=True, lower=True):
super().__init__(name="Eigh")
self.init_prim_io_names(inputs=['A', 's'], outputs=['output', 'output_v'])
self.init_prim_io_names(inputs=['A'], outputs=['output_w', 'output_v'])
self.compute_eigenvectors = validator.check_value_type(
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.add_prim_attr('lower', self.lower)
self.add_prim_attr('compute_eigenvectors', self.compute_eigenvectors)
def __infer__(self, A, s):
def __infer__(self, A):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (A['dtype'], A['dtype']),
'value': None
}
if A['dtype'] == mstype.tensor_type(mstype.float32):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (mstype.float32, mstype.float32),
'value': None
}
elif A['dtype'] == mstype.tensor_type(mstype.float64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (mstype.float64, mstype.float64),
'value': None
}
elif A['dtype'] == mstype.tensor_type(mstype.complex64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (A['dtype'], A['dtype']),
'value': None
}
elif A['dtype'] == mstype.tensor_type(mstype.complex128):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (mstype.complex128, mstype.complex128),
'value': None
}
return shape
@ -238,16 +217,17 @@ class EighNet(nn.Cell):
EigenValue /eigenvector solver for symmetric/Hermitian matrix
Ax = lambda * x
"""
def __init__(self, b):
super(EighNet, self).__init__()
self.b = b
self.eigh = Eigh(b)
def construct(self, A, s=True):
r = self.eigh(A, s)
if self.b:
def __init__(self, bv=True, lower=True):
super(EighNet, self).__init__()
self.bv = bv
self.eigh = Eigh(bv, lower)
def construct(self, A):
r = self.eigh(A)
if self.bv:
return (r[0], r[1])
return (r[0],)
return r[0]
class Eig(PrimitiveWithInfer):
@ -257,7 +237,7 @@ class Eig(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, compute_eigenvectors):
def __init__(self, compute_eigenvectors=True):
super().__init__(name="Eig")
self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v'])
self.compute_eigenvectors = validator.check_value_type(
@ -285,13 +265,14 @@ class EigNet(nn.Cell):
EigenValue /eigenvector solver for generic matrix
Ax = lambda * x
"""
def __init__(self, b):
def __init__(self, bv=True):
super(EigNet, self).__init__()
self.b = b
self.eig = Eig(b)
self.bv = bv
self.eig = Eig(bv)
def construct(self, A):
r = self.eig(A)
if self.b:
if self.bv:
return (r[0], r[1])
return (r[0],)
return r[0]

View File

@ -31,11 +31,6 @@ def match(v, v_, error=0):
np.testing.assert_equal(v, v_)
def create_sym_pos_matrix(m, n, dtype):
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
return np.dot(a, a.T)
@pytest.mark.parametrize('n', [4, 6, 9, 10])
@pytest.mark.platform_x86_cpu
def test_eig_net(n: int):
@ -48,13 +43,13 @@ def test_eig_net(n: int):
rtol = 1e-3
atol = 1e-4
msp_eig = EigNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
A = np.array(np.random.rand(n, n), dtype=np.float32)
tensor_a = Tensor(np.array(A).astype(np.float32))
msp_w, msp_v = msp_eig(tensor_a)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
# test case for real scalar double 64
A = np.random.rand(n, n)
A = np.array(np.random.rand(n, n), dtype=np.float64)
rtol = 1e-5
atol = 1e-8
msp_eig = EigNet(True)
@ -98,6 +93,7 @@ def test_eig_net(n: int):
# Com`pare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
msp_eig = EigNet(False)
msp_w0 = msp_eig(Tensor(np.array(A).astype(np.complex128)))
assert np.allclose(msp_w0.asnumpy() - msp_w.asnumpy(), np.zeros((n, n)), rtol, atol)

View File

@ -47,10 +47,12 @@ def test_eigh_net(n: int):
# test for real scalar float 32
rtol = 1e-3
atol = 1e-4
msp_eigh = EighNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False)
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)))
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
@ -62,19 +64,23 @@ def test_eigh_net(n: int):
A = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
msp_eigh = EighNet(True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False)
# Compare with scipy
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)))
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
# test for real scalar float64 no vector
msp_eigh = EighNet(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
msp_eigh = EighNet(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
# test case for complex64
rtol = 1e-3
@ -86,18 +92,12 @@ def test_eigh_net(n: int):
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), False)
# Compare with scipy, scipy passed
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
@ -113,19 +113,21 @@ def test_eigh_net(n: int):
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), False)
# Compare with scipy, scipy passed
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
# test for real scalar complex128 no vector
msp_eigh = EighNet(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
msp_eigh = EighNet(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)

View File

@ -47,26 +47,35 @@ def test_eigh_net(n: int):
# test for real scalar float 32
rtol = 1e-3
atol = 1e-4
msp_eigh = EighNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False)
assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)))
assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
# test case for real scalar double 64
A = create_sym_pos_matrix(n, n, np.float64)
rtol = 1e-5
atol = 1e-8
msp_eigh = EighNet(True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False)
assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)))
assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
# test for real scalar float64 no vector
msp_eigh = EighNet(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
msp_eigh = EighNet(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
# test case for complex64
rtol = 1e-3
@ -78,14 +87,15 @@ def test_eigh_net(n: int):
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex64)), False)
assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()),
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()),
np.zeros((n, n)), rtol, atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()),
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()),
np.zeros((n, n)), rtol, atol)
# test for complex128
@ -94,17 +104,24 @@ def test_eigh_net(n: int):
A = np.array(np.random.rand(n, n), dtype=np.complex128)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex128)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex128)), False)
assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()),
msp_eigh = EighNet(True, True)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
msp_eigh = EighNet(True, False)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()),
np.zeros((n, n)), rtol, atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()),
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()),
np.zeros((n, n)), rtol, atol)
# test for real scalar complex128 no vector
msp_eigh = EighNet(False, True)
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
msp_eigh = EighNet(False, False)
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)

View File

@ -139,3 +139,91 @@ def test_cholesky_solver(n: int, lower: bool, dtype):
# pre tensor_a has been inplace.
tensor_a = Tensor(a)
assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 6, 9, 20])
def test_eigh_solver(n: int):
"""
Feature: ALL TO ALL
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
Expectation: the result match scipy cholesky_solve
"""
# test for real scalar float 32
rtol = 1e-3
atol = 1e-4
A = create_sym_pos_matrix([n, n], onp.float32)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False)
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
# test case for real scalar double 64
A = create_sym_pos_matrix([n, n], onp.float64)
rtol = 1e-5
atol = 1e-8
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False)
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
rtol,
atol)
# test for real scalar float64 no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True)
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
# test case for complex64
rtol = 1e-3
atol = 1e-4
A = onp.array(onp.random.rand(n, n), dtype=onp.complex64)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(onp.random.rand(1, 1), 0)
else:
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False)
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
onp.zeros((n, n)), rtol, atol)
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
onp.zeros((n, n)), rtol, atol)
# test for complex128
rtol = 1e-5
atol = 1e-8
A = onp.array(onp.random.rand(n, n), dtype=onp.complex128)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(onp.random.rand(1, 1), 0)
else:
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False)
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False)
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
onp.zeros((n, n)), rtol, atol)
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
onp.zeros((n, n)), rtol, atol)
# test for real scalar float64 no vector
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True)
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)