forked from mindspore-Ecosystem/mindspore
!26546 Unify GPU/CPU ops input/output(col/rolmajor), modify related testcases, add linalg function and testcases
Merge pull request !26546 from wuwenbing/master
This commit is contained in:
commit
69c4f470e4
|
@ -24,8 +24,6 @@ namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kInputsNum = 1;
|
constexpr size_t kInputsNum = 1;
|
||||||
constexpr size_t kOutputsNum = 2;
|
constexpr size_t kOutputsNum = 2;
|
||||||
constexpr size_t kDefaultShape = 1;
|
|
||||||
constexpr auto kAMatrixDimNum = 2;
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
using Eigen::Dynamic;
|
using Eigen::Dynamic;
|
||||||
|
@ -45,12 +43,8 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row
|
||||||
template <typename T, typename C>
|
template <typename T, typename C>
|
||||||
void EigCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
void EigCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
|
||||||
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||||
|
|
||||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
|
|
||||||
|
|
||||||
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
||||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
|
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
|
||||||
<< "]";
|
<< "]";
|
||||||
|
|
|
@ -22,10 +22,8 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kInputsNum = 2;
|
constexpr size_t kInputsNum = 1;
|
||||||
constexpr size_t kOutputsNum = 2;
|
constexpr size_t kOutputsNum = 2;
|
||||||
constexpr size_t kDefaultShape = 1;
|
|
||||||
constexpr auto kAMatrixDimNum = 2;
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
using Eigen::Dynamic;
|
using Eigen::Dynamic;
|
||||||
|
@ -45,12 +43,9 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EighCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
void EighCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
compute_eigen_vectors_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||||
compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||||
|
|
||||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));
|
|
||||||
|
|
||||||
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
||||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
|
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1]
|
||||||
<< "]";
|
<< "]";
|
||||||
|
@ -91,10 +86,8 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||||
|
|
||||||
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
|
// is the Matrix a symmetric matrix(true lower triangle, false upper triangle)
|
||||||
auto symmetric_type = reinterpret_cast<bool *>(inputs[1]->addr);
|
|
||||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
auto output_v_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
auto output_v_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||||
Map<MatrixSquare<T>> A(A_addr, m_, m_);
|
Map<MatrixSquare<T>> A(A_addr, m_, m_);
|
||||||
|
@ -102,19 +95,19 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
Map<MatrixSquare<T>> output(output_addr, m_, 1);
|
Map<MatrixSquare<T>> output(output_addr, m_, 1);
|
||||||
Map<MatrixSquare<T>> outputv(output_v_addr, m_, m_);
|
Map<MatrixSquare<T>> outputv(output_v_addr, m_, m_);
|
||||||
// selfadjoint matrix
|
// selfadjoint matrix
|
||||||
if (*symmetric_type) {
|
if (lower_) {
|
||||||
A_ = A.template selfadjointView<Lower>();
|
A_ = A.template selfadjointView<Lower>();
|
||||||
} else {
|
} else {
|
||||||
A_ = A.template selfadjointView<Upper>();
|
A_ = A.template selfadjointView<Upper>();
|
||||||
}
|
}
|
||||||
// Real scalar eigen solver
|
// Real scalar eigen solver
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
|
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||||
} else if constexpr (std::is_same_v<T, double>) {
|
} else if constexpr (std::is_same_v<T, double>) {
|
||||||
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
|
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||||
} else {
|
} else {
|
||||||
// complex eigen solver
|
// complex eigen solver
|
||||||
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors);
|
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,36 +46,29 @@ class EighCPUKernel : public CPUKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t m_{1};
|
size_t m_{1};
|
||||||
bool compute_eigen_vectors{false};
|
bool compute_eigen_vectors_{false};
|
||||||
|
bool lower_{true};
|
||||||
TypeId dtype_{kNumberTypeFloat32};
|
TypeId dtype_{kNumberTypeFloat32};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(Eigh,
|
MS_REG_CPU_KERNEL_T(
|
||||||
KernelAttr()
|
Eigh,
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
.AddInputAttr(kNumberTypeBool)
|
EighCPUKernel, float);
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
MS_REG_CPU_KERNEL_T(
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
Eigh,
|
||||||
EighCPUKernel, float);
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
MS_REG_CPU_KERNEL_T(Eigh,
|
EighCPUKernel, double);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
EighCPUKernel, double);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(Eigh,
|
MS_REG_CPU_KERNEL_T(Eigh,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeComplex64)
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeComplex64)
|
.AddOutputAttr(kNumberTypeComplex64)
|
||||||
.AddOutputAttr(kNumberTypeComplex64),
|
.AddOutputAttr(kNumberTypeComplex64),
|
||||||
EighCPUKernel, float_complex);
|
EighCPUKernel, float_complex);
|
||||||
MS_REG_CPU_KERNEL_T(Eigh,
|
MS_REG_CPU_KERNEL_T(Eigh,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeComplex128)
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeComplex128)
|
.AddOutputAttr(kNumberTypeComplex128)
|
||||||
.AddOutputAttr(kNumberTypeComplex128),
|
.AddOutputAttr(kNumberTypeComplex128),
|
||||||
EighCPUKernel, double_complex);
|
EighCPUKernel, double_complex);
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
#include "transpose_impl.cuh"
|
#include "transpose_impl.cuh"
|
||||||
#include "runtime/device/gpu/cuda_common.h"
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
#include "utils/complex.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Complex = mindspore::utils::Complex<T>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void Transpose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
|
__global__ void Transpose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
|
||||||
|
@ -74,3 +78,9 @@ template void CalTranspose<int>(const size_t size, const int *input, const size_
|
||||||
template void CalTranspose<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
|
template void CalTranspose<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
|
||||||
const size_t *input_axis, const size_t shape_size, int64_t *output,
|
const size_t *input_axis, const size_t shape_size, int64_t *output,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
template void CalTranspose<Complex<float>>(const size_t size, const Complex<float> *input, const size_t *input_shape,
|
||||||
|
const size_t *input_axis, const size_t shape_size, Complex<float> *output,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void CalTranspose<Complex<double>>(const size_t size, const Complex<double> *input, const size_t *input_shape,
|
||||||
|
const size_t *input_axis, const size_t shape_size, Complex<double> *output,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -21,14 +21,12 @@ namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
MS_REG_GPU_KERNEL_ONE(Eigh,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeComplex64)
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeComplex64)
|
.AddOutputAttr(kNumberTypeComplex64)
|
||||||
.AddOutputAttr(kNumberTypeComplex64),
|
.AddOutputAttr(kNumberTypeComplex64),
|
||||||
EighcGpuKernel, Complex<float>)
|
EighcGpuKernel, Complex<float>)
|
||||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
MS_REG_GPU_KERNEL_ONE(Eigh,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeComplex128)
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeComplex128)
|
.AddOutputAttr(kNumberTypeComplex128)
|
||||||
.AddOutputAttr(kNumberTypeComplex128),
|
.AddOutputAttr(kNumberTypeComplex128),
|
||||||
EighcGpuKernel, Complex<double>);
|
EighcGpuKernel, Complex<double>);
|
||||||
|
|
|
@ -32,10 +32,12 @@
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/complex.h"
|
#include "utils/complex.h"
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/real_to_complex_impl.cuh"
|
#include "backend/kernel_compiler/gpu/cuda_impl/real_to_complex_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
||||||
|
constexpr char LOWER[] = "lower";
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using Complex = mindspore::utils::Complex<T>;
|
using Complex = mindspore::utils::Complex<T>;
|
||||||
|
|
||||||
|
@ -61,6 +63,7 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
|
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
|
||||||
|
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER));
|
||||||
if (compute_eigen_vectors_) {
|
if (compute_eigen_vectors_) {
|
||||||
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
|
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,13 +87,7 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
// matrix A, input or output(eigenvector)
|
// matrix A, input or output(eigenvector)
|
||||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
|
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
auto lower = GetDeviceAddress<bool>(inputs, 1);
|
if (lower_) {
|
||||||
bool h_lower{true};
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
||||||
cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost,
|
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
||||||
"copy lower do device failed");
|
|
||||||
if (h_lower) {
|
|
||||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||||
} else {
|
} else {
|
||||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||||
|
@ -105,24 +102,39 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
// temp output eigenvalues real scalar
|
// temp output eigenvalues real scalar
|
||||||
auto w_w_addr = GetDeviceAddress<D>(workspace, 0);
|
auto w_w_addr = GetDeviceAddress<D>(workspace, 0);
|
||||||
auto w_w_c_addr = GetDeviceAddress<T>(workspace, 1);
|
auto w_w_c_addr = GetDeviceAddress<T>(workspace, 1);
|
||||||
|
// temp eigenvector before transpose
|
||||||
|
auto w_v_addr = GetDeviceAddress<T>(workspace, 2);
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||||
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
|
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
|
||||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
"copy input matrix failed");
|
"copy input matrix failed");
|
||||||
|
size_t input_shape[kShape2dDims] = {m_, m_};
|
||||||
|
size_t input_axis[kShape2dDims] = {1, 0};
|
||||||
|
size_t *dev_input_shape = nullptr;
|
||||||
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
|
||||||
|
size_t *dev_input_axis = nullptr;
|
||||||
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
|
||||||
|
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
CalTranspose(m_ * m_, output_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, w_v_addr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if constexpr (std::is_same_v<T, Complex<float>>) {
|
if constexpr (std::is_same_v<T, Complex<float>>) {
|
||||||
cusolverDnCheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr),
|
cusolverDnCheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr),
|
||||||
lda_, w_w_addr, &lwork);
|
lda_, w_w_addr, &lwork);
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
|
||||||
"cal eigenvalues workspace failed");
|
"cal eigenvalues workspace failed");
|
||||||
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr), lda_, w_w_addr,
|
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(w_v_addr), lda_, w_w_addr,
|
||||||
reinterpret_cast<cuComplex *>(d_work), lwork, devInfo);
|
reinterpret_cast<cuComplex *>(d_work), lwork, devInfo);
|
||||||
} else {
|
} else {
|
||||||
cusolverDnZheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_,
|
cusolverDnZheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_,
|
||||||
reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, w_w_addr, &lwork);
|
reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, w_w_addr, &lwork);
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
|
||||||
"cal eigenvalues workspace failed");
|
"cal eigenvalues workspace failed");
|
||||||
cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_,
|
cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(w_v_addr), lda_,
|
||||||
w_w_addr, reinterpret_cast<cuDoubleComplex *>(d_work), lwork, devInfo);
|
w_w_addr, reinterpret_cast<cuDoubleComplex *>(d_work), lwork, devInfo);
|
||||||
}
|
}
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||||
|
@ -131,6 +143,8 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
"copy eigenvalue from workspace to host failed");
|
"copy eigenvalue from workspace to host failed");
|
||||||
RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr),
|
RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr),
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
// convert real scalar to complex
|
// convert real scalar to complex
|
||||||
if (d_work) {
|
if (d_work) {
|
||||||
cudaFree(d_work);
|
cudaFree(d_work);
|
||||||
|
@ -153,8 +167,6 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
void InitSizeLists() override {
|
void InitSizeLists() override {
|
||||||
// in/out matrix, eigenvector
|
// in/out matrix, eigenvector
|
||||||
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
// uplo
|
|
||||||
input_size_list_.push_back(sizeof(bool));
|
|
||||||
// eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
|
// eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
|
||||||
output_size_list_.push_back(m_ * sizeof(T));
|
output_size_list_.push_back(m_ * sizeof(T));
|
||||||
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
|
@ -162,6 +174,7 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
workspace_size_list_.push_back(m_ * sizeof(D));
|
workspace_size_list_.push_back(m_ * sizeof(D));
|
||||||
// for temp pre-transpose complex mitrx
|
// for temp pre-transpose complex mitrx
|
||||||
workspace_size_list_.push_back(m_ * sizeof(T));
|
workspace_size_list_.push_back(m_ * sizeof(T));
|
||||||
|
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t m_{1};
|
size_t m_{1};
|
||||||
|
@ -171,6 +184,7 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||||
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
|
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
|
||||||
bool compute_eigen_vectors_{false};
|
bool compute_eigen_vectors_{false};
|
||||||
|
bool lower_{true};
|
||||||
std::vector<T *> h_array_{};
|
std::vector<T *> h_array_{};
|
||||||
std::vector<size_t> input_size_list_{};
|
std::vector<size_t> input_size_list_{};
|
||||||
std::vector<size_t> output_size_list_{};
|
std::vector<size_t> output_size_list_{};
|
||||||
|
|
|
@ -18,19 +18,13 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
KernelAttr()
|
Eigh,
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
.AddInputAttr(kNumberTypeBool)
|
EighGpuKernel, float)
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
Eigh,
|
||||||
EighGpuKernel, float)
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
EighGpuKernel, double);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
EighGpuKernel, double);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,10 +30,12 @@
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
||||||
|
constexpr char LOWER[] = "lower";
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class EighGpuKernel : public GpuKernel {
|
class EighGpuKernel : public GpuKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -47,6 +49,7 @@ class EighGpuKernel : public GpuKernel {
|
||||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
|
compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR));
|
||||||
|
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER));
|
||||||
if (compute_eigen_vectors_) {
|
if (compute_eigen_vectors_) {
|
||||||
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
|
jobz_ = CUSOLVER_EIG_MODE_VECTOR;
|
||||||
} else {
|
} else {
|
||||||
|
@ -69,26 +72,23 @@ class EighGpuKernel : public GpuKernel {
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
// matrix A, input or output(eigenvector)
|
// matrix A, input or output(eigenvector)
|
||||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
|
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||||
auto lower = GetDeviceAddress<bool>(inputs, 1);
|
// Notice :this is important
|
||||||
bool h_lower{true};
|
// a col or row major is different to cpu, so a lower triangle is a upper triangle, a upper is a lower in gpu mem
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
// so the upper is positive to it from var, but for real scalar matrix, upper eq lower, it's different from complex
|
||||||
cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost,
|
if (lower_) {
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
||||||
"copy to lower to device failed");
|
|
||||||
if (h_lower) {
|
|
||||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
|
||||||
} else {
|
|
||||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||||
|
} else {
|
||||||
|
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||||
}
|
}
|
||||||
auto output_addr = GetDeviceAddress<T>(outputs, 0); // output eigenvalues
|
auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues
|
||||||
auto output_v_addr = GetDeviceAddress<T>(outputs, 1); // output eigenvalues
|
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
|
||||||
|
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim0); // temp eigenvector before transpose
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||||
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
|
cudaMemcpyAsync(w_v_addr, inout_A_addr, m_ * m_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
"copy to input matrix failed");
|
"copy to input matrix failed");
|
||||||
size_t lda_ = m_;
|
size_t lda_ = m_;
|
||||||
|
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
cusolverDnSsyevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, inout_A_addr, lda_, output_addr, &lwork);
|
cusolverDnSsyevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, inout_A_addr, lda_, output_addr, &lwork);
|
||||||
|
@ -100,10 +100,22 @@ class EighGpuKernel : public GpuKernel {
|
||||||
T *d_work = nullptr;
|
T *d_work = nullptr;
|
||||||
cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork);
|
cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork);
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo);
|
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
|
||||||
} else if constexpr (std::is_same_v<T, double>) {
|
} else if constexpr (std::is_same_v<T, double>) {
|
||||||
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo);
|
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
|
||||||
}
|
}
|
||||||
|
size_t input_shape[kShape2dDims] = {m_, m_};
|
||||||
|
size_t input_axis[kShape2dDims] = {1, 0};
|
||||||
|
size_t *dev_input_shape = nullptr;
|
||||||
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
|
||||||
|
size_t *dev_input_axis = nullptr;
|
||||||
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
|
||||||
|
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
if (d_work) {
|
if (d_work) {
|
||||||
cudaFree(d_work);
|
cudaFree(d_work);
|
||||||
}
|
}
|
||||||
|
@ -125,12 +137,11 @@ class EighGpuKernel : public GpuKernel {
|
||||||
void InitSizeLists() override {
|
void InitSizeLists() override {
|
||||||
// in/out matrix, eigenvector
|
// in/out matrix, eigenvector
|
||||||
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
// uplo
|
|
||||||
input_size_list_.push_back(sizeof(bool));
|
|
||||||
// eigenvalues
|
// eigenvalues
|
||||||
output_size_list_.push_back(m_ * sizeof(T));
|
output_size_list_.push_back(m_ * sizeof(T));
|
||||||
// eigenvector
|
// eigenvector
|
||||||
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
|
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t m_{1};
|
size_t m_{1};
|
||||||
|
@ -139,6 +150,7 @@ class EighGpuKernel : public GpuKernel {
|
||||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||||
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
|
cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR;
|
||||||
bool compute_eigen_vectors_{false};
|
bool compute_eigen_vectors_{false};
|
||||||
|
bool lower_{true};
|
||||||
std::vector<T *> h_array_{};
|
std::vector<T *> h_array_{};
|
||||||
std::vector<size_t> input_size_list_{};
|
std::vector<size_t> input_size_list_{};
|
||||||
std::vector<size_t> output_size_list_{};
|
std::vector<size_t> output_size_list_{};
|
||||||
|
|
|
@ -18,9 +18,10 @@ from .. import ops
|
||||||
from .ops import SolveTriangular
|
from .ops import SolveTriangular
|
||||||
from .ops import CholeskySolver
|
from .ops import CholeskySolver
|
||||||
from .ops import Cholesky
|
from .ops import Cholesky
|
||||||
|
from .ops import EighNet
|
||||||
from ..ops import operations as P
|
from ..ops import operations as P
|
||||||
|
|
||||||
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve']
|
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh']
|
||||||
|
|
||||||
|
|
||||||
def block_diag(*arrs):
|
def block_diag(*arrs):
|
||||||
|
@ -318,3 +319,84 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||||
cholesky_solver_net = CholeskySolver(lower=lower)
|
cholesky_solver_net = CholeskySolver(lower=lower)
|
||||||
x = cholesky_solver_net(c, b)
|
x = cholesky_solver_net(c, b)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||||
|
overwrite_b=False, turbo=True, eigvals=None, _type=1,
|
||||||
|
check_finite=True):
|
||||||
|
"""
|
||||||
|
Solve a standard or generalized eigenvalue problem for a complex
|
||||||
|
Hermitian or real symmetric matrix.
|
||||||
|
|
||||||
|
Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of
|
||||||
|
Tensor ``a``, where ``b`` is positive definite such that for every
|
||||||
|
eigenvalue λ (i-th entry of w) and its eigenvector ``vi`` (i-th column of
|
||||||
|
``v``) satisfies::
|
||||||
|
|
||||||
|
a @ vi = λ * b @ vi
|
||||||
|
vi.conj().T @ a @ vi = λ
|
||||||
|
vi.conj().T @ b @ vi = 1
|
||||||
|
|
||||||
|
In the standard problem, ``b`` is assumed to be the identity matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (Tensor): (M, M) Tensor
|
||||||
|
A complex Hermitian or real symmetric matrix whose eigenvalues and
|
||||||
|
eigenvectors will be computed.
|
||||||
|
b (Tensor, optional): (M, M) Tensor
|
||||||
|
A complex Hermitian or real symmetric definite positive matrix in.
|
||||||
|
If omitted, identity matrix is assumed.
|
||||||
|
lower (bool, optional): Whether the pertinent Tensor data is taken from
|
||||||
|
the lower or upper triangle of ``a`` and, if applicable, ``b``. (Default: lower)
|
||||||
|
eigvals_only (bool, optional): Whether to calculate only eigenvalues
|
||||||
|
and no eigenvectors. (Default: both are calculated)
|
||||||
|
_type (int, optional): For the generalized problems, this keyword specifies
|
||||||
|
the problem type to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible
|
||||||
|
inputs)::
|
||||||
|
|
||||||
|
1 => a @ v = w @ b @ v
|
||||||
|
2 => a @ b @ v = w @ v
|
||||||
|
3 => b @ a @ v = w @ v
|
||||||
|
|
||||||
|
This keyword is ignored for standard problems.
|
||||||
|
overwrite_a (bool, optional): Whether to overwrite data in ``a``
|
||||||
|
(may improve performance). Default is False.
|
||||||
|
overwrite_b (bool, optional): Whether to overwrite data in ``b``
|
||||||
|
(may improve performance). Default is False.
|
||||||
|
check_finite (bool, optional): Whether to check that the input matrices
|
||||||
|
contain only finite numbers.
|
||||||
|
Disabling may give a performance gain, but may result in problems
|
||||||
|
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||||
|
turbo (bool, optional): use divide and conquer algorithm (faster but
|
||||||
|
expensive in memory, only for generalized eigenvalue problem and
|
||||||
|
if full set of eigenvalues are requested.). Has no significant
|
||||||
|
effect if eigenvectors are not requested.
|
||||||
|
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order)
|
||||||
|
eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1.
|
||||||
|
If omitted, all eigenvalues and eigenvectors are returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
w (Tensor): (N,) Tensor, The N (1<=N<=M) selected eigenvalues, in ascending order,
|
||||||
|
each repeated according to its multiplicity.
|
||||||
|
v (Tensor): (M, N) Tensor, (if ``eigvals_only == False``)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LinAlgError: If eigenvalue computation does not converge, an error occurred, or
|
||||||
|
b matrix is not definite positive. Note that if input matrices are
|
||||||
|
not symmetric or Hermitian, no error will be reported but results will
|
||||||
|
be wrong.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as onp
|
||||||
|
>>> from mindspore.common import Tensor
|
||||||
|
>>> from mindspore.scipy.linalg import eigh
|
||||||
|
>>> A = Tensor(onp.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]))
|
||||||
|
>>> w, v = eigh(A)
|
||||||
|
>>> onp.allclose(A @ v - v @ onp.diag(w), onp.zeros((4, 4)))
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
eigh_net = EighNet(not eigvals_only, lower=True)
|
||||||
|
return eigh_net(a)
|
||||||
|
|
|
@ -194,42 +194,21 @@ class Eigh(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, compute_eigenvectors):
|
def __init__(self, compute_eigenvectors=True, lower=True):
|
||||||
super().__init__(name="Eigh")
|
super().__init__(name="Eigh")
|
||||||
self.init_prim_io_names(inputs=['A', 's'], outputs=['output', 'output_v'])
|
self.init_prim_io_names(inputs=['A'], outputs=['output_w', 'output_v'])
|
||||||
self.compute_eigenvectors = validator.check_value_type(
|
self.compute_eigenvectors = validator.check_value_type(
|
||||||
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
|
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)
|
||||||
|
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||||
|
self.add_prim_attr('lower', self.lower)
|
||||||
|
self.add_prim_attr('compute_eigenvectors', self.compute_eigenvectors)
|
||||||
|
|
||||||
def __infer__(self, A, s):
|
def __infer__(self, A):
|
||||||
shape = {
|
shape = {
|
||||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||||
'dtype': (A['dtype'], A['dtype']),
|
'dtype': (A['dtype'], A['dtype']),
|
||||||
'value': None
|
'value': None
|
||||||
}
|
}
|
||||||
if A['dtype'] == mstype.tensor_type(mstype.float32):
|
|
||||||
shape = {
|
|
||||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
|
||||||
'dtype': (mstype.float32, mstype.float32),
|
|
||||||
'value': None
|
|
||||||
}
|
|
||||||
elif A['dtype'] == mstype.tensor_type(mstype.float64):
|
|
||||||
shape = {
|
|
||||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
|
||||||
'dtype': (mstype.float64, mstype.float64),
|
|
||||||
'value': None
|
|
||||||
}
|
|
||||||
elif A['dtype'] == mstype.tensor_type(mstype.complex64):
|
|
||||||
shape = {
|
|
||||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
|
||||||
'dtype': (A['dtype'], A['dtype']),
|
|
||||||
'value': None
|
|
||||||
}
|
|
||||||
elif A['dtype'] == mstype.tensor_type(mstype.complex128):
|
|
||||||
shape = {
|
|
||||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
|
||||||
'dtype': (mstype.complex128, mstype.complex128),
|
|
||||||
'value': None
|
|
||||||
}
|
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
@ -238,16 +217,17 @@ class EighNet(nn.Cell):
|
||||||
EigenValue /eigenvector solver for symmetric/Hermitian matrix
|
EigenValue /eigenvector solver for symmetric/Hermitian matrix
|
||||||
Ax = lambda * x
|
Ax = lambda * x
|
||||||
"""
|
"""
|
||||||
def __init__(self, b):
|
|
||||||
super(EighNet, self).__init__()
|
|
||||||
self.b = b
|
|
||||||
self.eigh = Eigh(b)
|
|
||||||
|
|
||||||
def construct(self, A, s=True):
|
def __init__(self, bv=True, lower=True):
|
||||||
r = self.eigh(A, s)
|
super(EighNet, self).__init__()
|
||||||
if self.b:
|
self.bv = bv
|
||||||
|
self.eigh = Eigh(bv, lower)
|
||||||
|
|
||||||
|
def construct(self, A):
|
||||||
|
r = self.eigh(A)
|
||||||
|
if self.bv:
|
||||||
return (r[0], r[1])
|
return (r[0], r[1])
|
||||||
return (r[0],)
|
return r[0]
|
||||||
|
|
||||||
|
|
||||||
class Eig(PrimitiveWithInfer):
|
class Eig(PrimitiveWithInfer):
|
||||||
|
@ -257,7 +237,7 @@ class Eig(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, compute_eigenvectors):
|
def __init__(self, compute_eigenvectors=True):
|
||||||
super().__init__(name="Eig")
|
super().__init__(name="Eig")
|
||||||
self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v'])
|
self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v'])
|
||||||
self.compute_eigenvectors = validator.check_value_type(
|
self.compute_eigenvectors = validator.check_value_type(
|
||||||
|
@ -285,13 +265,14 @@ class EigNet(nn.Cell):
|
||||||
EigenValue /eigenvector solver for generic matrix
|
EigenValue /eigenvector solver for generic matrix
|
||||||
Ax = lambda * x
|
Ax = lambda * x
|
||||||
"""
|
"""
|
||||||
def __init__(self, b):
|
|
||||||
|
def __init__(self, bv=True):
|
||||||
super(EigNet, self).__init__()
|
super(EigNet, self).__init__()
|
||||||
self.b = b
|
self.bv = bv
|
||||||
self.eig = Eig(b)
|
self.eig = Eig(bv)
|
||||||
|
|
||||||
def construct(self, A):
|
def construct(self, A):
|
||||||
r = self.eig(A)
|
r = self.eig(A)
|
||||||
if self.b:
|
if self.bv:
|
||||||
return (r[0], r[1])
|
return (r[0], r[1])
|
||||||
return (r[0],)
|
return r[0]
|
||||||
|
|
|
@ -31,11 +31,6 @@ def match(v, v_, error=0):
|
||||||
np.testing.assert_equal(v, v_)
|
np.testing.assert_equal(v, v_)
|
||||||
|
|
||||||
|
|
||||||
def create_sym_pos_matrix(m, n, dtype):
|
|
||||||
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
|
|
||||||
return np.dot(a, a.T)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('n', [4, 6, 9, 10])
|
@pytest.mark.parametrize('n', [4, 6, 9, 10])
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
def test_eig_net(n: int):
|
def test_eig_net(n: int):
|
||||||
|
@ -48,13 +43,13 @@ def test_eig_net(n: int):
|
||||||
rtol = 1e-3
|
rtol = 1e-3
|
||||||
atol = 1e-4
|
atol = 1e-4
|
||||||
msp_eig = EigNet(True)
|
msp_eig = EigNet(True)
|
||||||
A = create_sym_pos_matrix(n, n, np.float32)
|
A = np.array(np.random.rand(n, n), dtype=np.float32)
|
||||||
tensor_a = Tensor(np.array(A).astype(np.float32))
|
tensor_a = Tensor(np.array(A).astype(np.float32))
|
||||||
msp_w, msp_v = msp_eig(tensor_a)
|
msp_w, msp_v = msp_eig(tensor_a)
|
||||||
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
# test case for real scalar double 64
|
# test case for real scalar double 64
|
||||||
A = np.random.rand(n, n)
|
A = np.array(np.random.rand(n, n), dtype=np.float64)
|
||||||
rtol = 1e-5
|
rtol = 1e-5
|
||||||
atol = 1e-8
|
atol = 1e-8
|
||||||
msp_eig = EigNet(True)
|
msp_eig = EigNet(True)
|
||||||
|
@ -98,6 +93,7 @@ def test_eig_net(n: int):
|
||||||
# Com`pare with scipy, scipy passed
|
# Com`pare with scipy, scipy passed
|
||||||
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
|
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
|
||||||
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
|
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
|
|
||||||
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
|
||||||
|
msp_eig = EigNet(False)
|
||||||
|
msp_w0 = msp_eig(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
assert np.allclose(msp_w0.asnumpy() - msp_w.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
|
@ -47,10 +47,12 @@ def test_eigh_net(n: int):
|
||||||
# test for real scalar float 32
|
# test for real scalar float 32
|
||||||
rtol = 1e-3
|
rtol = 1e-3
|
||||||
atol = 1e-4
|
atol = 1e-4
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
A = create_sym_pos_matrix(n, n, np.float32)
|
A = create_sym_pos_matrix(n, n, np.float32)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)))
|
||||||
|
msp_eigh = EighNet(True, False)
|
||||||
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)))
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
|
@ -62,19 +64,23 @@ def test_eigh_net(n: int):
|
||||||
A = np.random.rand(n, n)
|
A = np.random.rand(n, n)
|
||||||
rtol = 1e-5
|
rtol = 1e-5
|
||||||
atol = 1e-8
|
atol = 1e-8
|
||||||
msp_eigh = EighNet(True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False)
|
msp_eigh = EighNet(True, False)
|
||||||
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
# Compare with scipy
|
|
||||||
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
|
|
||||||
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)
|
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
|
# test for real scalar float64 no vector
|
||||||
|
msp_eigh = EighNet(False, True)
|
||||||
|
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
|
msp_eigh = EighNet(False, False)
|
||||||
|
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
|
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
# test case for complex64
|
# test case for complex64
|
||||||
rtol = 1e-3
|
rtol = 1e-3
|
||||||
|
@ -86,18 +92,12 @@ def test_eigh_net(n: int):
|
||||||
A[i][j] = complex(np.random.rand(1, 1), 0)
|
A[i][j] = complex(np.random.rand(1, 1), 0)
|
||||||
else:
|
else:
|
||||||
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
|
||||||
# Compare with scipy, scipy passed
|
msp_eigh = EighNet(True, False)
|
||||||
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
|
||||||
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
|
|
||||||
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
|
|
||||||
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)
|
|
||||||
|
|
||||||
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
|
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
|
@ -113,19 +113,21 @@ def test_eigh_net(n: int):
|
||||||
A[i][j] = complex(np.random.rand(1, 1), 0)
|
A[i][j] = complex(np.random.rand(1, 1), 0)
|
||||||
else:
|
else:
|
||||||
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
# Compare with scipy, scipy passed
|
msp_eigh = EighNet(True, False)
|
||||||
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
|
|
||||||
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
|
|
||||||
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)
|
|
||||||
|
|
||||||
# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
|
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
|
|
||||||
|
# test for real scalar complex128 no vector
|
||||||
|
msp_eigh = EighNet(False, True)
|
||||||
|
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
msp_eigh = EighNet(False, False)
|
||||||
|
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
|
@ -47,26 +47,35 @@ def test_eigh_net(n: int):
|
||||||
# test for real scalar float 32
|
# test for real scalar float 32
|
||||||
rtol = 1e-3
|
rtol = 1e-3
|
||||||
atol = 1e-4
|
atol = 1e-4
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
A = create_sym_pos_matrix(n, n, np.float32)
|
A = create_sym_pos_matrix(n, n, np.float32)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)))
|
||||||
assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
msp_eigh = EighNet(True, False)
|
||||||
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)))
|
||||||
|
assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
|
|
||||||
# test case for real scalar double 64
|
# test case for real scalar double 64
|
||||||
A = create_sym_pos_matrix(n, n, np.float64)
|
A = create_sym_pos_matrix(n, n, np.float64)
|
||||||
rtol = 1e-5
|
rtol = 1e-5
|
||||||
atol = 1e-8
|
atol = 1e-8
|
||||||
msp_eigh = EighNet(True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False)
|
msp_eigh = EighNet(True, False)
|
||||||
assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
|
assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
|
||||||
atol)
|
atol)
|
||||||
|
# test for real scalar float64 no vector
|
||||||
|
msp_eigh = EighNet(False, True)
|
||||||
|
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
|
msp_eigh = EighNet(False, False)
|
||||||
|
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64)))
|
||||||
|
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
# test case for complex64
|
# test case for complex64
|
||||||
rtol = 1e-3
|
rtol = 1e-3
|
||||||
|
@ -78,14 +87,15 @@ def test_eigh_net(n: int):
|
||||||
A[i][j] = complex(np.random.rand(1, 1), 0)
|
A[i][j] = complex(np.random.rand(1, 1), 0)
|
||||||
else:
|
else:
|
||||||
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex64)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex64)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()),
|
msp_eigh = EighNet(True, False)
|
||||||
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)))
|
||||||
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()),
|
||||||
np.zeros((n, n)), rtol, atol)
|
np.zeros((n, n)), rtol, atol)
|
||||||
assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()),
|
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()),
|
||||||
np.zeros((n, n)), rtol, atol)
|
np.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
# test for complex128
|
# test for complex128
|
||||||
|
@ -94,17 +104,24 @@ def test_eigh_net(n: int):
|
||||||
A = np.array(np.random.rand(n, n), dtype=np.complex128)
|
A = np.array(np.random.rand(n, n), dtype=np.complex128)
|
||||||
for i in range(0, n):
|
for i in range(0, n):
|
||||||
for j in range(0, n):
|
for j in range(0, n):
|
||||||
|
|
||||||
if i == j:
|
if i == j:
|
||||||
A[i][j] = complex(np.random.rand(1, 1), 0)
|
A[i][j] = complex(np.random.rand(1, 1), 0)
|
||||||
else:
|
else:
|
||||||
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
|
||||||
msp_eigh = EighNet(True)
|
|
||||||
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
|
||||||
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
|
||||||
msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex128)), True)
|
msp_eigh = EighNet(True, True)
|
||||||
msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex128)), False)
|
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()),
|
msp_eigh = EighNet(True, False)
|
||||||
|
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()),
|
||||||
np.zeros((n, n)), rtol, atol)
|
np.zeros((n, n)), rtol, atol)
|
||||||
assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()),
|
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()),
|
||||||
np.zeros((n, n)), rtol, atol)
|
np.zeros((n, n)), rtol, atol)
|
||||||
|
# test for real scalar complex128 no vector
|
||||||
|
msp_eigh = EighNet(False, True)
|
||||||
|
msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
msp_eigh = EighNet(False, False)
|
||||||
|
msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128)))
|
||||||
|
assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol)
|
||||||
|
|
|
@ -139,3 +139,91 @@ def test_cholesky_solver(n: int, lower: bool, dtype):
|
||||||
# pre tensor_a has been inplace.
|
# pre tensor_a has been inplace.
|
||||||
tensor_a = Tensor(a)
|
tensor_a = Tensor(a)
|
||||||
assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy())
|
assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||||
|
def test_eigh_solver(n: int):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||||
|
Expectation: the result match scipy cholesky_solve
|
||||||
|
"""
|
||||||
|
# test for real scalar float 32
|
||||||
|
rtol = 1e-3
|
||||||
|
atol = 1e-4
|
||||||
|
A = create_sym_pos_matrix([n, n], onp.float32)
|
||||||
|
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False)
|
||||||
|
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False)
|
||||||
|
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||||
|
rtol,
|
||||||
|
atol)
|
||||||
|
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
|
||||||
|
rtol,
|
||||||
|
atol)
|
||||||
|
|
||||||
|
# test case for real scalar double 64
|
||||||
|
A = create_sym_pos_matrix([n, n], onp.float64)
|
||||||
|
rtol = 1e-5
|
||||||
|
atol = 1e-8
|
||||||
|
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False)
|
||||||
|
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False)
|
||||||
|
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||||
|
rtol,
|
||||||
|
atol)
|
||||||
|
assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)),
|
||||||
|
rtol,
|
||||||
|
atol)
|
||||||
|
# test for real scalar float64 no vector
|
||||||
|
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True)
|
||||||
|
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True)
|
||||||
|
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||||
|
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
|
# test case for complex64
|
||||||
|
rtol = 1e-3
|
||||||
|
atol = 1e-4
|
||||||
|
A = onp.array(onp.random.rand(n, n), dtype=onp.complex64)
|
||||||
|
for i in range(0, n):
|
||||||
|
for j in range(0, n):
|
||||||
|
if i == j:
|
||||||
|
A[i][j] = complex(onp.random.rand(1, 1), 0)
|
||||||
|
else:
|
||||||
|
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
|
||||||
|
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
|
||||||
|
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
|
||||||
|
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False)
|
||||||
|
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False)
|
||||||
|
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
|
||||||
|
onp.zeros((n, n)), rtol, atol)
|
||||||
|
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
|
||||||
|
onp.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
|
# test for complex128
|
||||||
|
rtol = 1e-5
|
||||||
|
atol = 1e-8
|
||||||
|
A = onp.array(onp.random.rand(n, n), dtype=onp.complex128)
|
||||||
|
for i in range(0, n):
|
||||||
|
for j in range(0, n):
|
||||||
|
|
||||||
|
if i == j:
|
||||||
|
A[i][j] = complex(onp.random.rand(1, 1), 0)
|
||||||
|
else:
|
||||||
|
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
|
||||||
|
sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
|
||||||
|
sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
|
||||||
|
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False)
|
||||||
|
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False)
|
||||||
|
assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
|
||||||
|
onp.zeros((n, n)), rtol, atol)
|
||||||
|
assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
|
||||||
|
onp.zeros((n, n)), rtol, atol)
|
||||||
|
|
||||||
|
# test for real scalar float64 no vector
|
||||||
|
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True)
|
||||||
|
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True)
|
||||||
|
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||||
|
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||||
|
|
Loading…
Reference in New Issue