forked from mindspore-Ecosystem/mindspore
!25608 add GPU trsm 2d matrix support
Merge pull request !25608 from zhujingxuan/master
This commit is contained in:
commit
80bed6c7c5
|
@ -47,26 +47,72 @@ class TrsmGpuKernel : public GpuKernel {
|
|||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const size_t batch = m_ * n_;
|
||||
// if b is not a vector, solve b in the workspace
|
||||
T *dst = nullptr;
|
||||
if (n_ == 1) {
|
||||
dst = output_addr;
|
||||
} else {
|
||||
dst = GetDeviceAddress<T>(workspace, 0);
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output_addr failed");
|
||||
if (n_ == 1) {
|
||||
const size_t batch = m_ * n_;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dst, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync dst failed");
|
||||
} else {
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
// in order to convert row major matrix b(m x n) to col major matrix b'(m x n),
|
||||
// the following operation is equivalent to:
|
||||
// b' = b.T.reshape(m, n)
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
|
||||
n_, &beta, inputb_addr, n_, dst, m_),
|
||||
"cublas transpose b Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
|
||||
n_, &beta, inputb_addr, n_, dst, m_),
|
||||
"cublas transpose b Fail");
|
||||
}
|
||||
}
|
||||
|
||||
T alpha = 1;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasStrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
|
||||
&alpha, inputA_addr, lda_, output_addr, ldb_),
|
||||
&alpha, inputA_addr, lda_, dst, ldb_),
|
||||
"cublas trsm Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasDtrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
|
||||
&alpha, inputA_addr, lda_, output_addr, ldb_),
|
||||
&alpha, inputA_addr, lda_, dst, ldb_),
|
||||
"cublas trsm Fail");
|
||||
}
|
||||
|
||||
// if x is not a vector, do transpose
|
||||
if (n_ != 1) {
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
|
||||
// the following operation is equivalent to:
|
||||
// x = x'.reshape(n, m).T
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
|
||||
"cublas transpose x Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
|
||||
"cublas transpose x Fail");
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -97,9 +143,8 @@ class TrsmGpuKernel : public GpuKernel {
|
|||
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "b as a matrix is currently not supported.";
|
||||
n_ = b_shape[kDim1];
|
||||
}
|
||||
m_ = b_shape[kDim0];
|
||||
|
||||
lda_ = SizeToInt(m_);
|
||||
ldb_ = SizeToInt(m_);
|
||||
|
@ -137,8 +182,13 @@ class TrsmGpuKernel : public GpuKernel {
|
|||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t unit_size = sizeof(T);
|
||||
input_size_list_ = {m_ * m_ * unit_size, m_ * n_ * unit_size};
|
||||
output_size_list_ = {m_ * n_ * unit_size};
|
||||
size_t A_size = m_ * m_ * unit_size;
|
||||
size_t b_size = m_ * n_ * unit_size;
|
||||
input_size_list_ = {A_size, b_size};
|
||||
output_size_list_ = {b_size};
|
||||
if (n_ != 1) {
|
||||
workspace_size_list_ = {b_size};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -118,3 +118,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
|||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
b = np.random.random(n).astype(dtype)
|
||||
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(10, 20)])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False, True])
|
||||
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
if trans == 'T':
|
||||
n, m = shape
|
||||
else:
|
||||
m, n = shape
|
||||
# add Identity matrix to make matrix A non-singular
|
||||
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
|
||||
b = np.random.random((m, n)).astype(dtype)
|
||||
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
|
|
|
@ -79,7 +79,7 @@ def match(a, b, lower, unit_diagonal, trans):
|
|||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
|
@ -99,7 +99,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
|||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
|
@ -116,3 +116,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
|||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
b = np.random.random(n).astype(dtype)
|
||||
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(4, 5), (10, 20)])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False, True])
|
||||
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
if trans == 'T':
|
||||
n, m = shape
|
||||
else:
|
||||
m, n = shape
|
||||
# add Identity matrix to make matrix A non-singular
|
||||
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
|
||||
b = np.random.random((m, n)).astype(dtype)
|
||||
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
|
|
Loading…
Reference in New Issue