add cholesky, cho_factor primitive and backend gpu implements

This commit is contained in:
z00512249 2021-11-15 14:21:32 +08:00
parent f537f1ae1a
commit 36032e7ee2
15 changed files with 661 additions and 577 deletions

View File

@ -37,3 +37,4 @@ void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_
}
template void Eye<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);
template void Eye<double>(const size_t size, const size_t dim, double *output_addr, cudaStream_t cuda_stream);

View File

@ -68,3 +68,5 @@ void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T
template void MatrixSplit<float>(const size_t size, const size_t split_dim, const size_t dim, float *input_addr,
float *output_addr, cudaStream_t cuda_stream);
template void MatrixSplit<double>(const size_t size, const size_t split_dim, const size_t dim, double *input_addr,
double *output_addr, cudaStream_t cuda_stream);

View File

@ -18,47 +18,25 @@
template <typename T>
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0;
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
if (col > row) {
output[i] = 0;
} else {
output[i] = input[i];
}
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
if (uplo == CUBLAS_FILL_MODE_UPPER) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
if (col < row) {
output[i] = 0;
} else {
output[i] = input[i];
}
}
}
}
template <typename T>
__global__ void ScipyTriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
if (uplo == CUBLAS_FILL_MODE_UPPER) {
if (col < row) {
output[i] = 0;
} else {
output[i] = input[i];
}
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
if (col > row) {
output[i] = 0;
} else {
@ -87,12 +65,8 @@ template void TriangleMatrixCopy<float>(const float *input, float *output, cubla
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template <typename T>
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream) {
ScipyTriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
return;
}
template void TriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template <typename T>
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
@ -100,15 +74,6 @@ void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda
return;
}
template void ScipyTriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m,
cudaStream_t cuda_stream);
template void ScipyTriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template void ScipyTriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m,
cudaStream_t cuda_stream);
template void MatrixCopy<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void MatrixCopy<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void MatrixCopy<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -22,10 +22,6 @@ template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream);
template <typename T>
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream);
template <typename T>
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_

View File

