forked from mindspore-Ecosystem/mindspore
!26849 bind stream with handle
Merge pull request !26849 from zhujingxuan/master
This commit is contained in:
commit
b1deeb425d
|
@ -51,6 +51,10 @@ class CholeskyGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
if (!use_split_matrix_) {
|
if (!use_split_matrix_) {
|
||||||
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,8 @@ class CholeskySolveGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
|
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||||
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
|
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
|
||||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||||
|
|
|
@ -55,6 +55,10 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
if (!use_split_matrix_) {
|
if (!use_split_matrix_) {
|
||||||
LaunchNonSplitMatrix(inputs, workspace, outputs, stream_ptr);
|
LaunchNonSplitMatrix(inputs, workspace, outputs, stream_ptr);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -85,6 +85,10 @@ class EighcGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
// matrix A, input or output(eigenvector)
|
// matrix A, input or output(eigenvector)
|
||||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||||
if (lower_) {
|
if (lower_) {
|
||||||
|
|
|
@ -71,6 +71,8 @@ class EighGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(cusolver_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
// matrix A, input or output(eigenvector)
|
// matrix A, input or output(eigenvector)
|
||||||
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||||
// Notice :this is important
|
// Notice :this is important
|
||||||
|
|
|
@ -47,6 +47,8 @@ class LUGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cusolverDnSetStream failed");
|
||||||
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||||
int *piv_output_addr = nullptr;
|
int *piv_output_addr = nullptr;
|
||||||
|
|
|
@ -43,8 +43,9 @@ class MatMulGpuKernel : public GpuKernel {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
VARIABLE_NOT_USED(workspace);
|
VARIABLE_NOT_USED(workspace);
|
||||||
VARIABLE_NOT_USED(stream_ptr);
|
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,8 @@ class MatrixInverseGpuKernel : public GpuKernel {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
|
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
|
||||||
|
|
|
@ -43,6 +43,8 @@ class TrsmGpuKernel : public GpuKernel {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
auto inputA_addr = GetDeviceAddress<T>(inputs, 0);
|
auto inputA_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
|
@ -53,6 +53,8 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cublasSetStream failed");
|
||||||
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
|
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
|
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
|
||||||
auto input3_addr = GetDeviceAddress<T>(inputs, 2);
|
auto input3_addr = GetDeviceAddress<T>(inputs, 2);
|
||||||
|
|
Loading…
Reference in New Issue