!26672 Use GPU mem Allocator and workspace instead of self allocator

Merge pull request !26672 from wuwenbing/master
This commit is contained in:
i-robot 2021-11-25 01:49:14 +00:00 committed by Gitee
commit 953acc0335
2 changed files with 42 additions and 49 deletions

View File

@ -86,34 +86,30 @@ class EighcGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// matrix A, input or output(eigenvector)
auto inout_A_addr = GetDeviceAddress<T>(inputs, 0);
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
}
size_t lda_ = m_;
int *devInfo = nullptr;
cudaMalloc(reinterpret_cast<void **>(&devInfo), sizeof(int));
T *d_work = nullptr;
auto output_w_addr = GetDeviceAddress<T>(outputs, 0);
auto output_w_addr = GetDeviceAddress<T>(outputs, kDim0);
// output eigenvector
auto output_v_addr = GetDeviceAddress<T>(outputs, 1);
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1);
int *devInfo = GetDeviceAddress<int>(workspace, kDim0);
// temp output eigenvalues real scalar
auto w_w_addr = GetDeviceAddress<D>(workspace, 0);
auto w_w_c_addr = GetDeviceAddress<T>(workspace, 1);
auto w_w_addr = GetDeviceAddress<D>(workspace, kDim1);
auto w_w_c_addr = GetDeviceAddress<T>(workspace, kDim2);
// temp eigenvector before transpose
auto w_v_addr = GetDeviceAddress<T>(workspace, 2);
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim3);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy input matrix failed");
size_t input_shape[kShape2dDims] = {m_, m_};
size_t input_axis[kShape2dDims] = {1, 0};
size_t *dev_input_shape = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
size_t *dev_input_axis = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
size_t *dev_input_shape = GetDeviceAddress<size_t>(workspace, kDim4);
size_t *dev_input_axis = GetDeviceAddress<size_t>(workspace, kDim5);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -126,18 +122,22 @@ class EighcGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
int lwork = 0;
void *d_work = nullptr;
if constexpr (std::is_same_v<T, Complex<float>>) {
cusolverDnCheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr),
lda_, w_w_addr, &lwork);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
"cal eigenvalues workspace failed");
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(w_v_addr), lda_, w_w_addr,
reinterpret_cast<cuComplex *>(d_work), lwork, devInfo);
} else {
cusolverDnZheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_,
reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, w_w_addr, &lwork);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork),
"cal eigenvalues workspace failed");
}
d_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(sizeof(T) * lwork);
if (!d_work) {
MS_LOG(EXCEPTION) << "GPU memory alloca failed.";
}
if constexpr (std::is_same_v<T, Complex<float>>) {
cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(w_v_addr), lda_, w_w_addr,
reinterpret_cast<cuComplex *>(d_work), lwork, devInfo);
} else {
cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(w_v_addr), lda_,
w_w_addr, reinterpret_cast<cuDoubleComplex *>(d_work), lwork, devInfo);
}
@ -145,28 +145,17 @@ class EighcGpuKernel : public GpuKernel {
cudaMemcpyAsync(w_w_c_addr, w_w_addr, m_ * sizeof(D), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy eigenvalue from workspace to host failed");
// convert real scalar to complex
RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr),
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (dev_input_shape) {
cudaFree(dev_input_shape);
}
if (dev_input_axis) {
cudaFree(dev_input_axis);
}
// convert real scalar to complex
if (d_work) {
cudaFree(d_work);
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
int info_gpu = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&info_gpu, devInfo, sizeof(int), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy eigenvalues to outpu failed");
if (devInfo) {
cudaFree(devInfo);
}
if (info_gpu != 0) {
MS_LOG_EXCEPTION << kernel_name_ << " launch gpu kernel fail for dtype:" << dtype_;
}
@ -180,11 +169,16 @@ class EighcGpuKernel : public GpuKernel {
// eigenvalues, cuda output original real scalar, should covert to complex<ft32/64>
output_size_list_.push_back(m_ * sizeof(T));
output_size_list_.push_back(m_ * m_ * sizeof(T));
// result
workspace_size_list_.push_back(sizeof(int));
// for temp original eigenvalue real scalar
workspace_size_list_.push_back(m_ * sizeof(D));
// for temp pre-transpose complex mitrx
workspace_size_list_.push_back(m_ * sizeof(T));
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
// transpose scalar workspace
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
}
size_t m_{1};

View File

@ -83,7 +83,8 @@ class EighGpuKernel : public GpuKernel {
}
auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues
auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim0); // temp eigenvector before transpose
int *devInfo = GetDeviceAddress<int>(workspace, kDim0);
auto w_v_addr = GetDeviceAddress<T>(workspace, kDim1); // temp eigenvector before transpose
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(w_v_addr, inout_A_addr, m_ * m_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -95,38 +96,31 @@ class EighGpuKernel : public GpuKernel {
} else if constexpr (std::is_same_v<T, double>) {
cusolverDnDsyevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, inout_A_addr, lda_, output_addr, &lwork);
}
int *devInfo = nullptr;
cudaMalloc(reinterpret_cast<void **>(&devInfo), sizeof(int));
T *d_work = nullptr;
cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork);
void *d_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(sizeof(T) * lwork);
if constexpr (std::is_same_v<T, float>) {
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, reinterpret_cast<T *>(d_work),
lwork, devInfo);
} else if constexpr (std::is_same_v<T, double>) {
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo);
cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, reinterpret_cast<T *>(d_work),
lwork, devInfo);
}
size_t input_shape[kShape2dDims] = {m_, m_};
size_t input_axis[kShape2dDims] = {1, 0};
size_t *dev_input_shape = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t));
size_t *dev_input_axis = nullptr;
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t));
size_t *dev_input_shape = GetDeviceAddress<size_t>(workspace, kDim2);
size_t *dev_input_axis = GetDeviceAddress<size_t>(workspace, kDim3);
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (d_work) {
cudaFree(d_work);
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
int info_gpu = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&info_gpu, devInfo, sizeof(int), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"copy to device result failed");
if (devInfo) {
cudaFree(devInfo);
}
if (info_gpu != 0) {
MS_LOG_EXCEPTION << kernel_name_ << " launch gpu kernel fail for dtype:" << dtype_;
}
@ -141,7 +135,12 @@ class EighGpuKernel : public GpuKernel {
output_size_list_.push_back(m_ * sizeof(T));
// eigenvector
output_size_list_.push_back(m_ * m_ * sizeof(T));
// result
workspace_size_list_.push_back(sizeof(int));
workspace_size_list_.push_back(m_ * m_ * sizeof(T));
// transpose scalar workspace
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
workspace_size_list_.push_back(kShape2dDims * sizeof(size_t));
}
size_t m_{1};