@ -39,6 +39,12 @@ static constexpr char kAvgPoolingModeUpperCase[] = "AVG";
// Used by Pooling
static constexpr char kAvgPoolingModeLowerCase[] = "avg";
// Used by cholesky
static constexpr char kLower[] = "lower";
// Used by cholesky
static constexpr char kSplitDim[] = "split_dim";
// Used by MaxPool pad: The minimum value of float32
static constexpr float kSignedMinFloat = -3.402823466e+38F;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,10 +17,10 @@
#include "backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScipyCholeskyGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ScipyCholeskyGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskyGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,8 @@
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include <type_traits>
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
@ -38,81 +39,53 @@ constexpr size_t kCholeskyNormalShape = 2;
constexpr size_t kCholeskyBatchedShape = 3;
template <typename T>
class ScipyCholeskyGpuKernel : public GpuKernel {
class CholeskyGpuKernel : public GpuKernel {
public:
ScipyCholeskyGpuKernel() = default;
~ScipyCholeskyGpuKernel() = default;
using pointer = T *;
CholeskyGpuKernel() = default;
~CholeskyGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// here all addresses are malloc by cuda, so deal with them as device's address.
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_array_addr = GetDeviceAddress<T *>(workspace, kDim0);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = input1_addr + i * lda_ * m_;
if (!use_split_matrix_) {
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
}
// 5. copy input's addr to d_array_addr
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
// 6. solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else {
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
}
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
// 7. copy results from written input's matrix to output's matrix by up or lower flag.
ScipyTriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
return SplitLaunch(inputs, workspace, outputs, stream_ptr);
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, "lower"));
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
}
// 1. get CuSolver Dense matrix handler
// get CuSolver Dense matrix handler
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
// 2. get Cublas handler
// get Cublas handler
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
// 3. check input shape not null
bool is_null_input = CHECK_NULL_INPUT(in_shape);
if (is_null_input) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
is_null_input_ = CHECK_NULL_INPUT(in_shape);
if (is_null_input_) {
MS_LOG(EXCEPTION) << "For 'CholeskyGpuKernel', input is null";
}
// 4. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
if (split_dim_ == 0) {
return InitNoSplitDim(in_shape);
}
return true;
return InitSplitDim(in_shape);
}
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
protected:
bool InitNoSplitDim(const std::vector<size_t> &in_shape) {
if (in_shape.size() == kCholeskyDefaultShape) {
batch_ = 1;
cho_row_ = in_shape.at(kDim0);
@ -134,7 +107,46 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
return false;
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(cho_row_);
m_ = SizeToInt(in_shape.at(kDim1));
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
InitSizeLists();
return true;
}
bool InitSplitDim(const std::vector<size_t> &in_shape) {
if (in_shape.size() != kCholeskyNormalShape) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
return false;
}
cho_row_ = in_shape.at(kDim0);
cho_col_ = in_shape.at(kDim1);
if (cho_row_ != cho_col_) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
return false;
}
if (SizeToInt(cho_row_) <= split_dim_) {
use_split_matrix_ = false;
batch_ = 1;
m_ = SizeToInt(in_shape.at(kDim1));
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
InitSizeLists();
return true;
}
use_split_matrix_ = true;
size_t batch = cho_col_ / split_dim_;
res_dim_ = cho_col_ - batch * split_dim_;
if (res_dim_ == 0) {
batch_ = batch;
} else {
batch_ = batch + 1;
}
m_ = split_dim_;
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
@ -143,20 +155,99 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
}
void InitSizeLists() override {
input_size_ = batch_ * m_ * lda_ * unit_size_;
input_size_list_.push_back(input_size_);
output_size_ = batch_ * m_ * lda_ * unit_size_;
output_size_list_.push_back(output_size_);
size_t workspace_size = batch_ * sizeof(T *);
workspace_size_list_.resize(kDim2);
workspace_size_list_[kDim0] = workspace_size;
size_t workspace_size = batch_ * sizeof(pointer);
workspace_size_list_.emplace_back(workspace_size);
workspace_size = batch_ * sizeof(int);
workspace_size_list_[kDim1] = workspace_size;
workspace_size_list_.emplace_back(workspace_size);
size_t input_size;
if (!use_split_matrix_) {
input_size = batch_ * m_ * lda_ * unit_size_;
} else {
input_size = cho_row_ * cho_col_ * unit_size_;
workspace_size = batch_ * m_ * lda_ * unit_size_;
workspace_size_list_.emplace_back(workspace_size);
}
input_size_list_.push_back(input_size);
size_t output_size = batch_ * m_ * lda_ * unit_size_;
output_size_list_.push_back(output_size);
}
bool NoSplitLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
// here all addresses are malloc by cuda, so deal with them as device's address.
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = input1_addr + i * lda_ * m_;
}
// copy input's addr to d_array_addr
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
// solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else {
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
}
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
// copy results from written input's matrix to output's matrix by up or lower flag.
TriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool SplitLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, kDim2);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = d_batch_input_addr + i * lda_ * m_;
}
Eye(batch_ * split_dim_ * split_dim_, split_dim_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
MatrixSplit(batch_ * split_dim_ * split_dim_, split_dim_, cho_col_, input1_addr, d_batch_input_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else {
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
}
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
private:
size_t unit_size_{sizeof(T)};
size_t cho_row_{0};
size_t cho_col_{0};
@ -164,18 +255,20 @@ class ScipyCholeskyGpuKernel : public GpuKernel {
size_t m_{0};
size_t lda_{0};
size_t ldb_{0};
size_t input_size_{0};
size_t output_size_{0};
int res_dim_{0};
int split_dim_{0};
bool is_null_input_{false};
bool use_split_matrix_{false};
cusolverDnHandle_t handle_{nullptr};
cublasHandle_t blas_handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
std::vector<pointer> h_array_;
bool lower_{false};
std::vector<T *> h_array_{};
std::vector<size_t> input_size_list_{};
std::vector<size_t> output_size_list_{};
std::vector<size_t> workspace_size_list_{};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,7 +17,14 @@
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
CholeskySolver,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskySolveGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
CholeskySolver,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskySolveGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -20,8 +20,6 @@
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
@ -30,248 +28,150 @@
namespace mindspore {
namespace kernel {
constexpr size_t kCholeskyInputsNum = 1;
constexpr size_t kInputIndex = 0;
constexpr size_t kCholeskyOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kCholeskyDefaultShape = 1;
constexpr size_t kCholeskyNormalShape = 2;
constexpr size_t kCholeskyBatchedShape = 3;
template <typename T>
class CholeskyGpuKernel : public GpuKernel {
class CholeskySolveGpuKernel : public GpuKernel {
public:
CholeskyGpuKernel()
: batch_(0),
m_(0),
lda_(0),
ldb_(0),
res_dim_(0),
split_dim_(0),
is_null_input_(false),
use_split_matrix_(false),
height_(0),
width_(0),
handle_(nullptr),
blas_handle_(nullptr) {}
~CholeskyGpuKernel() = default;
using pointer = T *;
CholeskySolveGpuKernel() = default;
~CholeskySolveGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_a_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
auto d_b_array_addr = GetDeviceAddress<pointer>(workspace, kDim1);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim2);
for (size_t i = 0; i < batch_; i++) {
h_a_array_[i] = input_a_addr + i * lda_ * m_;
h_b_array_[i] = input_b_addr + i * ldb_ * nrhs_;
}
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
auto d_array_addr = GetDeviceAddress<T *>(workspace, 0);
auto d_identity_addr = GetDeviceAddress<T *>(workspace, 1);
if (!use_split_matrix_) {
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = input1_addr + i * lda_ * m_;
h_identity_[i] = output_addr + i * ldb_ * m_;
CHECK_CUDA_RET_WITH_ERROR(
kernel_node_,
cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data_.data(), sizeof(T) * ldb_ * m_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
}
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_identity_addr, h_identity_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
TriangleMatrixCopy(input1_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_a_array_addr, h_a_array_.data(), sizeof(pointer) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_b_array_addr, h_b_array_.data(), sizeof(pointer) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
// only support rhs = 1
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnSpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_,
d_b_array_addr, ldb_, d_info_array_addr, batch_),
"cusolver cholesky solve batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnDpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_,
d_b_array_addr, ldb_, d_info_array_addr, batch_),
"cusolver cholesky solve batched Fail");
} else {
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, 3);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = d_batch_input_addr + i * lda_ * m_;
h_identity_[i] = output_addr + i * ldb_ * m_;
}
Eye(batch_ * split_dim_ * split_dim_, split_dim_, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
MatrixSplit(batch_ * split_dim_ * split_dim_, split_dim_, width_, input1_addr, d_batch_input_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_identity_addr, h_identity_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
MS_LOG(EXCEPTION) << "cholesky solve do not support other data type but only float or double, right now.";
}
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
// copy results from written input's matrix to output's matrix.
MatrixCopy(input_b_addr, output_addr, output_elements, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(in_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "For 'CholeskySolveGpuKernel', input is null";
InitSizeLists();
return true;
}
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
if (split_dim_ == 0) {
if (!InitNoSpltDim(in_shape)) {
return false;
}
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
// gpu input is col major default, so need to change row major.
// in order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix
// when input b_col is not equal to one, maybe need a normal transpose op inplace.
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_UPPER;
} else {
if (!InitSpltDim(in_shape)) {
return false;
}
uplo_ = CUBLAS_FILL_MODE_LOWER;
}
return true;
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
auto in_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kDim0);
auto in_b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kDim1);
if (CHECK_NULL_INPUT(in_a_shape) || CHECK_NULL_INPUT(in_b_shape)) {
MS_LOG(EXCEPTION) << "For 'CholeskySolveGpuKernel', input is null";
}
return InitDim(in_a_shape, in_b_shape);
}
protected:
bool InitNoSpltDim(const std::vector<size_t> &in_shape) {
use_split_matrix_ = false;
if (in_shape.size() == 2) {
void InitSizeLists() override {
size_t input_size = batch_ * m_ * lda_ * unit_size_;
input_size_list_.emplace_back(input_size);
input_size = batch_ * nrhs_ * ldb_ * unit_size_;
input_size_list_.emplace_back(input_size);
size_t workspace_size = batch_ * sizeof(pointer);
workspace_size_list_.emplace_back(workspace_size);
workspace_size_list_.emplace_back(workspace_size);
workspace_size = batch_ * sizeof(int);
workspace_size_list_.emplace_back(workspace_size);
size_t output_size = batch_ * m_ * unit_size_;
output_size_list_.push_back(output_size);
}
private:
bool InitDim(const std::vector<size_t> &in_a_shape, const std::vector<size_t> &in_b_shape) {
if (in_a_shape.size() == kCholeskyDefaultShape) {
batch_ = 1;
if (in_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
} else if (in_shape.size() == 3) {
batch_ = SizeToInt(in_shape[0]);
if (in_shape[1] != in_shape[2]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
cho_row_ = in_a_shape.at(kDim0);
cho_col_ = cho_row_;
} else if (in_a_shape.size() == kCholeskyNormalShape) {
batch_ = 1;
cho_row_ = in_a_shape.at(kDim0);
cho_col_ = in_a_shape.at(kDim1);
} else if (in_a_shape.size() == kCholeskyBatchedShape) {
batch_ = SizeToInt(in_a_shape.at(kDim0));
cho_row_ = in_a_shape.at(kDim1);
cho_col_ = in_a_shape.at(kDim2);
} else {
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
return false;
}
m_ = SizeToInt(in_shape[1]);
if (cho_row_ != cho_col_) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
size_t b_row = in_b_shape.size() == kCholeskyBatchedShape ? in_b_shape.at(kDim1) : in_b_shape.at(kDim0);
if (cho_row_ != b_row) {
MS_LOG(ERROR) << "Cholesky right hand matrix is not equal to left matrix.";
return false;
}
m_ = SizeToInt(in_a_shape.at(kDim1));
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
h_identity_.resize(batch_);
h_identity_data_.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data_[i * m_ + j] = 1;
} else {
h_identity_data_[i * m_ + j] = 0;
}
}
}
h_a_array_.resize(batch_);
h_b_array_.resize(batch_);
InitSizeLists();
return true;
}
bool InitSpltDim(const std::vector<size_t> &in_shape) {
if (in_shape.size() != 2) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
return false;
}
height_ = in_shape[0];
width_ = in_shape[1];
if (height_ != width_) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
return false;
}
if (SizeToInt(height_) <= split_dim_) {
use_split_matrix_ = false;
batch_ = 1;
m_ = SizeToInt(in_shape[1]);
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
h_identity_.resize(batch_);
h_identity_data_.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data_[i * m_ + j] = 1;
} else {
h_identity_data_[i * m_ + j] = 0;
}
}
}
InitSizeLists();
} else {
use_split_matrix_ = true;
int batch = SizeToInt(in_shape[1]) / split_dim_;
res_dim_ = in_shape[1] - batch * split_dim_;
if (res_dim_ == 0) {
batch_ = batch;
} else {
batch_ = batch + 1;
}
m_ = split_dim_;
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
h_identity_.resize(batch_);
h_identity_data_.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data_[i * m_ + j] = 1;
} else {
h_identity_data_[i * m_ + j] = 0;
}
}
}
InitSizeLists();
}
return true;
}
void InitSizeLists() override {
size_t unit_size = sizeof(T);
size_t input_size;
size_t workspace_size;
if (!use_split_matrix_) {
input_size = batch_ * m_ * lda_ * unit_size;
} else {
input_size = height_ * width_ * unit_size;
workspace_size = batch_ * m_ * lda_ * unit_size;
workspace_size_list_.push_back(workspace_size);
}
input_size_list_.push_back(input_size);
size_t output_size = batch_ * m_ * lda_ * unit_size;
output_size_list_.push_back(output_size);
workspace_size = batch_ * sizeof(T *);
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
workspace_size = batch_ * sizeof(T *);
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
workspace_size = batch_ * sizeof(int);
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
}
private:
size_t batch_;
size_t m_;
size_t lda_;
size_t ldb_;
int res_dim_;
int split_dim_;
bool is_null_input_;
bool use_split_matrix_;
size_t height_;
size_t width_;
cusolverDnHandle_t handle_;
cublasHandle_t blas_handle_;
size_t cho_row_{0};
size_t cho_col_{0};
size_t unit_size_{sizeof(T)};
size_t nrhs_{1};
size_t batch_{0};
size_t m_{0};
size_t lda_{0};
size_t ldb_{0};
cusolverDnHandle_t handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
std::vector<T *> h_array_;
std::vector<T *> h_identity_;
std::vector<T> h_identity_data_;
std::vector<pointer> h_a_array_;
std::vector<pointer> h_b_array_;
bool lower_{false};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -708,8 +708,14 @@ class _Cholesky(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, split_dim=0):
def __init__(self, lower=False, clean=True, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.lower = lower
self.add_prim_attr('lower', self.lower)
self.clean = clean
self.add_prim_attr('clean', self.clean)
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
@ -729,7 +735,7 @@ class _Cholesky(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
return x1_dtype

View File

@ -16,9 +16,11 @@
from .. import numpy as mnp
from .. import ops
from .ops import SolveTriangular
from .ops import CholeskySolver
from .ops import Cholesky
from ..ops import operations as P
__all__ = ['block_diag', 'solve_triangular', 'inv']
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve']
def block_diag(*arrs):
@ -191,3 +193,128 @@ def inv(a, overwrite_a=False, check_finite=True):
"""
matrix_inverse = P.MatrixInverse(adjoint=False)
return matrix_inverse(a)
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
"""
Compute the Cholesky decomposition of a matrix, to use in cho_solve
Returns a matrix containing the Cholesky decomposition,
``A = L L*`` or ``A = U* U`` of a Hermitian positive-definite matrix `a`.
The return value can be directly used as the first parameter to cho_solve.
.. warning::
The returned matrix also contains random data in the entries not
used by the Cholesky decomposition. If you need to zero these
entries, use the function `cholesky` instead.
Args:
a (Tensor): square Matrix of (M, M) to be decomposed
lower (bool, optional): Whether to compute the upper or lower triangular Cholesky factorization
(Default: upper-triangular)
overwrite_a(bool, optional): Whether to overwrite data in a (may improve performance)
check_finite(bool, optional): Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
c (Tensor): Matrix whose upper or lower triangle contains the Cholesky factor of `a`.
Other parts of the matrix contain random data.
lower (bool, optional): Flag indicating whether the factor is in the lower or upper triangle
Raises:
LinAlgError: Raised if decomposition fails.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import cho_factor
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]).astype(onp.float32))
>>> c, low = cho_factor(A)
>>> c
[[ 2.9999998 0.99999994 0.3333333 1.6666665 ]
[ 0. 2.4494896 1.9051585 -0.27216542]
[ 0. 0. 2.2933078 0.8559527 ]
[ 0. 0. 0. 1.5541859 ]]
"""
cholesky_net = Cholesky(lower=lower, clean=False)
c = cholesky_net(a)
return c, lower
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
"""
Compute the Cholesky decomposition of a matrix.
Returns the Cholesky decomposition, :math:`A = L L^*` or
:math:`A = U^* U` of a Hermitian positive-definite matrix A.
Args:
a (Tensor): square Matrix of (M, M) to be decomposed
lower (bool, optional): Whether to compute the upper- or lower-triangular Cholesky
factorization. Default is upper-triangular.
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance).
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
c (Tensor): Upper- or lower-triangular Cholesky factor of `a`.
Raises:
LinAlgError: if decomposition fails.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import cholesky
>>> a = Tensor(onp.array([[1, -2],[2, 5]]).astype(onp.float32))
>>> L = cholesky(a, lower=True)
>>> L
[[1., 0.],
[2., 1.]]
"""
cholesky_net = Cholesky(lower=lower, clean=True)
c = cholesky_net(a)
return c
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
"""Solve the linear equations Ax = b, given the Cholesky factorization of A.
Args:
c_and_lower ((Tensor, bool)): Cholesky factorization of a, as given by cho_factor
b (Tensor): Right-hand side
overwrite_b (bool, optional): Whether to overwrite data in b (may improve performance)
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:
x (Tensor):
The solution to the system A x = b
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import cho_factor, cho_solve
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]).astype(onp.float32))
>>> b = Tensor(onp.array([1, 1, 1, 1]).astype(onp.float32))
>>> c, low = cho_factor(A)
>>> x = cho_solve((c, low), b)
>>> x
[-0.01749271, 0.11953353, 0.01166181, 0.1574344 ]
"""
(c, lower) = c_and_lower
cholesky_solver_net = CholeskySolver(lower=lower)
x = cholesky_solver_net(c, b)
return x

View File

@ -96,3 +96,91 @@ class SolveTriangular(PrimitiveWithInfer):
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
self.name, True)
return x_dtype
class Cholesky(PrimitiveWithInfer):
"""
Inner API for _Cholesky base class.
"""
@prim_attr_register
def __init__(self, lower=False, clean=True, split_dim=0):
super().__init__("Cholesky")
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.lower = lower
self.add_prim_attr('lower', self.lower)
self.clean = clean
self.add_prim_attr('clean', self.clean)
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
def infer_shape(self, x1_shape):
if self.split_dim != 0:
height = x1_shape[0]
width = x1_shape[1]
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = x1_shape
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
return x1_dtype
class CholeskySolver(PrimitiveWithInfer):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
lower : bool, optional
Whether to compute the upper or lower triangular Cholesky factorization
(Default: upper-triangular)
b : array
Right-hand side
Inputs:
- **A** (Tensor) - A matrix of shape :math:`(M, M)` to be decomposed.
- **b** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(..., M)`.
Right-hand side matrix in :math:`A x = b`.
Returns
-------
x : array
The solution to the system A x = b
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.ops import CholeskySolver
>>> from mindspore.scipy.linalg import cho_factor
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]))
>>> b = Tensor(onp.array([1.0, 1.0, 1.0, 1.0], dtype=onp.dtype))
>>> c, lower = cho_factor(A)
>>> cholesky_solver = CholeskySolver(lower=lower)
>>> x = cholesky_solver(c, b)
"""
@prim_attr_register
def __init__(self, lower=False):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.init_prim_io_names(inputs=['x', 'b'], outputs=['y'])
def __infer__(self, x, b):
b_shape = b['shape']
x_dtype = x['dtype']
return {
'shape': tuple(b_shape),
'dtype': x_dtype,
'value': None
}

