!25608 add GPU trsm 2d matrix support

Merge pull request !25608 from zhujingxuan/master
This commit is contained in:
i-robot 2021-10-30 08:44:24 +00:00 committed by Gitee
commit 80bed6c7c5
3 changed files with 111 additions and 13 deletions

View File

@ -47,26 +47,72 @@ class TrsmGpuKernel : public GpuKernel {
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
const size_t batch = m_ * n_;
// if b is not a vector, solve b in the workspace
T *dst = nullptr;
if (n_ == 1) {
dst = output_addr;
} else {
dst = GetDeviceAddress<T>(workspace, 0);
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_addr, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output_addr failed");
if (n_ == 1) {
const size_t batch = m_ * n_;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dst, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync dst failed");
} else {
T alpha = 1;
T beta = 0;
// in order to convert row major matrix b(m x n) to col major matrix b'(m x n),
// the following operation is equivalent to:
// b' = b.T.reshape(m, n)
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
n_, &beta, inputb_addr, n_, dst, m_),
"cublas transpose b Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
n_, &beta, inputb_addr, n_, dst, m_),
"cublas transpose b Fail");
}
}
T alpha = 1;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasStrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
&alpha, inputA_addr, lda_, output_addr, ldb_),
&alpha, inputA_addr, lda_, dst, ldb_),
"cublas trsm Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasDtrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
&alpha, inputA_addr, lda_, output_addr, ldb_),
&alpha, inputA_addr, lda_, dst, ldb_),
"cublas trsm Fail");
}
// if x is not a vector, do transpose
if (n_ != 1) {
T alpha = 1;
T beta = 0;
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
// the following operation is equivalent to:
// x = x'.reshape(n, m).T
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
"cublas transpose x Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
"cublas transpose x Fail");
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
@ -97,9 +143,8 @@ class TrsmGpuKernel : public GpuKernel {
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
} else {
MS_LOG(EXCEPTION) << "b as a matrix is currently not supported.";
n_ = b_shape[kDim1];
}
m_ = b_shape[kDim0];
lda_ = SizeToInt(m_);
ldb_ = SizeToInt(m_);
@ -137,8 +182,13 @@ class TrsmGpuKernel : public GpuKernel {
protected:
void InitSizeLists() override {
size_t unit_size = sizeof(T);
input_size_list_ = {m_ * m_ * unit_size, m_ * n_ * unit_size};
output_size_list_ = {m_ * n_ * unit_size};
size_t A_size = m_ * m_ * unit_size;
size_t b_size = m_ * n_ * unit_size;
input_size_list_ = {A_size, b_size};
output_size_list_ = {b_size};
if (n_ != 1) {
workspace_size_list_ = {b_size};
}
}
private:

View File

@ -118,3 +118,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
b = np.random.random(n).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(10, 20)])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]
Expectation: the result match scipy
"""
if trans == 'T':
n, m = shape
else:
m, n = shape
# add Identity matrix to make matrix A non-singular
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
b = np.random.random((m, n)).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)

View File

@ -79,7 +79,7 @@ def match(a, b, lower, unit_diagonal, trans):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@ -99,7 +99,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@ -116,3 +116,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
b = np.random.random(n).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(4, 5), (10, 20)])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]
Expectation: the result match scipy
"""
if trans == 'T':
n, m = shape
else:
m, n = shape
# add Identity matrix to make matrix A non-singular
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
b = np.random.random((m, n)).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)