forked from mindspore-Ecosystem/mindspore
add cholesky, cho_factor primitive and backend gpu implements
This commit is contained in:
parent
f537f1ae1a
commit
36032e7ee2
|
@ -37,3 +37,4 @@ void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_
|
|||
}
|
||||
|
||||
template void Eye<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);
|
||||
template void Eye<double>(const size_t size, const size_t dim, double *output_addr, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -68,3 +68,5 @@ void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T
|
|||
template void MatrixSplit<float>(const size_t size, const size_t split_dim, const size_t dim, float *input_addr,
|
||||
float *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixSplit<double>(const size_t size, const size_t split_dim, const size_t dim, double *input_addr,
|
||||
double *output_addr, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,47 +18,25 @@
|
|||
template <typename T>
|
||||
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0;
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0;
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
if (col > row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = input[i];
|
||||
}
|
||||
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
|
||||
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
if (col < row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ScipyTriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
|
||||
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
if (col < row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = input[i];
|
||||
}
|
||||
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
|
||||
} else {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
if (col > row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
|
@ -87,12 +65,8 @@ template void TriangleMatrixCopy<float>(const float *input, float *output, cubla
|
|||
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream) {
|
||||
ScipyTriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
|
||||
return;
|
||||
}
|
||||
template void TriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
|
@ -100,15 +74,6 @@ void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda
|
|||
return;
|
||||
}
|
||||
|
||||
template void ScipyTriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ScipyTriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void ScipyTriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixCopy<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void MatrixCopy<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void MatrixCopy<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -22,10 +22,6 @@ template <typename T>
|
|||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
|
||||
|
|
|
@ -39,6 +39,12 @@ static constexpr char kAvgPoolingModeUpperCase[] = "AVG";
|
|||
// Used by Pooling
|
||||
static constexpr char kAvgPoolingModeLowerCase[] = "avg";
|
||||
|
||||
// Used by cholesky
|
||||
static constexpr char kLower[] = "lower";
|
||||
|
||||
// Used by cholesky
|
||||
static constexpr char kSplitDim[] = "split_dim";
|
||||
|
||||
// Used by MaxPool pad: The minimum value of float32
|
||||
static constexpr float kSignedMinFloat = -3.402823466e+38F;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,10 +17,10 @@
|
|||
#include "backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ScipyCholeskyGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ScipyCholeskyGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CholeskyGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,7 +20,8 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
@ -38,81 +39,53 @@ constexpr size_t kCholeskyNormalShape = 2;
|
|||
constexpr size_t kCholeskyBatchedShape = 3;
|
||||
|
||||
template <typename T>
|
||||
class ScipyCholeskyGpuKernel : public GpuKernel {
|
||||
class CholeskyGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ScipyCholeskyGpuKernel() = default;
|
||||
~ScipyCholeskyGpuKernel() = default;
|
||||
using pointer = T *;
|
||||
|
||||
CholeskyGpuKernel() = default;
|
||||
~CholeskyGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
// here all addresses are malloc by cuda, so deal with them as device's address.
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
auto d_array_addr = GetDeviceAddress<T *>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = input1_addr + i * lda_ * m_;
|
||||
if (!use_split_matrix_) {
|
||||
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
// 5. copy input's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
// 6. solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// 7. copy results from written input's matrix to output's matrix by up or lower flag.
|
||||
ScipyTriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
return SplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, "lower"));
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
|
||||
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
}
|
||||
// 1. get CuSolver Dense matrix handler
|
||||
// get CuSolver Dense matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
// 2. get Cublas handler
|
||||
// get Cublas handler
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
|
||||
|
||||
// 3. check input shape not null
|
||||
bool is_null_input = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
|
||||
is_null_input_ = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(EXCEPTION) << "For 'CholeskyGpuKernel', input is null";
|
||||
}
|
||||
// 4. calculate input size
|
||||
if (!InitInputSize(in_shape)) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
|
||||
|
||||
if (split_dim_ == 0) {
|
||||
return InitNoSplitDim(in_shape);
|
||||
}
|
||||
return true;
|
||||
return InitSplitDim(in_shape);
|
||||
}
|
||||
|
||||
private:
|
||||
bool InitInputSize(const std::vector<size_t> &in_shape) {
|
||||
protected:
|
||||
bool InitNoSplitDim(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() == kCholeskyDefaultShape) {
|
||||
batch_ = 1;
|
||||
cho_row_ = in_shape.at(kDim0);
|
||||
|
@ -134,7 +107,46 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
|
|||
return false;
|
||||
}
|
||||
// set matrix row or col to be lead dimension
|
||||
m_ = SizeToInt(cho_row_);
|
||||
m_ = SizeToInt(in_shape.at(kDim1));
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InitSplitDim(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() != kCholeskyNormalShape) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
|
||||
return false;
|
||||
}
|
||||
cho_row_ = in_shape.at(kDim0);
|
||||
cho_col_ = in_shape.at(kDim1);
|
||||
if (cho_row_ != cho_col_) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (SizeToInt(cho_row_) <= split_dim_) {
|
||||
use_split_matrix_ = false;
|
||||
batch_ = 1;
|
||||
m_ = SizeToInt(in_shape.at(kDim1));
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
use_split_matrix_ = true;
|
||||
size_t batch = cho_col_ / split_dim_;
|
||||
res_dim_ = cho_col_ - batch * split_dim_;
|
||||
if (res_dim_ == 0) {
|
||||
batch_ = batch;
|
||||
} else {
|
||||
batch_ = batch + 1;
|
||||
}
|
||||
m_ = split_dim_;
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
|
@ -143,20 +155,99 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_ = batch_ * m_ * lda_ * unit_size_;
|
||||
input_size_list_.push_back(input_size_);
|
||||
|
||||
output_size_ = batch_ * m_ * lda_ * unit_size_;
|
||||
output_size_list_.push_back(output_size_);
|
||||
|
||||
size_t workspace_size = batch_ * sizeof(T *);
|
||||
workspace_size_list_.resize(kDim2);
|
||||
workspace_size_list_[kDim0] = workspace_size;
|
||||
|
||||
size_t workspace_size = batch_ * sizeof(pointer);
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
workspace_size = batch_ * sizeof(int);
|
||||
workspace_size_list_[kDim1] = workspace_size;
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
|
||||
size_t input_size;
|
||||
if (!use_split_matrix_) {
|
||||
input_size = batch_ * m_ * lda_ * unit_size_;
|
||||
} else {
|
||||
input_size = cho_row_ * cho_col_ * unit_size_;
|
||||
workspace_size = batch_ * m_ * lda_ * unit_size_;
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
}
|
||||
input_size_list_.push_back(input_size);
|
||||
size_t output_size = batch_ * m_ * lda_ * unit_size_;
|
||||
output_size_list_.push_back(output_size);
|
||||
}
|
||||
|
||||
bool NoSplitLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
// here all addresses are malloc by cuda, so deal with them as device's address.
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = input1_addr + i * lda_ * m_;
|
||||
}
|
||||
|
||||
// copy input's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
// solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// copy results from written input's matrix to output's matrix by up or lower flag.
|
||||
TriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SplitLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, kDim2);
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = d_batch_input_addr + i * lda_ * m_;
|
||||
}
|
||||
Eye(batch_ * split_dim_ * split_dim_, split_dim_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
MatrixSplit(batch_ * split_dim_ * split_dim_, split_dim_, cho_col_, input1_addr, d_batch_input_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
|
||||
}
|
||||
|
||||
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t unit_size_{sizeof(T)};
|
||||
size_t cho_row_{0};
|
||||
size_t cho_col_{0};
|
||||
|
@ -164,18 +255,20 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
|
|||
size_t m_{0};
|
||||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
size_t input_size_{0};
|
||||
size_t output_size_{0};
|
||||
int res_dim_{0};
|
||||
int split_dim_{0};
|
||||
bool is_null_input_{false};
|
||||
bool use_split_matrix_{false};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
std::vector<pointer> h_array_;
|
||||
bool lower_{false};
|
||||
std::vector<T *> h_array_{};
|
||||
std::vector<size_t> input_size_list_{};
|
||||
std::vector<size_t> output_size_list_{};
|
||||
std::vector<size_t> workspace_size_list_{};
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,7 +17,14 @@
|
|||
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
CholeskySolver,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskySolveGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
CholeskySolver,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CholeskySolveGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,8 +20,6 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
@ -30,248 +28,150 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kCholeskyInputsNum = 1;
|
||||
constexpr size_t kInputIndex = 0;
|
||||
constexpr size_t kCholeskyOutputsNum = 1;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
constexpr size_t kCholeskyDefaultShape = 1;
|
||||
constexpr size_t kCholeskyNormalShape = 2;
|
||||
constexpr size_t kCholeskyBatchedShape = 3;
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGpuKernel : public GpuKernel {
|
||||
class CholeskySolveGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CholeskyGpuKernel()
|
||||
: batch_(0),
|
||||
m_(0),
|
||||
lda_(0),
|
||||
ldb_(0),
|
||||
res_dim_(0),
|
||||
split_dim_(0),
|
||||
is_null_input_(false),
|
||||
use_split_matrix_(false),
|
||||
height_(0),
|
||||
width_(0),
|
||||
handle_(nullptr),
|
||||
blas_handle_(nullptr) {}
|
||||
~CholeskyGpuKernel() = default;
|
||||
using pointer = T *;
|
||||
|
||||
CholeskySolveGpuKernel() = default;
|
||||
~CholeskySolveGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
auto d_a_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
|
||||
auto d_b_array_addr = GetDeviceAddress<pointer>(workspace, kDim1);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim2);
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_a_array_[i] = input_a_addr + i * lda_ * m_;
|
||||
h_b_array_[i] = input_b_addr + i * ldb_ * nrhs_;
|
||||
}
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
auto d_array_addr = GetDeviceAddress<T *>(workspace, 0);
|
||||
auto d_identity_addr = GetDeviceAddress<T *>(workspace, 1);
|
||||
if (!use_split_matrix_) {
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = input1_addr + i * lda_ * m_;
|
||||
h_identity_[i] = output_addr + i * ldb_ * m_;
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data_.data(), sizeof(T) * ldb_ * m_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_identity_addr, h_identity_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
TriangleMatrixCopy(input1_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_a_array_addr, h_a_array_.data(), sizeof(pointer) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_b_array_addr, h_b_array_.data(), sizeof(pointer) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
// only support rhs = 1
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnSpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_,
|
||||
d_b_array_addr, ldb_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky solve batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnDpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_,
|
||||
d_b_array_addr, ldb_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky solve batched Fail");
|
||||
} else {
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
|
||||
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, 3);
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = d_batch_input_addr + i * lda_ * m_;
|
||||
h_identity_[i] = output_addr + i * ldb_ * m_;
|
||||
}
|
||||
Eye(batch_ * split_dim_ * split_dim_, split_dim_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
MatrixSplit(batch_ * split_dim_ * split_dim_, split_dim_, width_, input1_addr, d_batch_input_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_identity_addr, h_identity_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
MS_LOG(EXCEPTION) << "cholesky solve do not support other data type but only float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// copy results from written input's matrix to output's matrix.
|
||||
MatrixCopy(input_b_addr, output_addr, output_elements, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'CholeskySolveGpuKernel', input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
|
||||
if (split_dim_ == 0) {
|
||||
if (!InitNoSpltDim(in_shape)) {
|
||||
return false;
|
||||
}
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
|
||||
// gpu input is col major default, so need to change row major.
|
||||
// in order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix
|
||||
// when input b_col is not equal to one, maybe need a normal transpose op inplace.
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
} else {
|
||||
if (!InitSpltDim(in_shape)) {
|
||||
return false;
|
||||
}
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
}
|
||||
return true;
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
auto in_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kDim0);
|
||||
auto in_b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kDim1);
|
||||
if (CHECK_NULL_INPUT(in_a_shape) || CHECK_NULL_INPUT(in_b_shape)) {
|
||||
MS_LOG(EXCEPTION) << "For 'CholeskySolveGpuKernel', input is null";
|
||||
}
|
||||
return InitDim(in_a_shape, in_b_shape);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool InitNoSpltDim(const std::vector<size_t> &in_shape) {
|
||||
use_split_matrix_ = false;
|
||||
if (in_shape.size() == 2) {
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = batch_ * m_ * lda_ * unit_size_;
|
||||
input_size_list_.emplace_back(input_size);
|
||||
input_size = batch_ * nrhs_ * ldb_ * unit_size_;
|
||||
input_size_list_.emplace_back(input_size);
|
||||
|
||||
size_t workspace_size = batch_ * sizeof(pointer);
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
workspace_size = batch_ * sizeof(int);
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
|
||||
size_t output_size = batch_ * m_ * unit_size_;
|
||||
output_size_list_.push_back(output_size);
|
||||
}
|
||||
|
||||
private:
|
||||
bool InitDim(const std::vector<size_t> &in_a_shape, const std::vector<size_t> &in_b_shape) {
|
||||
if (in_a_shape.size() == kCholeskyDefaultShape) {
|
||||
batch_ = 1;
|
||||
if (in_shape[0] != in_shape[1]) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
return false;
|
||||
}
|
||||
} else if (in_shape.size() == 3) {
|
||||
batch_ = SizeToInt(in_shape[0]);
|
||||
if (in_shape[1] != in_shape[2]) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
return false;
|
||||
}
|
||||
cho_row_ = in_a_shape.at(kDim0);
|
||||
cho_col_ = cho_row_;
|
||||
} else if (in_a_shape.size() == kCholeskyNormalShape) {
|
||||
batch_ = 1;
|
||||
cho_row_ = in_a_shape.at(kDim0);
|
||||
cho_col_ = in_a_shape.at(kDim1);
|
||||
} else if (in_a_shape.size() == kCholeskyBatchedShape) {
|
||||
batch_ = SizeToInt(in_a_shape.at(kDim0));
|
||||
cho_row_ = in_a_shape.at(kDim1);
|
||||
cho_col_ = in_a_shape.at(kDim2);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
|
||||
return false;
|
||||
}
|
||||
|
||||
m_ = SizeToInt(in_shape[1]);
|
||||
if (cho_row_ != cho_col_) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
return false;
|
||||
}
|
||||
size_t b_row = in_b_shape.size() == kCholeskyBatchedShape ? in_b_shape.at(kDim1) : in_b_shape.at(kDim0);
|
||||
if (cho_row_ != b_row) {
|
||||
MS_LOG(ERROR) << "Cholesky right hand matrix is not equal to left matrix.";
|
||||
return false;
|
||||
}
|
||||
m_ = SizeToInt(in_a_shape.at(kDim1));
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
h_identity_.resize(batch_);
|
||||
h_identity_data_.resize(m_ * m_);
|
||||
for (size_t i = 0; i < m_; i++) {
|
||||
for (size_t j = 0; j < m_; j++) {
|
||||
if (i == j) {
|
||||
h_identity_data_[i * m_ + j] = 1;
|
||||
} else {
|
||||
h_identity_data_[i * m_ + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
h_a_array_.resize(batch_);
|
||||
h_b_array_.resize(batch_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InitSpltDim(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
|
||||
return false;
|
||||
}
|
||||
height_ = in_shape[0];
|
||||
width_ = in_shape[1];
|
||||
if (height_ != width_) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
|
||||
return false;
|
||||
}
|
||||
if (SizeToInt(height_) <= split_dim_) {
|
||||
use_split_matrix_ = false;
|
||||
batch_ = 1;
|
||||
m_ = SizeToInt(in_shape[1]);
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
h_identity_.resize(batch_);
|
||||
h_identity_data_.resize(m_ * m_);
|
||||
for (size_t i = 0; i < m_; i++) {
|
||||
for (size_t j = 0; j < m_; j++) {
|
||||
if (i == j) {
|
||||
h_identity_data_[i * m_ + j] = 1;
|
||||
} else {
|
||||
h_identity_data_[i * m_ + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
InitSizeLists();
|
||||
} else {
|
||||
use_split_matrix_ = true;
|
||||
int batch = SizeToInt(in_shape[1]) / split_dim_;
|
||||
res_dim_ = in_shape[1] - batch * split_dim_;
|
||||
if (res_dim_ == 0) {
|
||||
batch_ = batch;
|
||||
} else {
|
||||
batch_ = batch + 1;
|
||||
}
|
||||
m_ = split_dim_;
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
h_identity_.resize(batch_);
|
||||
h_identity_data_.resize(m_ * m_);
|
||||
for (size_t i = 0; i < m_; i++) {
|
||||
for (size_t j = 0; j < m_; j++) {
|
||||
if (i == j) {
|
||||
h_identity_data_[i * m_ + j] = 1;
|
||||
} else {
|
||||
h_identity_data_[i * m_ + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
InitSizeLists();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
size_t unit_size = sizeof(T);
|
||||
size_t input_size;
|
||||
size_t workspace_size;
|
||||
if (!use_split_matrix_) {
|
||||
input_size = batch_ * m_ * lda_ * unit_size;
|
||||
} else {
|
||||
input_size = height_ * width_ * unit_size;
|
||||
workspace_size = batch_ * m_ * lda_ * unit_size;
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
}
|
||||
input_size_list_.push_back(input_size);
|
||||
size_t output_size = batch_ * m_ * lda_ * unit_size;
|
||||
output_size_list_.push_back(output_size);
|
||||
workspace_size = batch_ * sizeof(T *);
|
||||
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||
workspace_size = batch_ * sizeof(T *);
|
||||
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||
workspace_size = batch_ * sizeof(int);
|
||||
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t batch_;
|
||||
size_t m_;
|
||||
size_t lda_;
|
||||
size_t ldb_;
|
||||
int res_dim_;
|
||||
int split_dim_;
|
||||
bool is_null_input_;
|
||||
bool use_split_matrix_;
|
||||
size_t height_;
|
||||
size_t width_;
|
||||
cusolverDnHandle_t handle_;
|
||||
cublasHandle_t blas_handle_;
|
||||
size_t cho_row_{0};
|
||||
size_t cho_col_{0};
|
||||
size_t unit_size_{sizeof(T)};
|
||||
size_t nrhs_{1};
|
||||
size_t batch_{0};
|
||||
size_t m_{0};
|
||||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
std::vector<T *> h_array_;
|
||||
std::vector<T *> h_identity_;
|
||||
std::vector<T> h_identity_data_;
|
||||
std::vector<pointer> h_a_array_;
|
||||
std::vector<pointer> h_b_array_;
|
||||
bool lower_{false};
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -708,8 +708,14 @@ class _Cholesky(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, split_dim=0):
|
||||
def __init__(self, lower=False, clean=True, split_dim=0):
|
||||
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.lower = lower
|
||||
self.add_prim_attr('lower', self.lower)
|
||||
self.clean = clean
|
||||
self.add_prim_attr('clean', self.clean)
|
||||
self.split_dim = split_dim
|
||||
self.add_prim_attr('split_dim', self.split_dim)
|
||||
|
||||
|
@ -729,7 +735,7 @@ class _Cholesky(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
|
||||
return x1_dtype
|
||||
|
||||
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
from .. import numpy as mnp
|
||||
from .. import ops
|
||||
from .ops import SolveTriangular
|
||||
from .ops import CholeskySolver
|
||||
from .ops import Cholesky
|
||||
from ..ops import operations as P
|
||||
|
||||
__all__ = ['block_diag', 'solve_triangular', 'inv']
|
||||
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve']
|
||||
|
||||
|
||||
def block_diag(*arrs):
|
||||
|
@ -191,3 +193,128 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
"""
|
||||
matrix_inverse = P.MatrixInverse(adjoint=False)
|
||||
return matrix_inverse(a)
|
||||
|
||||
|
||||
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute the Cholesky decomposition of a matrix, to use in cho_solve
|
||||
|
||||
Returns a matrix containing the Cholesky decomposition,
|
||||
``A = L L*`` or ``A = U* U`` of a Hermitian positive-definite matrix `a`.
|
||||
The return value can be directly used as the first parameter to cho_solve.
|
||||
|
||||
.. warning::
|
||||
The returned matrix also contains random data in the entries not
|
||||
used by the Cholesky decomposition. If you need to zero these
|
||||
entries, use the function `cholesky` instead.
|
||||
|
||||
Args:
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed
|
||||
lower (bool, optional): Whether to compute the upper or lower triangular Cholesky factorization
|
||||
(Default: upper-triangular)
|
||||
overwrite_a(bool, optional): Whether to overwrite data in a (may improve performance)
|
||||
check_finite(bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
c (Tensor): Matrix whose upper or lower triangle contains the Cholesky factor of `a`.
|
||||
Other parts of the matrix contain random data.
|
||||
lower (bool, optional): Flag indicating whether the factor is in the lower or upper triangle
|
||||
|
||||
Raises:
|
||||
LinAlgError: Raised if decomposition fails.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import cho_factor
|
||||
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]).astype(onp.float32))
|
||||
>>> c, low = cho_factor(A)
|
||||
>>> c
|
||||
[[ 2.9999998 0.99999994 0.3333333 1.6666665 ]
|
||||
[ 0. 2.4494896 1.9051585 -0.27216542]
|
||||
[ 0. 0. 2.2933078 0.8559527 ]
|
||||
[ 0. 0. 0. 1.5541859 ]]
|
||||
"""
|
||||
cholesky_net = Cholesky(lower=lower, clean=False)
|
||||
c = cholesky_net(a)
|
||||
return c, lower
|
||||
|
||||
|
||||
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute the Cholesky decomposition of a matrix.
|
||||
|
||||
Returns the Cholesky decomposition, :math:`A = L L^*` or
|
||||
:math:`A = U^* U` of a Hermitian positive-definite matrix A.
|
||||
|
||||
Args:
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed
|
||||
lower (bool, optional): Whether to compute the upper- or lower-triangular Cholesky
|
||||
factorization. Default is upper-triangular.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance).
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
c (Tensor): Upper- or lower-triangular Cholesky factor of `a`.
|
||||
|
||||
Raises:
|
||||
LinAlgError: if decomposition fails.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import cholesky
|
||||
>>> a = Tensor(onp.array([[1, -2],[2, 5]]).astype(onp.float32))
|
||||
>>> L = cholesky(a, lower=True)
|
||||
>>> L
|
||||
[[1., 0.],
|
||||
[2., 1.]]
|
||||
"""
|
||||
cholesky_net = Cholesky(lower=lower, clean=True)
|
||||
c = cholesky_net(a)
|
||||
return c
|
||||
|
||||
|
||||
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||
"""Solve the linear equations Ax = b, given the Cholesky factorization of A.
|
||||
|
||||
Args:
|
||||
c_and_lower ((Tensor, bool)): Cholesky factorization of a, as given by cho_factor
|
||||
b (Tensor): Right-hand side
|
||||
overwrite_b (bool, optional): Whether to overwrite data in b (may improve performance)
|
||||
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns:
|
||||
x (Tensor):
|
||||
The solution to the system A x = b
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import cho_factor, cho_solve
|
||||
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]).astype(onp.float32))
|
||||
>>> b = Tensor(onp.array([1, 1, 1, 1]).astype(onp.float32))
|
||||
>>> c, low = cho_factor(A)
|
||||
>>> x = cho_solve((c, low), b)
|
||||
>>> x
|
||||
[-0.01749271, 0.11953353, 0.01166181, 0.1574344 ]
|
||||
"""
|
||||
(c, lower) = c_and_lower
|
||||
cholesky_solver_net = CholeskySolver(lower=lower)
|
||||
x = cholesky_solver_net(c, b)
|
||||
return x
|
||||
|
|
|
@ -96,3 +96,91 @@ class SolveTriangular(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
|
||||
self.name, True)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Cholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for _Cholesky base class.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=True, split_dim=0):
|
||||
super().__init__("Cholesky")
|
||||
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.lower = lower
|
||||
self.add_prim_attr('lower', self.lower)
|
||||
self.clean = clean
|
||||
self.add_prim_attr('clean', self.clean)
|
||||
self.split_dim = split_dim
|
||||
self.add_prim_attr('split_dim', self.split_dim)
|
||||
|
||||
def infer_shape(self, x1_shape):
|
||||
if self.split_dim != 0:
|
||||
height = x1_shape[0]
|
||||
width = x1_shape[1]
|
||||
if height <= self.split_dim:
|
||||
out_shape = [1, height, width]
|
||||
else:
|
||||
batch = height // self.split_dim
|
||||
if height != batch * self.split_dim:
|
||||
batch += 1
|
||||
out_shape = [batch, self.split_dim, self.split_dim]
|
||||
else:
|
||||
out_shape = x1_shape
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
|
||||
return x1_dtype
|
||||
|
||||
|
||||
class CholeskySolver(PrimitiveWithInfer):
|
||||
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower : bool, optional
|
||||
Whether to compute the upper or lower triangular Cholesky factorization
|
||||
(Default: upper-triangular)
|
||||
b : array
|
||||
Right-hand side
|
||||
|
||||
Inputs:
|
||||
- **A** (Tensor) - A matrix of shape :math:`(M, M)` to be decomposed.
|
||||
- **b** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(..., M)`.
|
||||
Right-hand side matrix in :math:`A x = b`.
|
||||
Returns
|
||||
-------
|
||||
x : array
|
||||
The solution to the system A x = b
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.ops import CholeskySolver
|
||||
>>> from mindspore.scipy.linalg import cho_factor
|
||||
|
||||
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]))
|
||||
>>> b = Tensor(onp.array([1.0, 1.0, 1.0, 1.0], dtype=onp.dtype))
|
||||
>>> c, lower = cho_factor(A)
|
||||
>>> cholesky_solver = CholeskySolver(lower=lower)
|
||||
>>> x = cholesky_solver(c, b)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolver")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, b):
|
||||
b_shape = b['shape']
|
||||
x_dtype = x['dtype']
|
||||
return {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
|
|
@ -1,161 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from typing import Generic
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
from mindspore.scipy.linalg import cholesky
|
||||
import scipy as scp
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
class Cholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for Cholesky base class.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=True):
|
||||
super().__init__(name="Cholesky")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = x['shape']
|
||||
x_dtype = x['dtype']
|
||||
return {
|
||||
'shape': tuple(x_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
||||
class CholeskySolver(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for CholeskySolver class.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolver")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, b):
|
||||
b_shape = b['shape']
|
||||
x_dtype = x['dtype']
|
||||
return {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
||||
class CholeskyNet(nn.Cell):
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super(CholeskyNet, self).__init__()
|
||||
self.cholesky = Cholesky(lower, clean)
|
||||
|
||||
def construct(self, x):
|
||||
return self.cholesky(x)
|
||||
|
||||
|
||||
class CholeskySolverNet(nn.Cell):
|
||||
def __init__(self, lower=False):
|
||||
super(CholeskySolverNet, self).__init__()
|
||||
self.cholesky_solver = CholeskySolver(lower)
|
||||
|
||||
def construct(self, c, b):
|
||||
return self.cholesky_solver(c, b)
|
||||
|
||||
|
||||
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
ompute the Cholesky decomposition of a matrix, to use in cho_solve.
|
||||
Returns a matrix containing the Cholesky decomposition
|
||||
"""
|
||||
cholesky_net = CholeskyNet(lower=lower, clean=False)
|
||||
c = cholesky_net(a)
|
||||
return c, lower
|
||||
|
||||
|
||||
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute the Cholesky decomposition of a matrix.
|
||||
Returns the Cholesky decomposition
|
||||
"""
|
||||
cholesky_net = CholeskyNet(lower=lower, clean=True)
|
||||
c = cholesky_net(a)
|
||||
return c
|
||||
|
||||
|
||||
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
c_and_lower: (c, lower) tuple, (array, bool)
|
||||
Cholesky factorization of a, as given by cho_factor
|
||||
b : array
|
||||
Right-hand side
|
||||
overwrite_b : bool, optional
|
||||
Whether to overwrite data in b (may improve performance)
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array
|
||||
The solution to the system A x = b
|
||||
|
||||
See also
|
||||
--------
|
||||
cho_factor : Cholesky factorization of a matrix
|
||||
|
||||
"""
|
||||
(c, lower) = c_and_lower
|
||||
cholesky_solver_net = CholeskySolverNet(lower=lower)
|
||||
x = cholesky_solver_net(c, b)
|
||||
return x
|
||||
|
||||
|
||||
def test_cholesky():
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('lower', [True])
|
||||
@pytest.mark.parametrize('dtype', [np.float64])
|
||||
def test_cholesky(lower: bool, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = np.array([[4, 12, -6], [12, 37, -43], [-16, -43, 98]], dtype=np.float32)
|
||||
a = np.array([[4, 12, -6], [12, 37, -43], [-16, -43, 98]], dtype=dtype)
|
||||
tensor_a = Tensor(a)
|
||||
scp_c_1, _ = scp.linalg.cho_factor(a, lower=True)
|
||||
mscp_c_1, _ = cho_factor(tensor_a, lower=True)
|
||||
|
||||
scp_c_2 = scp.linalg.cholesky(a, lower=True)
|
||||
mscp_c_2 = cholesky(tensor_a, lower=True)
|
||||
assert np.allclose(scp_c_1, mscp_c_1.asnumpy())
|
||||
assert np.allclose(scp_c_2, mscp_c_2.asnumpy())
|
||||
|
||||
|
||||
def test_cholesky_solver():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky solver [N,N]
|
||||
Expectation: the result match scipy cholesky_solve
|
||||
"""
|
||||
a = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]], dtype=np.float32)
|
||||
b = np.array([1, 1, 1, 1], dtype=np.float32)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
scp_c, lower = scp.linalg.cho_factor(a, lower=False)
|
||||
mscp_c, mscp_lower = cho_factor(tensor_a, lower=False)
|
||||
scp_c = scp.linalg.cholesky(a, lower=lower)
|
||||
mscp_c = cholesky(tensor_a, lower=lower)
|
||||
assert np.allclose(scp_c, mscp_c.asnumpy())
|
||||
|
||||
scp_factor = (scp_c, lower)
|
||||
ms_cho_factor = (mscp_c, mscp_lower)
|
||||
scp_x = scp.linalg.cho_solve(scp_factor, b)
|
||||
mscp_x = cho_solve(ms_cho_factor, tensor_b)
|
||||
assert np.allclose(scp_x, mscp_x.asnumpy())
|
||||
|
|
|
@ -12,99 +12,30 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from typing import Generic
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.scipy.linalg import cholesky
|
||||
import numpy as np
|
||||
import scipy as scp
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetCholesky(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCholesky, self).__init__()
|
||||
self.cholesky = P.Cholesky()
|
||||
|
||||
def construct(self, x):
|
||||
return self.cholesky(x)
|
||||
|
||||
|
||||
class ScipyCholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for Cholesky base class.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super().__init__(name="PureCholesky")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = x['shape']
|
||||
x_dtype = x['dtype']
|
||||
return {
|
||||
'shape': tuple(x_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
||||
class ScipyNetCholesky(nn.Cell):
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super(ScipyNetCholesky, self).__init__()
|
||||
self.cholesky = ScipyCholesky(lower, clean)
|
||||
|
||||
def construct(self, x):
|
||||
return self.cholesky(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cholesky_fp32():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for origin cholesky [N,N]
|
||||
Expectation: the result match np cholesky
|
||||
"""
|
||||
cholesky = NetCholesky()
|
||||
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
|
||||
output = cholesky(Tensor(x, dtype=mstype.float32))
|
||||
expect = np.linalg.cholesky(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scipy_cholesky_fp32():
|
||||
@pytest.mark.parametrize('lower', [True])
|
||||
@pytest.mark.parametrize('dtype', [np.float64])
|
||||
def test_scipy_cholesky(lower: bool, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for new scipy cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
|
||||
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(dtype=dtype)
|
||||
tensor_a = Tensor(a)
|
||||
cholesky = ScipyNetCholesky(lower=True, clean=False)
|
||||
output = cholesky(tensor_a)
|
||||
|
||||
cholesky1 = ScipyNetCholesky(lower=False, clean=False)
|
||||
output1 = cholesky1(tensor_a)
|
||||
|
||||
expect = scp.linalg.cholesky(a, lower=True)
|
||||
expect1 = scp.linalg.cholesky(a, lower=False)
|
||||
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-5
|
||||
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
|
||||
assert np.allclose(expect1, output1.asnumpy(), rtol=rtol, atol=atol)
|
||||
output = cholesky(tensor_a, lower=lower)
|
||||
expect = scp.linalg.cholesky(a, lower=lower)
|
||||
assert np.allclose(expect, output.asnumpy())
|
||||
|
|
|
@ -14,13 +14,19 @@
|
|||
# ============================================================================
|
||||
"""st for scipy.linalg."""
|
||||
|
||||
from typing import Generic
|
||||
import pytest
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.scipy as msp
|
||||
from .utils import match_array, create_full_rank_matrix
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.numpy as mnp
|
||||
from .utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
|
||||
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -59,3 +65,77 @@ def test_inv(dtype, shape):
|
|||
ms_res = msp.linalg.inv(Tensor(x))
|
||||
scipy_res = onp.linalg.inv(x)
|
||||
match_array(ms_res.asnumpy(), scipy_res, error=3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cholesky(n: int, lower: bool, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
tensor_a = Tensor(a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
osp_c = osp.linalg.cholesky(a, lower=lower)
|
||||
msp_c = msp.linalg.cholesky(tensor_a, lower=lower)
|
||||
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cho_factor(n: int, lower: bool, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
tensor_a = Tensor(a)
|
||||
msp_c, _ = msp.linalg.cho_factor(tensor_a, lower=lower)
|
||||
if lower:
|
||||
msp_reconstruct_a = mnp.dot(mnp.tril(msp_c), mnp.tril(msp_c).T)
|
||||
else:
|
||||
msp_reconstruct_a = mnp.dot(mnp.triu(msp_c).T, mnp.triu(msp_c))
|
||||
assert onp.allclose(a, msp_reconstruct_a.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cholesky_solver(n: int, lower: bool, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky solver [N,N]
|
||||
Expectation: the result match scipy cholesky_solve
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
b = onp.ones((n, 1), dtype=dtype)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
osp_c, lower = osp.linalg.cho_factor(a, lower=lower)
|
||||
msp_c, msp_lower = msp.linalg.cho_factor(tensor_a, lower=lower)
|
||||
osp_factor = (osp_c, lower)
|
||||
|
||||
ms_cho_factor = (msp_c, msp_lower)
|
||||
osp_x = osp.linalg.cho_solve(osp_factor, b)
|
||||
msp_x = msp.linalg.cho_solve(ms_cho_factor, tensor_b)
|
||||
# pre tensor_a has been inplace.
|
||||
tensor_a = Tensor(a)
|
||||
assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy())
|
||||
|
|
Loading…
Reference in New Issue