From a125654fbcb06a80e56682feca3afbcd8305795a Mon Sep 17 00:00:00 2001 From: z00512249 Date: Tue, 19 Oct 2021 16:01:34 +0800 Subject: [PATCH] add cholesky && lu factorization for gpu backend --- .../cuda_impl/triangle_matrix_copy_impl.cu | 68 ++++++- .../cuda_impl/triangle_matrix_copy_impl.cuh | 11 +- .../gpu/math/cholesky_gpu_kernel.cc | 26 +++ .../gpu/math/cholesky_gpu_kernel.h | 181 ++++++++++++++++++ .../kernel_compiler/gpu/math/lu_gpu_kernel.cc | 36 ++++ .../kernel_compiler/gpu/math/lu_gpu_kernel.h | 175 +++++++++++++++++ tests/st/ops/gpu/test_cholesky_op.py | 68 ++++++- tests/st/ops/gpu/test_lu_op.py | 94 +++++++++ 8 files changed, 652 insertions(+), 7 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_lu_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu index e9b520f43d4..c91aa0fbca7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu @@ -16,8 +16,8 @@ #include "triangle_matrix_copy_impl.cuh" template -__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, - const size_t count, const size_t ldb, const size_t m) { +__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count, + const size_t ldb, const size_t m) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { size_t batchIdx = i / (ldb * m); size_t row = (i - batchIdx * ldb * m) / m; @@ -42,8 +42,42 @@ __global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMo } template -void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, - const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) { +__global__ void ScipyTriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count, + const size_t ldb, const size_t m) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + size_t batchIdx = i / (ldb * m); + size_t row = (i - batchIdx * ldb * m) / m; + size_t col = (i - batchIdx * ldb * m) % m; + + // If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0; + // If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0; + // special case, only upper triangle data is correct, so copy up to lower, when lower case. + if (uplo == CUBLAS_FILL_MODE_UPPER) { + if (col < row) { + output[i] = 0; + } else { + output[i] = input[i]; + } + } else if (uplo == CUBLAS_FILL_MODE_LOWER) { + if (col > row) { + output[i] = 0; + } else { + output[row * m + col] = input[col * m + row]; + } + } + } +} + +template +__global__ void MatrixCopyKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = input[i]; + } +} + +template +void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb, + const size_t m, cudaStream_t cuda_stream) { TriangleMatrixCopyKernel<<>>(input, output, uplo, count, ldb, m); return; } @@ -52,3 +86,29 @@ template void TriangleMatrixCopy(const float *input, float *output, cubla const size_t ldb, const size_t m, cudaStream_t cuda_stream); template void TriangleMatrixCopy(const half *input, half *output, cublasFillMode_t uplo, const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream); + +template +void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb, + const size_t m, cudaStream_t cuda_stream) { + ScipyTriangleMatrixCopyKernel<<>>(input, output, uplo, count, ldb, m); + return; +} + +template +void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + MatrixCopyKernel<<>>(input, output, count); + return; +} + +template void ScipyTriangleMatrixCopy(const float *input, float *output, cublasFillMode_t uplo, + const size_t count, const size_t ldb, const size_t m, + cudaStream_t cuda_stream); +template void ScipyTriangleMatrixCopy(const half *input, half *output, cublasFillMode_t uplo, const size_t count, + const size_t ldb, const size_t m, cudaStream_t cuda_stream); +template void ScipyTriangleMatrixCopy(const double *input, double *output, cublasFillMode_t uplo, + const size_t count, const size_t ldb, const size_t m, + cudaStream_t cuda_stream); + +template void MatrixCopy(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void MatrixCopy(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void MatrixCopy(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh index 98218a30d97..8758c5d79f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh @@ -19,6 +19,13 @@ #include "runtime/device/gpu/cuda_common.h" template -void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, - const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream); +void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb, + const size_t m, cudaStream_t cuda_stream); + +template +void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb, + const size_t m, cudaStream_t cuda_stream); + +template +void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.cc new file mode 100644 index 00000000000..9cd900b8fcb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ScipyCholeskyGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ScipyCholeskyGpuKernel, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h new file mode 100644 index 00000000000..8c1f7459c3a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kCholeskyInputsNum = 1; +constexpr size_t kInputIndex = 0; +constexpr size_t kCholeskyOutputsNum = 1; +constexpr size_t kOutputIndex = 0; +constexpr size_t kCholeskyDefaultShape = 1; +constexpr size_t kCholeskyNormalShape = 2; +constexpr size_t kCholeskyBatchedShape = 3; + +template +class ScipyCholeskyGpuKernel : public GpuKernel { + public: + ScipyCholeskyGpuKernel() = default; + ~ScipyCholeskyGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + // here all addresses are malloc by cuda, so deal with them as device's address. + auto input1_addr = GetDeviceAddress(inputs, kDim0); + auto output_addr = GetDeviceAddress(outputs, kDim0); + + auto d_array_addr = GetDeviceAddress(workspace, kDim0); + auto d_info_array_addr = GetDeviceAddress(workspace, kDim1); + + for (size_t i = 0; i < batch_; i++) { + h_array_[i] = input1_addr + i * lda_ * m_; + } + + // 5. copy input's addr to d_array_addr + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + + // 6. solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix. + if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + } else if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + } else { + MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now."; + } + size_t output_elements = outputs.at(kDim0)->size / unit_size_; + // 7. copy results from written input's matrix to output's matrix by up or lower flag. + ScipyTriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + lower_ = static_cast(GetAttr(kernel_node, "lower")); + if (lower_) { + uplo_ = CUBLAS_FILL_MODE_LOWER; + } else { + uplo_ = CUBLAS_FILL_MODE_UPPER; + } + // 1. get CuSolver Dense matrix handler + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + // 2. get Cublas handler + blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + // 3. check input shape not null + bool is_null_input = CHECK_NULL_INPUT(in_shape); + if (is_null_input) { + MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null"; + } + // 4. calculate input size + if (!InitInputSize(in_shape)) { + MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed."; + } + return true; + } + + private: + bool InitInputSize(const std::vector &in_shape) { + if (in_shape.size() == kCholeskyDefaultShape) { + batch_ = 1; + cho_row_ = in_shape.at(kDim0); + cho_col_ = cho_row_; + } else if (in_shape.size() == kCholeskyNormalShape) { + batch_ = 1; + cho_row_ = in_shape.at(kDim0); + cho_col_ = in_shape.at(kDim1); + } else if (in_shape.size() == kCholeskyBatchedShape) { + batch_ = SizeToInt(in_shape.at(kDim0)); + cho_row_ = in_shape.at(kDim1); + cho_col_ = in_shape.at(kDim2); + } else { + MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; + return false; + } + if (cho_row_ != cho_col_) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + return false; + } + // set matrix row or col to be lead dimension + m_ = SizeToInt(cho_row_); + lda_ = m_; + ldb_ = m_; + h_array_.resize(batch_); + InitSizeLists(); + return true; + } + + void InitSizeLists() override { + input_size_ = batch_ * m_ * lda_ * unit_size_; + input_size_list_.push_back(input_size_); + + output_size_ = batch_ * m_ * lda_ * unit_size_; + output_size_list_.push_back(output_size_); + + size_t workspace_size = batch_ * sizeof(T *); + workspace_size_list_.resize(kDim2); + workspace_size_list_[kDim0] = workspace_size; + + workspace_size = batch_ * sizeof(int); + workspace_size_list_[kDim1] = workspace_size; + } + + size_t unit_size_{sizeof(T)}; + size_t cho_row_{0}; + size_t cho_col_{0}; + size_t batch_{0}; + size_t m_{0}; + size_t lda_{0}; + size_t ldb_{0}; + size_t input_size_{0}; + size_t output_size_{0}; + cusolverDnHandle_t handle_{nullptr}; + cublasHandle_t blas_handle_{nullptr}; + cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER; + bool lower_{false}; + std::vector h_array_{}; + std::vector input_size_list_{}; + std::vector output_size_list_{}; + std::vector workspace_size_list_{}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.cc new file mode 100644 index 00000000000..fe100558dcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/lu_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LU, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + LUGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(LU, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + LUGpuKernel, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h new file mode 100644 index 00000000000..70bd8696b14 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h @@ -0,0 +1,175 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" +#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t kLuInputsNum = 1; +constexpr size_t kInputIndex = 0; +constexpr size_t kLuOutputsNum = 1; +constexpr size_t kOutputIndex = 0; +constexpr size_t kLuDefaultShape = 1; +constexpr size_t kLuNormalShape = 2; + +template +class LUGpuKernel : public GpuKernel { + public: + LUGpuKernel() = default; + ~LUGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto input_addr = GetDeviceAddress(inputs, kDim0); + auto output_addr = GetDeviceAddress(outputs, kDim0); + int *piv_output_addr = nullptr; + if (pivot_on_) { + piv_output_addr = GetDeviceAddress(outputs, kDim1); + } + + auto info_output_addr = GetDeviceAddress(outputs, kDim2); + + // 4. query working space of getrf + if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_, + cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_), + "cusolver query lu work size fail"); + + if (cudaMalloc(reinterpret_cast(&d_work_), unit_size_ * lwork_) != cudaSuccess) { + MS_LOG(EXCEPTION) << "cusolver malloc work size fail"; + } + + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + + } else if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_, + cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_), + "cusolver query lu work size fail"); + // 5. malloc device working space of getrf + + if (cudaMalloc(reinterpret_cast(&d_work_), unit_size_ * lwork_) != cudaSuccess) { + MS_LOG(EXCEPTION) << "cusolver malloc work size fail"; + } + + // 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix. + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + } else { + MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now."; + } + // 7. copy results from written input's matrix to output's matrix. + // if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) { + // MS_LOG(EXCEPTION) << "memcpy lu output fail."; + // } + MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast(stream_ptr)); + if (d_work_) { + cudaFree(d_work_); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + // 1. get CuSolver Dense matrix handler + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + // 2. check input shape not null + bool is_null_input = CHECK_NULL_INPUT(in_shape); + if (is_null_input) { + MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null"; + } + // 3. calculate input size + if (!InitInputSize(in_shape)) { + MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed."; + } + return true; + } + + private: + bool InitInputSize(const std::vector &in_shape) { + if (in_shape.size() == kLuDefaultShape) { + lu_row_ = in_shape.at(kDim0); + lu_col_ = lu_row_; + } else if (in_shape.size() == kLuNormalShape) { + lu_row_ = in_shape.at(kDim0); + lu_col_ = in_shape.at(kDim1); + } else { + MS_LOG(ERROR) << "Input Only support Rank 1 OR 2"; + return false; + } + if (lu_row_ != lu_col_) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + return false; + } + // set matrix row or col to be lead dimension + m_ = SizeToInt(lu_row_); + lda_ = m_; + ldb_ = m_; + InitSizeLists(); + return true; + } + + void InitSizeLists() override { + size_t input_size = lda_ * m_ * unit_size_; + input_size_list_.push_back(input_size); + + size_t output_size = lda_ * m_ * unit_size_; + size_t output_piv_size = 0; + size_t output_info_size = sizeof(int); + if (pivot_on_) { + output_piv_size = m_ * sizeof(int); + } + output_size_list_.resize(kDim3); + output_size_list_[kDim0] = output_size; + output_size_list_[kDim1] = output_piv_size; + output_size_list_[kDim2] = output_info_size; + } + + size_t unit_size_{sizeof(T)}; + size_t lu_row_{0}; + size_t lu_col_{0}; + size_t m_{0}; + size_t lda_{0}; + size_t ldb_{0}; + int lwork_{0}; + bool pivot_on_{true}; + T *d_work_{nullptr}; + cusolverDnHandle_t handle_{nullptr}; + std::vector input_size_list_{}; + std::vector output_size_list_{}; + std::vector workspace_size_list_{}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_cholesky_op.py b/tests/st/ops/gpu/test_cholesky_op.py index 34965cf5564..1ccf555de95 100644 --- a/tests/st/ops/gpu/test_cholesky_op.py +++ b/tests/st/ops/gpu/test_cholesky_op.py @@ -12,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - import numpy as np +import scipy as scp import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P from mindspore.common import dtype as mstype +from mindspore.ops import PrimitiveWithInfer +from mindspore.ops import prim_attr_register +from mindspore._checkparam import Validator as validator context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + class NetCholesky(nn.Cell): def __init__(self): super(NetCholesky, self).__init__() @@ -32,13 +36,75 @@ class NetCholesky(nn.Cell): return self.cholesky(x) +class ScipyCholesky(PrimitiveWithInfer): + """ + Inner API for Cholesky base class. + """ + + @prim_attr_register + def __init__(self, lower=False, clean=False): + super().__init__(name="PureCholesky") + self.lower = validator.check_value_type("lower", lower, [bool], self.lower) + self.clean = validator.check_value_type("clean", clean, [bool], self.clean) + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + def __infer__(self, x): + x_shape = x['shape'] + x_dtype = x['dtype'] + return { + 'shape': tuple(x_shape), + 'dtype': x_dtype, + 'value': None + } + + +class ScipyNetCholesky(nn.Cell): + def __init__(self, lower=False, clean=False): + super(ScipyNetCholesky, self).__init__() + self.cholesky = ScipyCholesky(lower, clean) + + def construct(self, x): + return self.cholesky(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_cholesky_fp32(): + """ + Feature: ALL TO ALL + Description: test cases for origin cholesky [N,N] + Expectation: the result match np cholesky + """ cholesky = NetCholesky() x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32) output = cholesky(Tensor(x, dtype=mstype.float32)) expect = np.linalg.cholesky(x) tol = 1e-6 assert (np.abs(output.asnumpy() - expect) < tol).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scipy_cholesky_fp32(): + """ + Feature: ALL TO ALL + Description: test cases for new scipy cholesky [N,N] + Expectation: the result match scipy cholesky + """ + a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32) + tensor_a = Tensor(a) + cholesky = ScipyNetCholesky(lower=True, clean=False) + output = cholesky(tensor_a) + + cholesky1 = ScipyNetCholesky(lower=False, clean=False) + output1 = cholesky1(tensor_a) + + expect = scp.linalg.cholesky(a, lower=True) + expect1 = scp.linalg.cholesky(a, lower=False) + + rtol = 1.e-4 + atol = 1.e-5 + assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol) + assert np.allclose(expect1, output1.asnumpy(), rtol=rtol, atol=atol) diff --git a/tests/st/ops/gpu/test_lu_op.py b/tests/st/ops/gpu/test_lu_op.py new file mode 100644 index 00000000000..a43a1608d9a --- /dev/null +++ b/tests/st/ops/gpu/test_lu_op.py @@ -0,0 +1,94 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import Generic +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.numpy as mnp +import mindspore.common.dtype as mstype +from mindspore.ops import PrimitiveWithInfer +from mindspore.ops import prim_attr_register +import scipy as scp +import numpy as np +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + +class LU(PrimitiveWithInfer): + """ + LU decomposition with partial pivoting + P.A = L.U + """ + + @prim_attr_register + def __init__(self): + super().__init__(name="LU") + self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation']) + + def __infer__(self, x): + x_shape = list(x['shape']) + x_dtype = x['dtype'] + pivots_shape = [] + permutation_shape = [] + ndim = len(x_shape) + if ndim == 0: + pivots_shape = x_shape + permutation_shape = x_shape + elif ndim == 1: + pivots_shape = x_shape[:-1] + # permutation_shape = x_shape[:-1] + else: + pivots_shape = x_shape[-2:-1] + # permutation_shape = x_shape[-2:-1] + + output = { + 'shape': (x_shape, pivots_shape, permutation_shape), + 'dtype': (x_dtype, mstype.int32, mstype.int32), + 'value': None + } + return output + + +class LuNet(nn.Cell): + def __init__(self): + super(LuNet, self).__init__() + self.lu = LU() + + def construct(self, a): + return self.lu(a) + + +@pytest.mark.platform_x86_gpu +@pytest.mark.parametrize('n', [10, 20]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +def test_lu_net(n: int, dtype: Generic): + """ + Feature: ALL To ALL + Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] + Expectation: the result match to scipy + """ + a = (np.random.random((n, n)) + np.eye(n)).astype(dtype) + expect, _ = scp.linalg.lu_factor(a) + mscp_lu_net = LuNet() + # mindspore tensor is row major but gpu cusolver is col major, so we should transpose it. + tensor_a = Tensor(a) + tensor_a = mnp.transpose(tensor_a) + output, _, _ = mscp_lu_net(tensor_a) + # mindspore tensor is row major but gpu cusolver is col major, so we should transpose it. + output = mnp.transpose(output) + rtol = 1.e-4 + atol = 1.e-5 + assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)