!1210 Add exception check for BiasAdd kernel

Merge pull request !1210 from chenweifeng/cudnn_exception
This commit is contained in:
mindspore-ci-bot 2020-05-20 16:02:12 +08:00 committed by Gitee
commit 47275427da
2 changed files with 19 additions and 10 deletions

View File

@ -49,11 +49,15 @@ class BiasAddGpuKernel : public GpuKernel {
T *b_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, b_addr,
&beta, x_desc_, output_addr),
"cudnnOpTensor Add failed");
try {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_,
b_addr, &beta, x_desc_, output_addr),
"cudnnOpTensor failed");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {

View File

@ -64,11 +64,16 @@ class MatMulGpuKernel : public GpuKernel {
auto stride_a = SizeToInt(m_ * k_);
auto stride_b = SizeToInt(k_ * n_);
auto stride_c = SizeToInt(m_ * n_);
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
try {
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {