forked from mindspore-Ecosystem/mindspore
[MSLITE] Fix rank 0 problems in relu gpu op
This commit is contained in:
parent
ce53db6836
commit
f3b26faab1
|
@ -44,6 +44,7 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
return ret;
|
||||
}
|
||||
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||
is_null_input_ = CHECK_NULL_INPUT(inputs[kIndex0]->GetShapeVector());
|
||||
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>());
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -51,6 +52,9 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
template <typename T>
|
||||
bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(input, false);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
|
|
@ -56,6 +56,7 @@ class ReLUFwdGpuKernelMod : public NativeGpuKernelMod {
|
|||
ReLUFwLaunchFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_;
|
||||
int input_length_{0};
|
||||
bool is_null_input_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue