forked from mindspore-Ecosystem/mindspore
!48707 [MSLITE] Fix empty tensor problem in relu gpu op
Merge pull request !48707 from zhangyongxian/dev_zhangyongxian_relugpu
This commit is contained in:
commit
e5e8d7bb10
|
@ -44,6 +44,7 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||||
|
is_null_input_ = CHECK_NULL_INPUT(inputs[kIndex0]->GetShapeVector());
|
||||||
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>());
|
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), static_cast<int>(1), std::multiplies<>());
|
||||||
return KRET_OK;
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
@ -51,6 +52,9 @@ int ReLUFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ReLUFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(input, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(input, false);
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
|
@ -56,6 +56,7 @@ class ReLUFwdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
ReLUFwLaunchFunc kernel_func_;
|
ReLUFwLaunchFunc kernel_func_;
|
||||||
static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_;
|
static std::vector<std::pair<KernelAttr, ReLUFwLaunchFunc>> func_list_;
|
||||||
int input_length_{0};
|
int input_length_{0};
|
||||||
|
bool is_null_input_{false};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue