From f3b26faab1936a55f987a4a6e6c6c9d557854b3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=8B=87=E8=B4=A4?= Date: Fri, 10 Feb 2023 14:47:33 +0800 Subject: [PATCH] [MSLITE] Fix rank 0 problems in relu gpu op --- .../ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.cc | 4 ++++ mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.h | 1 + 2 files changed, 5 insertions(+) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.cc index 973ddfdc4e0..e8980be1caa 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.cc @@ -44,6 +44,7 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std: return ret; } auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector()); + is_null_input_ = CHECK_NULL_INPUT(inputs[kIndex0]->GetShapeVector()); input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast(1), std::multiplies<>()); return KRET_OK; } @@ -51,6 +52,9 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std: template bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } T *input = GetDeviceAddress(inputs, 0); MS_ERROR_IF_NULL_W_RET_VAL(input, false); T *output = GetDeviceAddress(outputs, 0); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.h index 28c93a1a93c..30fff8ffb64 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/relu_gpu_kernel.h @@ -56,6 +56,7 @@ class ReLUFwdGpuKernelMod : public NativeGpuKernelMod { ReLUFwLaunchFunc kernel_func_; static std::vector> func_list_; int input_length_{0}; + bool is_null_input_{false}; }; } // namespace kernel } // namespace mindspore