add cholesky && lu factorization for gpu backend
This commit is contained in:
parent
5b87e64557
commit
a125654fbc
|
@ -16,8 +16,8 @@
|
|||
|
||||
#include "triangle_matrix_copy_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m) {
|
||||
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
|
@ -42,8 +42,42 @@ __global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMo
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) {
|
||||
__global__ void ScipyTriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
|
||||
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
if (col < row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = input[i];
|
||||
}
|
||||
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
|
||||
if (col > row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[row * m + col] = input[col * m + row];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixCopyKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream) {
|
||||
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
|
||||
return;
|
||||
}
|
||||
|
@ -52,3 +86,29 @@ template void TriangleMatrixCopy<float>(const float *input, float *output, cubla
|
|||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream) {
|
||||
ScipyTriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
MatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void ScipyTriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ScipyTriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void ScipyTriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixCopy<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void MatrixCopy<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void MatrixCopy<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -19,6 +19,13 @@
|
|||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ScipyCholeskyGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ScipyCholeskyGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kCholeskyInputsNum = 1;
|
||||
constexpr size_t kInputIndex = 0;
|
||||
constexpr size_t kCholeskyOutputsNum = 1;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
constexpr size_t kCholeskyDefaultShape = 1;
|
||||
constexpr size_t kCholeskyNormalShape = 2;
|
||||
constexpr size_t kCholeskyBatchedShape = 3;
|
||||
|
||||
template <typename T>
|
||||
class ScipyCholeskyGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ScipyCholeskyGpuKernel() = default;
|
||||
~ScipyCholeskyGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
// here all addresses are malloc by cuda, so deal with them as device's address.
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
auto d_array_addr = GetDeviceAddress<T *>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = input1_addr + i * lda_ * m_;
|
||||
}
|
||||
|
||||
// 5. copy input's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
// 6. solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// 7. copy results from written input's matrix to output's matrix by up or lower flag.
|
||||
ScipyTriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, "lower"));
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
}
|
||||
// 1. get CuSolver Dense matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
// 2. get Cublas handler
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
// 3. check input shape not null
|
||||
bool is_null_input = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
|
||||
}
|
||||
// 4. calculate input size
|
||||
if (!InitInputSize(in_shape)) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool InitInputSize(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() == kCholeskyDefaultShape) {
|
||||
batch_ = 1;
|
||||
cho_row_ = in_shape.at(kDim0);
|
||||
cho_col_ = cho_row_;
|
||||
} else if (in_shape.size() == kCholeskyNormalShape) {
|
||||
batch_ = 1;
|
||||
cho_row_ = in_shape.at(kDim0);
|
||||
cho_col_ = in_shape.at(kDim1);
|
||||
} else if (in_shape.size() == kCholeskyBatchedShape) {
|
||||
batch_ = SizeToInt(in_shape.at(kDim0));
|
||||
cho_row_ = in_shape.at(kDim1);
|
||||
cho_col_ = in_shape.at(kDim2);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
|
||||
return false;
|
||||
}
|
||||
if (cho_row_ != cho_col_) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
return false;
|
||||
}
|
||||
// set matrix row or col to be lead dimension
|
||||
m_ = SizeToInt(cho_row_);
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
h_array_.resize(batch_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_ = batch_ * m_ * lda_ * unit_size_;
|
||||
input_size_list_.push_back(input_size_);
|
||||
|
||||
output_size_ = batch_ * m_ * lda_ * unit_size_;
|
||||
output_size_list_.push_back(output_size_);
|
||||
|
||||
size_t workspace_size = batch_ * sizeof(T *);
|
||||
workspace_size_list_.resize(kDim2);
|
||||
workspace_size_list_[kDim0] = workspace_size;
|
||||
|
||||
workspace_size = batch_ * sizeof(int);
|
||||
workspace_size_list_[kDim1] = workspace_size;
|
||||
}
|
||||
|
||||
size_t unit_size_{sizeof(T)};
|
||||
size_t cho_row_{0};
|
||||
size_t cho_col_{0};
|
||||
size_t batch_{0};
|
||||
size_t m_{0};
|
||||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
size_t input_size_{0};
|
||||
size_t output_size_{0};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
bool lower_{false};
|
||||
std::vector<T *> h_array_{};
|
||||
std::vector<size_t> input_size_list_{};
|
||||
std::vector<size_t> output_size_list_{};
|
||||
std::vector<size_t> workspace_size_list_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/lu_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
LUGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(LU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
LUGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,175 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kLuInputsNum = 1;
|
||||
constexpr size_t kInputIndex = 0;
|
||||
constexpr size_t kLuOutputsNum = 1;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
constexpr size_t kLuDefaultShape = 1;
|
||||
constexpr size_t kLuNormalShape = 2;
|
||||
|
||||
template <typename T>
|
||||
class LUGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LUGpuKernel() = default;
|
||||
~LUGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
int *piv_output_addr = nullptr;
|
||||
if (pivot_on_) {
|
||||
piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
|
||||
}
|
||||
|
||||
auto info_output_addr = GetDeviceAddress<int>(outputs, kDim2);
|
||||
|
||||
// 4. query working space of getrf
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
|
||||
"cusolver query lu work size fail");
|
||||
|
||||
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
|
||||
MS_LOG(EXCEPTION) << "cusolver malloc work size fail";
|
||||
}
|
||||
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
|
||||
"cusolver query lu work size fail");
|
||||
// 5. malloc device working space of getrf
|
||||
|
||||
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
|
||||
MS_LOG(EXCEPTION) << "cusolver malloc work size fail";
|
||||
}
|
||||
|
||||
// 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
|
||||
}
|
||||
// 7. copy results from written input's matrix to output's matrix.
|
||||
// if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) {
|
||||
// MS_LOG(EXCEPTION) << "memcpy lu output fail.";
|
||||
// }
|
||||
MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (d_work_) {
|
||||
cudaFree(d_work_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
// 1. get CuSolver Dense matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
// 2. check input shape not null
|
||||
bool is_null_input = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
|
||||
}
|
||||
// 3. calculate input size
|
||||
if (!InitInputSize(in_shape)) {
|
||||
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool InitInputSize(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() == kLuDefaultShape) {
|
||||
lu_row_ = in_shape.at(kDim0);
|
||||
lu_col_ = lu_row_;
|
||||
} else if (in_shape.size() == kLuNormalShape) {
|
||||
lu_row_ = in_shape.at(kDim0);
|
||||
lu_col_ = in_shape.at(kDim1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Input Only support Rank 1 OR 2";
|
||||
return false;
|
||||
}
|
||||
if (lu_row_ != lu_col_) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
return false;
|
||||
}
|
||||
// set matrix row or col to be lead dimension
|
||||
m_ = SizeToInt(lu_row_);
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = lda_ * m_ * unit_size_;
|
||||
input_size_list_.push_back(input_size);
|
||||
|
||||
size_t output_size = lda_ * m_ * unit_size_;
|
||||
size_t output_piv_size = 0;
|
||||
size_t output_info_size = sizeof(int);
|
||||
if (pivot_on_) {
|
||||
output_piv_size = m_ * sizeof(int);
|
||||
}
|
||||
output_size_list_.resize(kDim3);
|
||||
output_size_list_[kDim0] = output_size;
|
||||
output_size_list_[kDim1] = output_piv_size;
|
||||
output_size_list_[kDim2] = output_info_size;
|
||||
}
|
||||
|
||||
size_t unit_size_{sizeof(T)};
|
||||
size_t lu_row_{0};
|
||||
size_t lu_col_{0};
|
||||
size_t m_{0};
|
||||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
int lwork_{0};
|
||||
bool pivot_on_{true};
|
||||
T *d_work_{nullptr};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
std::vector<size_t> input_size_list_{};
|
||||
std::vector<size_t> output_size_list_{};
|
||||
std::vector<size_t> workspace_size_list_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_
|
|
@ -12,17 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import scipy as scp
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetCholesky(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCholesky, self).__init__()
|
||||
|
@ -32,13 +36,75 @@ class NetCholesky(nn.Cell):
|
|||
return self.cholesky(x)
|
||||
|
||||
|
||||
class ScipyCholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for Cholesky base class.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super().__init__(name="PureCholesky")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = x['shape']
|
||||
x_dtype = x['dtype']
|
||||
return {
|
||||
'shape': tuple(x_shape),
|
||||
'dtype': x_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
||||
class ScipyNetCholesky(nn.Cell):
|
||||
def __init__(self, lower=False, clean=False):
|
||||
super(ScipyNetCholesky, self).__init__()
|
||||
self.cholesky = ScipyCholesky(lower, clean)
|
||||
|
||||
def construct(self, x):
|
||||
return self.cholesky(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cholesky_fp32():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for origin cholesky [N,N]
|
||||
Expectation: the result match np cholesky
|
||||
"""
|
||||
cholesky = NetCholesky()
|
||||
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
|
||||
output = cholesky(Tensor(x, dtype=mstype.float32))
|
||||
expect = np.linalg.cholesky(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scipy_cholesky_fp32():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for new scipy cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
|
||||
tensor_a = Tensor(a)
|
||||
cholesky = ScipyNetCholesky(lower=True, clean=False)
|
||||
output = cholesky(tensor_a)
|
||||
|
||||
cholesky1 = ScipyNetCholesky(lower=False, clean=False)
|
||||
output1 = cholesky1(tensor_a)
|
||||
|
||||
expect = scp.linalg.cholesky(a, lower=True)
|
||||
expect1 = scp.linalg.cholesky(a, lower=False)
|
||||
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-5
|
||||
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
|
||||
assert np.allclose(expect1, output1.asnumpy(), rtol=rtol, atol=atol)
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from typing import Generic
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.numpy as mnp
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
import scipy as scp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class LU(PrimitiveWithInfer):
|
||||
"""
|
||||
LU decomposition with partial pivoting
|
||||
P.A = L.U
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super().__init__(name="LU")
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shape = list(x['shape'])
|
||||
x_dtype = x['dtype']
|
||||
pivots_shape = []
|
||||
permutation_shape = []
|
||||
ndim = len(x_shape)
|
||||
if ndim == 0:
|
||||
pivots_shape = x_shape
|
||||
permutation_shape = x_shape
|
||||
elif ndim == 1:
|
||||
pivots_shape = x_shape[:-1]
|
||||
# permutation_shape = x_shape[:-1]
|
||||
else:
|
||||
pivots_shape = x_shape[-2:-1]
|
||||
# permutation_shape = x_shape[-2:-1]
|
||||
|
||||
output = {
|
||||
'shape': (x_shape, pivots_shape, permutation_shape),
|
||||
'dtype': (x_dtype, mstype.int32, mstype.int32),
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class LuNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LuNet, self).__init__()
|
||||
self.lu = LU()
|
||||
|
||||
def construct(self, a):
|
||||
return self.lu(a)
|
||||
|
||||
|
||||
@pytest.mark.platform_x86_gpu
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_lu_net(n: int, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
expect, _ = scp.linalg.lu_factor(a)
|
||||
mscp_lu_net = LuNet()
|
||||
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
|
||||
tensor_a = Tensor(a)
|
||||
tensor_a = mnp.transpose(tensor_a)
|
||||
output, _, _ = mscp_lu_net(tensor_a)
|
||||
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
|
||||
output = mnp.transpose(output)
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-5
|
||||
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
|
Loading…
Reference in New Issue