forked from mindspore-Ecosystem/mindspore
Optimize eigh, backend support 1 or 2 output for GPU/CPU
This commit is contained in:
parent
fddedb03c8
commit
ba3b65b9af
|
@ -47,52 +47,59 @@ void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SolveSelfAdjointMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *output, Map<MatrixSquare<T>> *outputv,
|
||||
bool compute_eigen_vectors) {
|
||||
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(A);
|
||||
void SolveSelfAdjointMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *output, bool compute_eigen_vectors,
|
||||
size_t m, T *output_v_addr = nullptr) {
|
||||
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(
|
||||
A, compute_eigen_vectors ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly);
|
||||
output->noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv->noalias() = solver.eigenvectors();
|
||||
if (compute_eigen_vectors && output_v_addr != nullptr) {
|
||||
Map<MatrixSquare<T>> outputv(output_v_addr, m, m);
|
||||
outputv.noalias() = solver.eigenvectors();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *output, Map<MatrixSquare<T>> *outputv,
|
||||
bool compute_eigen_vectors) {
|
||||
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(A);
|
||||
void SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *output, bool compute_eigen_vectors,
|
||||
size_t m, T *output_v_addr = nullptr) {
|
||||
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(A, compute_eigen_vectors);
|
||||
output->noalias() = solver.eigenvalues();
|
||||
if (compute_eigen_vectors) {
|
||||
outputv->noalias() = solver.eigenvectors();
|
||||
if (compute_eigen_vectors && output_v_addr != nullptr) {
|
||||
Map<MatrixSquare<T>> outputv(output_v_addr, m, m);
|
||||
outputv.noalias() = solver.eigenvectors();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
void EighCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
NativeCpuKernelMod::InitInputOutputSize(kernel_node);
|
||||
(void)workspace_size_list_.emplace_back(m_ * m_ * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// is the Matrix a symmetric matrix(true lower triangle, false upper triangle)
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output_v_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
T *a_Work_dir = reinterpret_cast<T *>(workspace[0]->addr);
|
||||
Map<MatrixSquare<T>> A(A_addr, m_, m_);
|
||||
Map<MatrixSquare<T>> A_(A_addr, m_, m_);
|
||||
Map<MatrixSquare<T>> A_(a_Work_dir, m_, m_);
|
||||
Map<MatrixSquare<T>> output(output_addr, m_, 1);
|
||||
Map<MatrixSquare<T>> outputv(output_v_addr, m_, m_);
|
||||
// selfadjoint matrix
|
||||
if (lower_) {
|
||||
A_ = A.template selfadjointView<Lower>();
|
||||
} else {
|
||||
A_ = A.template selfadjointView<Upper>();
|
||||
}
|
||||
// Real scalar eigen solver
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||
T *output_v_addr = nullptr;
|
||||
if (compute_eigen_vectors_) {
|
||||
output_v_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
}
|
||||
if constexpr (std::is_same<T, float>::value || std::is_same<T, double>::value) {
|
||||
SolveSelfAdjointMatrix(A_, &output, compute_eigen_vectors_, m_, output_v_addr);
|
||||
} else {
|
||||
// complex eigen solver
|
||||
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors_);
|
||||
SolveComplexMatrix(A_, &output, compute_eigen_vectors_, m_, output_v_addr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ class EighCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
size_t m_{1};
|
||||
|
@ -48,6 +49,16 @@ class EighCpuKernelMod : public NativeCpuKernelMod {
|
|||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighCpuKernelMod, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EighCpuKernelMod, float_complex);
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EighCpuKernelMod, double_complex);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EighcGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EighcGpuKernelMod, Complex<double>)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
|
|
|
@ -94,10 +94,10 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
"CublasSetStream failed");
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
// matrix A, input or output(eigenvector)
|
||||
"CusolverDnSetStream failed");
|
||||
// Matrix A, input or output(eigenvector)
|
||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
|
@ -106,18 +106,23 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
size_t lda_ = m_;
|
||||
auto output_w_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
// output eigenvector
|
||||
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1);
|
||||
// Output eigenvector if need
|
||||
T *output_v_addr = nullptr;
|
||||
if (compute_eigen_vectors_) {
|
||||
output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
|
||||
} else {
|
||||
output_v_addr = GetDeviceAddress<T>(workspace, kDim6); // not output eigenvalues, use workspace
|
||||
}
|
||||
int *devInfo = GetDeviceAddress<int>(workspace, kDim0);
|
||||
// temp output eigenvalues real scalar
|
||||
// Temp output eigenvalues real scalar
|
||||
auto w_w_addr = GetDeviceAddress<D>(workspace, kDim1);
|
||||
auto w_w_c_addr = GetDeviceAddress<T>(workspace, kDim2);
|
||||
// temp eigenvector before transpose
|
||||
// Temp eigenvector before transpose
|
||||
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim3);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"copy input matrix failed");
|
||||
"Copy input matrix failed");
|
||||
size_t input_shape[kShape2dDims] = {m_, m_};
|
||||
size_t input_axis[kShape2dDims] = {1, 0};
|
||||
size_t *dev_input_shape = GetDeviceAddress<size_t>(workspace, kDim4);
|
||||
|
@ -125,11 +130,11 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"malloc input shape workspace failed");
|
||||
"Malloc input shape workspace failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"malloc input shape workspace failed");
|
||||
"Malloc input shape workspace failed");
|
||||
CalTranspose(m_ * m_, output_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, w_v_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
|
@ -156,18 +161,20 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(w_w_c_addr, w_w_addr, m_ * sizeof(D), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"copy eigenvalue from workspace to host failed");
|
||||
// convert real scalar to complex
|
||||
"Copy eigenvalue from workspace to host failed");
|
||||
// Convert real scalar to complex
|
||||
RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (compute_eigen_vectors_) {
|
||||
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
|
||||
int info_gpu = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&info_gpu, devInfo, sizeof(int), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"copy eigenvalues to outpu failed");
|
||||
"Copy eigenvalues to outpu failed");
|
||||
if (info_gpu != 0) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " launch gpu kernel fail for dtype:" << dtype_;
|
||||
}
|
||||
|
@ -176,21 +183,27 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
// in/out matrix, eigenvector
|
||||
// In/out matrix, eigenvector
|
||||
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
|
||||
// Eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
|
||||
output_size_list_.push_back(m_ * sizeof(T));
|
||||
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// result
|
||||
// Eigenvector if need
|
||||
if (compute_eigen_vectors_) {
|
||||
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(int));
|
||||
// for temp original eigenvalue real scalar
|
||||
// For temp original eigenvalue real scalar
|
||||
workspace_size_list_.push_back(m_ * sizeof(D));
|
||||
// for temp pre-transpose complex mitrx
|
||||
// For temp pre-transpose complex mitrx
|
||||
workspace_size_list_.push_back(m_ * sizeof(T));
|
||||
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// transpose scalar workspace
|
||||
// Transpose scalar workspace
|
||||
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
|
||||
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
|
||||
// A temp space for input/eigenvectors if eigenvector not need to output
|
||||
if (!compute_eigen_vectors_) {
|
||||
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
size_t m_{1};
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -79,25 +79,31 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
|
|||
return true;
|
||||
}
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
// matrix A, input or output(eigenvector)
|
||||
"CusolverDnSetStream failed");
|
||||
// Matrix A, input or output(eigenvector)
|
||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
// Notice :this is important
|
||||
// a col or row major is different to cpu, so a lower triangle is a upper triangle, a upper is a lower in gpu mem
|
||||
// so the upper is positive to it from var, but for real scalar matrix, upper eq lower, it's different from complex
|
||||
// A col or row major is different to cpu, so a lower triangle is a upper triangle, a upper is a lower in gpu mem
|
||||
// So the upper is positive to it from var, but for real scalar matrix, upper eq lower, it's different from complex
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
}
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues
|
||||
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues
|
||||
// Output eigenvector if need
|
||||
T *output_v_addr = nullptr;
|
||||
if (compute_eigen_vectors_) {
|
||||
output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
|
||||
} else {
|
||||
output_v_addr = GetDeviceAddress<T>(workspace, kDim4); // not output eigenvalues, use workspace
|
||||
}
|
||||
int *devInfo = GetDeviceAddress<int>(workspace, kDim0);
|
||||
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim1); // temp eigenvector before transpose
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(w_v_addr, inout_A_addr, m_ * m_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"copy to input matrix failed");
|
||||
"Copy to input matrix failed");
|
||||
size_t lda_ = m_;
|
||||
int lwork = 0;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
|
@ -122,14 +128,16 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
|
|||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (compute_eigen_vectors_) {
|
||||
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
|
||||
int info_gpu = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&info_gpu, devInfo, sizeof(int), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"copy to device result failed");
|
||||
"Copy to device result failed");
|
||||
if (info_gpu != 0) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " launch gpu kernel fail for dtype:" << dtype_;
|
||||
}
|
||||
|
@ -138,18 +146,23 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
// in/out matrix, eigenvector
|
||||
// In/out matrix, eigenvector
|
||||
input_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// eigenvalues
|
||||
// Eigenvalues
|
||||
output_size_list_.push_back(m_ * sizeof(T));
|
||||
// eigenvector
|
||||
output_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// result
|
||||
// Eigenvector if need
|
||||
if (compute_eigen_vectors_) {
|
||||
output_size_list_.push_back(m_ * m_ * sizeof(T)); // eigenvector output
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(int));
|
||||
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
// transpose scalar workspace
|
||||
// Transpose scalar workspace
|
||||
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
|
||||
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
|
||||
// A temp space for input/eigenvectors if eigenvector not need to output
|
||||
if (!compute_eigen_vectors_) {
|
||||
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
size_t m_{1};
|
||||
|
|
|
@ -227,12 +227,20 @@ class Eigh(PrimitiveWithInfer):
|
|||
validator.check_scalar_or_tensor_types_same({"A_dtype": A['dtype']},
|
||||
[mstype.float32, mstype.float64, mstype.complex64,
|
||||
mstype.complex128], self.name, True)
|
||||
shape = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (A['dtype'], A['dtype']),
|
||||
'value': None
|
||||
}
|
||||
return shape
|
||||
output = None
|
||||
if self.compute_eigenvectors:
|
||||
output = {
|
||||
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
|
||||
'dtype': (A['dtype'], A['dtype']),
|
||||
'value': None
|
||||
}
|
||||
else:
|
||||
output = {
|
||||
'shape': (A['shape'][0],),
|
||||
'dtype': A['dtype'],
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class EighNet(nn.Cell):
|
||||
|
@ -249,10 +257,7 @@ class EighNet(nn.Cell):
|
|||
def construct(self, A):
|
||||
if F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float64)
|
||||
r = self.eigh(A)
|
||||
if self.bv:
|
||||
return (r[0], r[1])
|
||||
return r[0]
|
||||
return self.eigh(A)
|
||||
|
||||
|
||||
class Eig(PrimitiveWithInfer):
|
||||
|
|
|
@ -121,12 +121,13 @@ def get_bprpo_eigh(self):
|
|||
eigh = Eigh(compute_eigenvectors=True)
|
||||
|
||||
def bprop(a, out, dout):
|
||||
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
||||
if not is_compute_v:
|
||||
w, grad_w = out, dout
|
||||
# w, _ = Eigh(compute_eigenvectors=False)(a) -> a * _ = w * _
|
||||
_, v = eigh(a)
|
||||
grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v))
|
||||
else:
|
||||
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
||||
# w, v = Eigh(compute_eigenvectors=True)(a) -> a * v = w * v
|
||||
vh_gv = _matmul(_adjoint(v), grad_v)
|
||||
f = _compute_f(w)
|
||||
|
|
|
@ -86,11 +86,12 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
|
|||
self.eigh = Eigh(compute_eigenvectors, lower)
|
||||
|
||||
def construct(self, a):
|
||||
w, v = self.eigh(a)
|
||||
res = None
|
||||
if self.compute_eigenvectors:
|
||||
w, v = self.eigh(a)
|
||||
res = self.sum(w) + self.mean(v)
|
||||
else:
|
||||
w = self.eigh(a)
|
||||
res = self.mean(w)
|
||||
return res
|
||||
|
||||
|
|
Loading…
Reference in New Issue