forked from mindspore-Ecosystem/mindspore
!26849 bind stream with handle
Merge pull request !26849 from zhujingxuan/master
This commit is contained in:
commit
b1deeb425d
|
@ -51,6 +51,10 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
if (!use_split_matrix_) {
|
||||
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
|
|
@ -49,6 +49,8 @@ class CholeskySolveGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
|
|
@ -55,6 +55,10 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
|
|||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
if (!use_split_matrix_) {
|
||||
LaunchNonSplitMatrix(inputs, workspace, outputs, stream_ptr);
|
||||
} else {
|
||||
|
|
|
@ -85,6 +85,10 @@ class EighcGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
// matrix A, input or output(eigenvector)
|
||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
if (lower_) {
|
||||
|
|
|
@ -71,6 +71,8 @@ class EighGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
// matrix A, input or output(eigenvector)
|
||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
// Notice :this is important
|
||||
|
|
|
@ -47,6 +47,8 @@ class LUGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
int *piv_output_addr = nullptr;
|
||||
|
|
|
@ -43,8 +43,9 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -41,6 +41,8 @@ class MatrixInverseGpuKernel : public GpuKernel {
|
|||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
|
|
|
@ -43,6 +43,8 @@ class TrsmGpuKernel : public GpuKernel {
|
|||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
auto inputA_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
|
|
@ -53,6 +53,8 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
|
|||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto input3_addr = GetDeviceAddress<T>(inputs, 2);
|
||||
|
|
Loading…
Reference in New Issue