!26849 bind stream with handle

Merge pull request !26849 from zhujingxuan/master
This commit is contained in:
i-robot 2021-11-27 02:11:40 +00:00 committed by Gitee
commit b1deeb425d
10 changed files with 26 additions and 1 deletions

View File

@ -51,6 +51,10 @@ class CholeskyGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
if (!use_split_matrix_) {
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
}

View File

@ -49,6 +49,8 @@ class CholeskySolveGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);

View File

@ -55,6 +55,10 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
if (!use_split_matrix_) {
LaunchNonSplitMatrix(inputs, workspace, outputs, stream_ptr);
} else {

View File

@ -85,6 +85,10 @@ class EighcGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
// matrix A, input or output(eigenvector)
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
if (lower_) {

View File

@ -71,6 +71,8 @@ class EighGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
// matrix A, input or output(eigenvector)
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
// Notice :this is important

View File

@ -47,6 +47,8 @@ class LUGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *piv_output_addr = nullptr;

View File

@ -43,8 +43,9 @@ class MatMulGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}

View File

@ -41,6 +41,8 @@ class MatrixInverseGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);

View File

@ -43,6 +43,8 @@ class TrsmGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
auto inputA_addr = GetDeviceAddress<T>(inputs, 0);
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);

View File

@ -53,6 +53,8 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
auto input3_addr = GetDeviceAddress<T>(inputs, 2);