View File

@ -1,161 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import Generic
import mindspore.context as context
import mindspore.nn as nn
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore import Tensor
import numpy as np
from mindspore.scipy.linalg import cholesky
import scipy as scp
import numpy as np
import pytest
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
np.random.seed(0)
class Cholesky(PrimitiveWithInfer):
"""
Inner API for Cholesky base class.
"""
@prim_attr_register
def __init__(self, lower=False, clean=True):
super().__init__(name="Cholesky")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def __infer__(self, x):
x_shape = x['shape']
x_dtype = x['dtype']
return {
'shape': tuple(x_shape),
'dtype': x_dtype,
'value': None
}
class CholeskySolver(PrimitiveWithInfer):
"""
Inner API for CholeskySolver class.
"""
@prim_attr_register
def __init__(self, lower=False):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def __infer__(self, x, b):
b_shape = b['shape']
x_dtype = x['dtype']
return {
'shape': tuple(b_shape),
'dtype': x_dtype,
'value': None
}
class CholeskyNet(nn.Cell):
def __init__(self, lower=False, clean=False):
super(CholeskyNet, self).__init__()
self.cholesky = Cholesky(lower, clean)
def construct(self, x):
return self.cholesky(x)
class CholeskySolverNet(nn.Cell):
def __init__(self, lower=False):
super(CholeskySolverNet, self).__init__()
self.cholesky_solver = CholeskySolver(lower)
def construct(self, c, b):
return self.cholesky_solver(c, b)
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
"""
ompute the Cholesky decomposition of a matrix, to use in cho_solve.
Returns a matrix containing the Cholesky decomposition
"""
cholesky_net = CholeskyNet(lower=lower, clean=False)
c = cholesky_net(a)
return c, lower
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
"""
Compute the Cholesky decomposition of a matrix.
Returns the Cholesky decomposition
"""
cholesky_net = CholeskyNet(lower=lower, clean=True)
c = cholesky_net(a)
return c
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
c_and_lower: (c, lower) tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
overwrite_b : bool, optional
Whether to overwrite data in b (may improve performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
x : array
The solution to the system A x = b
See also
--------
cho_factor : Cholesky factorization of a matrix
"""
(c, lower) = c_and_lower
cholesky_solver_net = CholeskySolverNet(lower=lower)
x = cholesky_solver_net(c, b)
return x
def test_cholesky():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('lower', [True])
@pytest.mark.parametrize('dtype', [np.float64])
def test_cholesky(lower: bool, dtype: Generic):
"""
Feature: ALL TO ALL
Description: test cases for cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = np.array([[4, 12, -6], [12, 37, -43], [-16, -43, 98]], dtype=np.float32)
a = np.array([[4, 12, -6], [12, 37, -43], [-16, -43, 98]], dtype=dtype)
tensor_a = Tensor(a)
scp_c_1, _ = scp.linalg.cho_factor(a, lower=True)
mscp_c_1, _ = cho_factor(tensor_a, lower=True)
scp_c_2 = scp.linalg.cholesky(a, lower=True)
mscp_c_2 = cholesky(tensor_a, lower=True)
assert np.allclose(scp_c_1, mscp_c_1.asnumpy())
assert np.allclose(scp_c_2, mscp_c_2.asnumpy())
def test_cholesky_solver():
"""
Feature: ALL TO ALL
Description: test cases for cholesky solver [N,N]
Expectation: the result match scipy cholesky_solve
"""
a = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]], dtype=np.float32)
b = np.array([1, 1, 1, 1], dtype=np.float32)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
scp_c, lower = scp.linalg.cho_factor(a, lower=False)
mscp_c, mscp_lower = cho_factor(tensor_a, lower=False)
scp_c = scp.linalg.cholesky(a, lower=lower)
mscp_c = cholesky(tensor_a, lower=lower)
assert np.allclose(scp_c, mscp_c.asnumpy())
scp_factor = (scp_c, lower)
ms_cho_factor = (mscp_c, mscp_lower)
scp_x = scp.linalg.cho_solve(scp_factor, b)
mscp_x = cho_solve(ms_cho_factor, tensor_b)
assert np.allclose(scp_x, mscp_x.asnumpy())

