forked from mindspore-Ecosystem/mindspore
!1210 Add exception check for BiasAdd kernel
Merge pull request !1210 from chenweifeng/cudnn_exception
This commit is contained in:
commit
47275427da
|
@ -49,11 +49,15 @@ class BiasAddGpuKernel : public GpuKernel {
|
|||
T *b_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, b_addr,
|
||||
&beta, x_desc_, output_addr),
|
||||
"cudnnOpTensor Add failed");
|
||||
try {
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_,
|
||||
b_addr, &beta, x_desc_, output_addr),
|
||||
"cudnnOpTensor failed");
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
|
|
@ -64,11 +64,16 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
auto stride_a = SizeToInt(m_ * k_);
|
||||
auto stride_b = SizeToInt(k_ * n_);
|
||||
auto stride_c = SizeToInt(m_ * n_);
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
|
||||
try {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
|
Loading…
Reference in New Issue