From af6ded28e3d457cb8df6778c380ba0f8122023a2 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 12 Nov 2020 12:05:25 -0500 Subject: [PATCH] Add Cholesky op at GPU back-end --- .../cuda_impl/triangle_matrix_copy_impl.cu | 54 ++++ .../cuda_impl/triangle_matrix_copy_impl.cuh | 24 ++ .../gpu/math/cholesky_solve_gpu_kernel.cc | 23 ++ .../gpu/math/cholesky_solve_gpu_kernel.h | 247 ++++++++++++++++++ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/_thor_ops.py | 53 +++- tests/st/ops/gpu/test_cholesky_op.py | 44 ++++ 7 files changed, 444 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_cholesky_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu new file mode 100644 index 00000000000..e9b520f43d4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cu @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "triangle_matrix_copy_impl.cuh" +template +__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, + const size_t count, const size_t ldb, const size_t m) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + size_t batchIdx = i / (ldb * m); + size_t row = (i - batchIdx * ldb * m) / m; + size_t col = (i - batchIdx * ldb * m) % m; + + // If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0; + // If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0; + if (uplo == CUBLAS_FILL_MODE_UPPER) { + if (col > row) { + output[i] = 0; + } else { + output[i] = input[i]; + } + } else if (uplo == CUBLAS_FILL_MODE_LOWER) { + if (col < row) { + output[i] = 0; + } else { + output[i] = input[i]; + } + } + } +} + +template +void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, + const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) { + TriangleMatrixCopyKernel<<>>(input, output, uplo, count, ldb, m); + return; +} + +template void TriangleMatrixCopy(const float *input, float *output, cublasFillMode_t uplo, const size_t count, + const size_t ldb, const size_t m, cudaStream_t cuda_stream); +template void TriangleMatrixCopy(const half *input, half *output, cublasFillMode_t uplo, const size_t count, + const size_t ldb, const size_t m, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh new file mode 100644 index 00000000000..98218a30d97 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, + const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc new file mode 100644 index 00000000000..9ef14295688 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CholeskyGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h new file mode 100644 index 00000000000..efc9b976ee7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h @@ -0,0 +1,247 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class CholeskyGpuKernel : public GpuKernel { + public: + CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} + ~CholeskyGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + auto input1_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + auto d_array_addr = GetDeviceAddress(workspace, 0); + auto d_identity_addr = GetDeviceAddress(workspace, 1); + if (!use_split_matrix) { + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = input1_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + } + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + TriangleMatrixCopy(input1_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_, + reinterpret_cast(stream_ptr)); + } else { + auto d_info_array_addr = GetDeviceAddress(workspace, 2); + auto d_batch_input_addr = GetDeviceAddress(workspace, 3); + for (size_t i = 0; i < batch_; i++) { + h_array[i] = d_batch_input_addr + i * lda_ * m_; + h_identity[i] = output_addr + i * ldb_ * m_; + } + Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cuda memcopy Fail"); + CHECK_CUSOLVER_RET_WITH_EXCEPT( + cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), + "cusolver cholesky batched Fail"); + TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + split_dim = static_cast(GetAttr(kernel_node, "split_dim")); + if (split_dim == 0) { + InitNoSpltDim(in_shape); + } else { + InitSpltDim(in_shape); + } + return true; + } + + protected: + void InitNoSpltDim(const std::vector &in_shape) { + use_split_matrix = false; + if (in_shape.size() == 2) { + batch_ = 1; + if (in_shape[0] != in_shape[1]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else if (in_shape.size() == 3) { + batch_ = SizeToInt(in_shape[0]); + if (in_shape[1] != in_shape[2]) { + MS_LOG(ERROR) << "Cholesky need square matrix as input."; + } + } else { + MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; + } + + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } + + void InitSpltDim(const std::vector &in_shape) { + if (in_shape.size() != 2) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2."; + } + height = in_shape[0]; + width = in_shape[1]; + if (height != width) { + MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input."; + } + if (SizeToInt(height) <= split_dim) { + use_split_matrix = false; + batch_ = 1; + m_ = SizeToInt(in_shape[1]); + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } else { + use_split_matrix = true; + int batch = SizeToInt(in_shape[1]) / split_dim; + res_dim = in_shape[1] - batch * split_dim; + if (res_dim == 0) { + batch_ = batch; + } else { + batch_ = batch + 1; + } + m_ = split_dim; + lda_ = m_; + ldb_ = m_; + h_array.resize(batch_); + h_identity.resize(batch_); + h_identity_data.resize(m_ * m_); + for (size_t i = 0; i < m_; i++) { + for (size_t j = 0; j < m_; j++) { + if (i == j) { + h_identity_data[i * m_ + j] = 1; + } else { + h_identity_data[i * m_ + j] = 0; + } + } + } + InitSizeLists(); + } + } + + void InitSizeLists() override { + size_t unit_size = sizeof(T); + size_t input_size; + size_t workspace_size; + if (!use_split_matrix) { + input_size = batch_ * m_ * lda_ * unit_size; + } else { + input_size = height * width * unit_size; + workspace_size = batch_ * m_ * lda_ * unit_size; + workspace_size_list_.push_back(workspace_size); + } + input_size_list_.push_back(input_size); + size_t output_size = batch_ * m_ * lda_ * unit_size; + output_size_list_.push_back(output_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); + workspace_size = batch_ * sizeof(T *); + workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); + workspace_size = batch_ * sizeof(int); + workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); + } + + private: + size_t batch_; + size_t m_; + size_t lda_; + size_t ldb_; + int res_dim; + int split_dim; + bool is_null_input_; + bool use_split_matrix; + size_t height; + size_t width; + cusolverDnHandle_t handle_; + cublasHandle_t blas_handle_; + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + std::vector h_array; + std::vector h_identity; + std::vector h_identity_data; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index b6f63f11468..7857a5008e0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -88,7 +88,7 @@ from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingB from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, CusMatMulCubeDenseRight, - CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, CholeskyTrsm, DetTriangle) + CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle) from .sparse_ops import SparseToDense from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py index bb3bc3aa5f9..71edf6dcaa7 100644 --- a/mindspore/ops/operations/_thor_ops.py +++ b/mindspore/ops/operations/_thor_ops.py @@ -608,9 +608,42 @@ class UpdateThorGradient(PrimitiveWithInfer): return x2_dtype +class Cholesky(PrimitiveWithInfer): + """ + Inner API for positive-definite matrix Cholesky decomposition GPU backend. + """ + + @prim_attr_register + def __init__(self, split_dim=0): + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.split_dim = split_dim + self.add_prim_attr('split_dim', self.split_dim) + + def infer_shape(self, x1_shape): + if self.split_dim != 0: + assert len(x1_shape) == 2 + height = x1_shape[0] + width = x1_shape[1] + assert height == width + if height <= self.split_dim: + out_shape = [1, height, width] + else: + batch = height // self.split_dim + if height != batch * self.split_dim: + batch += 1 + out_shape = [batch, self.split_dim, self.split_dim] + else: + out_shape = x1_shape + return out_shape + + def infer_dtype(self, x1_dtype): + validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) + return x1_dtype + + class CholeskyTrsm(PrimitiveWithInfer): """ - Inner API for resnet50 THOR GPU backend + Inner API for resnet50 THOR GPU backend. """ @prim_attr_register @@ -643,7 +676,23 @@ class CholeskyTrsm(PrimitiveWithInfer): class DetTriangle(PrimitiveWithInfer): """ - Calculate the determinant of triangle matrices + Calculate the determinant of triangle matrices. + + Args: + fill_mode (tuple): The target shape to broadcast. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, with the given `shape` and the same data type as `input_x`. + + Examples: + >>> shape = (2, 3) + >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + >>> broadcast_to = P.BroadcastTo(shape) + >>> broadcast_to(input_x) + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] """ @prim_attr_register diff --git a/tests/st/ops/gpu/test_cholesky_op.py b/tests/st/ops/gpu/test_cholesky_op.py new file mode 100644 index 00000000000..34965cf5564 --- /dev/null +++ b/tests/st/ops/gpu/test_cholesky_op.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetCholesky(nn.Cell): + def __init__(self): + super(NetCholesky, self).__init__() + self.cholesky = P.Cholesky() + + def construct(self, x): + return self.cholesky(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cholesky_fp32(): + cholesky = NetCholesky() + x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32) + output = cholesky(Tensor(x, dtype=mstype.float32)) + expect = np.linalg.cholesky(x) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect) < tol).all()