View File

@ -12,99 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import Generic
import mindspore.context as context
from mindspore import Tensor
from mindspore.scipy.linalg import cholesky
import numpy as np
import scipy as scp
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from mindspore._checkparam import Validator as validator
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetCholesky(nn.Cell):
def __init__(self):
super(NetCholesky, self).__init__()
self.cholesky = P.Cholesky()
def construct(self, x):
return self.cholesky(x)
class ScipyCholesky(PrimitiveWithInfer):
"""
Inner API for Cholesky base class.
"""
@prim_attr_register
def __init__(self, lower=False, clean=False):
super().__init__(name="PureCholesky")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def __infer__(self, x):
x_shape = x['shape']
x_dtype = x['dtype']
return {
'shape': tuple(x_shape),
'dtype': x_dtype,
'value': None
}
class ScipyNetCholesky(nn.Cell):
def __init__(self, lower=False, clean=False):
super(ScipyNetCholesky, self).__init__()
self.cholesky = ScipyCholesky(lower, clean)
def construct(self, x):
return self.cholesky(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cholesky_fp32():
"""
Feature: ALL TO ALL
Description: test cases for origin cholesky [N,N]
Expectation: the result match np cholesky
"""
cholesky = NetCholesky()
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
output = cholesky(Tensor(x, dtype=mstype.float32))
expect = np.linalg.cholesky(x)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect) < tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scipy_cholesky_fp32():
@pytest.mark.parametrize('lower', [True])
@pytest.mark.parametrize('dtype', [np.float64])
def test_scipy_cholesky(lower: bool, dtype: Generic):
"""
Feature: ALL TO ALL
Description: test cases for new scipy cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(dtype=dtype)
tensor_a = Tensor(a)
cholesky = ScipyNetCholesky(lower=True, clean=False)
output = cholesky(tensor_a)
cholesky1 = ScipyNetCholesky(lower=False, clean=False)
output1 = cholesky1(tensor_a)
expect = scp.linalg.cholesky(a, lower=True)
expect1 = scp.linalg.cholesky(a, lower=False)
rtol = 1.e-4
atol = 1.e-5
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
assert np.allclose(expect1, output1.asnumpy(), rtol=rtol, atol=atol)
output = cholesky(tensor_a, lower=lower)
expect = scp.linalg.cholesky(a, lower=lower)
assert np.allclose(expect, output.asnumpy())

View File

@ -14,13 +14,19 @@
# ============================================================================
"""st for scipy.linalg."""
from typing import Generic
import pytest
import numpy as onp
import scipy as osp
from mindspore import Tensor
import mindspore.scipy as msp
from .utils import match_array, create_full_rank_matrix
from mindspore import context, Tensor
import mindspore.numpy as mnp
from .utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
onp.random.seed(0)
context.set_context(mode=context.PYNATIVE_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -59,3 +65,77 @@ def test_inv(dtype, shape):
ms_res = msp.linalg.inv(Tensor(x))
scipy_res = onp.linalg.inv(x)
match_array(ms_res.asnumpy(), scipy_res, error=3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_cholesky(n: int, lower: bool, dtype: Generic):
"""
Feature: ALL TO ALL
Description: test cases for cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = create_sym_pos_matrix((n, n), dtype)
tensor_a = Tensor(a)
rtol = 1.e-5
atol = 1.e-8
osp_c = osp.linalg.cholesky(a, lower=lower)
msp_c = msp.linalg.cholesky(tensor_a, lower=lower)
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_cho_factor(n: int, lower: bool, dtype: Generic):
"""
Feature: ALL TO ALL
Description: test cases for cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = create_sym_pos_matrix((n, n), dtype)
tensor_a = Tensor(a)
msp_c, _ = msp.linalg.cho_factor(tensor_a, lower=lower)
if lower:
msp_reconstruct_a = mnp.dot(mnp.tril(msp_c), mnp.tril(msp_c).T)
else:
msp_reconstruct_a = mnp.dot(mnp.triu(msp_c).T, mnp.triu(msp_c))
assert onp.allclose(a, msp_reconstruct_a.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_cholesky_solver(n: int, lower: bool, dtype):
"""
Feature: ALL TO ALL
Description: test cases for cholesky solver [N,N]
Expectation: the result match scipy cholesky_solve
"""
a = create_sym_pos_matrix((n, n), dtype)
b = onp.ones((n, 1), dtype=dtype)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
osp_c, lower = osp.linalg.cho_factor(a, lower=lower)
msp_c, msp_lower = msp.linalg.cho_factor(tensor_a, lower=lower)
osp_factor = (osp_c, lower)
ms_cho_factor = (msp_c, msp_lower)
osp_x = osp.linalg.cho_solve(osp_factor, b)
msp_x = msp.linalg.cho_solve(ms_cho_factor, tensor_b)
# pre tensor_a has been inplace.
tensor_a = Tensor(a)
assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy())