[MSLITE] Fix rank 0 problems in relu gpu op

This commit is contained in:
张勇贤 2023-02-10 14:47:33 +08:00
parent ce53db6836
commit f3b26faab1
2 changed files with 5 additions and 0 deletions

View File

@ -44,6 +44,7 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
return ret;
}
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
is_null_input_ = CHECK_NULL_INPUT(inputs[kIndex0]->GetShapeVector());
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>());
return KRET_OK;
}
@ -51,6 +52,9 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
template <typename T>
bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
MS_ERROR_IF_NULL_W_RET_VAL(input, false);
T *output = GetDeviceAddress<T>(outputs, 0);

View File

@ -56,6 +56,7 @@ class ReLUFwdGpuKernelMod : public NativeGpuKernelMod {
ReLUFwLaunchFunc kernel_func_;
static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_;
int input_length_{0};
bool is_null_input_{false};
};
} // namespace kernel
} // namespace mindspore