!48707 [MSLITE] Fix empty tensor problem in relu gpu op

Merge pull request !48707 from zhangyongxian/dev_zhangyongxian_relugpu
This commit is contained in:
i-robot 2023-02-10 11:05:46 +00:00 committed by Gitee
commit e5e8d7bb10
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 5 additions and 0 deletions

View File

@ -44,6 +44,7 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
return ret; return ret;
} }
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector()); auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
is_null_input_ = CHECK_NULL_INPUT(inputs[kIndex0]->GetShapeVector());
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>()); input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>());
return KRET_OK; return KRET_OK;
} }
@ -51,6 +52,9 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
template <typename T> template <typename T>
bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) { const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0); T *input = GetDeviceAddress<T>(inputs, 0);
MS_ERROR_IF_NULL_W_RET_VAL(input, false); MS_ERROR_IF_NULL_W_RET_VAL(input, false);
T *output = GetDeviceAddress<T>(outputs, 0); T *output = GetDeviceAddress<T>(outputs, 0);

View File

@ -56,6 +56,7 @@ class ReLUFwdGpuKernelMod : public NativeGpuKernelMod {
ReLUFwLaunchFunc kernel_func_; ReLUFwLaunchFunc kernel_func_;
static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_; static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_;
int input_length_{0}; int input_length_{0};
bool is_null_input_{false};
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore