!8595 Add Cholesky op at GPU back-end

From: @peixu_ren
Reviewed-by: @zichun_ye,@wilfchen
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-02 10:10:24 +08:00 committed by Gitee
commit eb7cbd49ac
7 changed files with 444 additions and 3 deletions

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "triangle_matrix_copy_impl.cuh"
template <typename T>
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0;
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
if (col > row) {
output[i] = 0;
} else {
output[i] = input[i];
}
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
if (col < row) {
output[i] = 0;
} else {
output[i] = input[i];
}
}
}
}
template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) {
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
return;
}
template void TriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,247 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
template <typename T>
class CholeskyGpuKernel : public GpuKernel {
public:
CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
~CholeskyGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
auto d_array_addr = GetDeviceAddress<T *>(workspace, 0);
auto d_identity_addr = GetDeviceAddress<T *>(workspace, 1);
if (!use_split_matrix) {
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
for (size_t i = 0; i < batch_; i++) {
h_array[i] = input1_addr + i * lda_ * m_;
h_identity[i] = output_addr + i * ldb_ * m_;
CHECK_CUDA_RET_WITH_ERROR(
cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
}
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
TriangleMatrixCopy(input1_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, 3);
for (size_t i = 0; i < batch_; i++) {
h_array[i] = d_batch_input_addr + i * lda_ * m_;
h_identity[i] = output_addr + i * ldb_ * m_;
}
Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
split_dim = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
if (split_dim == 0) {
InitNoSpltDim(in_shape);
} else {
InitSpltDim(in_shape);
}
return true;
}
protected:
void InitNoSpltDim(const std::vector<size_t> &in_shape) {
use_split_matrix = false;
if (in_shape.size() == 2) {
batch_ = 1;
if (in_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
}
} else if (in_shape.size() == 3) {
batch_ = SizeToInt(in_shape[0]);
if (in_shape[1] != in_shape[2]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
}
} else {
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
}
m_ = SizeToInt(in_shape[1]);
lda_ = m_;
ldb_ = m_;
h_array.resize(batch_);
h_identity.resize(batch_);
h_identity_data.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data[i * m_ + j] = 1;
} else {
h_identity_data[i * m_ + j] = 0;
}
}
}
InitSizeLists();
}
void InitSpltDim(const std::vector<size_t> &in_shape) {
if (in_shape.size() != 2) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
}
height = in_shape[0];
width = in_shape[1];
if (height != width) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
}
if (SizeToInt(height) <= split_dim) {
use_split_matrix = false;
batch_ = 1;
m_ = SizeToInt(in_shape[1]);
lda_ = m_;
ldb_ = m_;
h_array.resize(batch_);
h_identity.resize(batch_);
h_identity_data.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data[i * m_ + j] = 1;
} else {
h_identity_data[i * m_ + j] = 0;
}
}
}
InitSizeLists();
} else {
use_split_matrix = true;
int batch = SizeToInt(in_shape[1]) / split_dim;
res_dim = in_shape[1] - batch * split_dim;
if (res_dim == 0) {
batch_ = batch;
} else {
batch_ = batch + 1;
}
m_ = split_dim;
lda_ = m_;
ldb_ = m_;
h_array.resize(batch_);
h_identity.resize(batch_);
h_identity_data.resize(m_ * m_);
for (size_t i = 0; i < m_; i++) {
for (size_t j = 0; j < m_; j++) {
if (i == j) {
h_identity_data[i * m_ + j] = 1;
} else {
h_identity_data[i * m_ + j] = 0;
}
}
}
InitSizeLists();
}
}
void InitSizeLists() override {
size_t unit_size = sizeof(T);
size_t input_size;
size_t workspace_size;
if (!use_split_matrix) {
input_size = batch_ * m_ * lda_ * unit_size;
} else {
input_size = height * width * unit_size;
workspace_size = batch_ * m_ * lda_ * unit_size;
workspace_size_list_.push_back(workspace_size);
}
input_size_list_.push_back(input_size);
size_t output_size = batch_ * m_ * lda_ * unit_size;
output_size_list_.push_back(output_size);
workspace_size = batch_ * sizeof(T *);
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
workspace_size = batch_ * sizeof(T *);
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
workspace_size = batch_ * sizeof(int);
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
}
private:
size_t batch_;
size_t m_;
size_t lda_;
size_t ldb_;
int res_dim;
int split_dim;
bool is_null_input_;
bool use_split_matrix;
size_t height;
size_t width;
cusolverDnHandle_t handle_;
cublasHandle_t blas_handle_;
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
std::vector<T *> h_array;
std::vector<T *> h_identity;
std::vector<T> h_identity_data;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -88,7 +88,7 @@ from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingB
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, CholeskyTrsm, DetTriangle)
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle)
from .sparse_ops import SparseToDense
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx

View File

@ -608,9 +608,42 @@ class UpdateThorGradient(PrimitiveWithInfer):
return x2_dtype
class Cholesky(PrimitiveWithInfer):
"""
Inner API for positive-definite matrix Cholesky decomposition GPU backend.
"""
@prim_attr_register
def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
def infer_shape(self, x1_shape):
if self.split_dim != 0:
assert len(x1_shape) == 2
height = x1_shape[0]
width = x1_shape[1]
assert height == width
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = x1_shape
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype
class CholeskyTrsm(PrimitiveWithInfer):
"""
Inner API for resnet50 THOR GPU backend
Inner API for resnet50 THOR GPU backend.
"""
@prim_attr_register
@ -643,7 +676,23 @@ class CholeskyTrsm(PrimitiveWithInfer):
class DetTriangle(PrimitiveWithInfer):
"""
Calculate the determinant of triangle matrices
Calculate the determinant of triangle matrices.
Args:
fill_mode (tuple): The target shape to broadcast.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, with the given `shape` and the same data type as `input_x`.
Examples:
>>> shape = (2, 3)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> broadcast_to = P.BroadcastTo(shape)
>>> broadcast_to(input_x)
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
"""
@prim_attr_register

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetCholesky(nn.Cell):
def __init__(self):
super(NetCholesky, self).__init__()
self.cholesky = P.Cholesky()
def construct(self, x):
return self.cholesky(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cholesky_fp32():
cholesky = NetCholesky()
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
output = cholesky(Tensor(x, dtype=mstype.float32))
expect = np.linalg.cholesky(x)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect) < tol).all()