!27537 Fixes core dump issue of ScatterAdd and ScatterSub

Merge pull request !27537 from huangbo/master_1209
This commit is contained in:
i-robot 2021-12-15 02:06:16 +00:00 committed by Gitee
commit b16e6bed20
2 changed files with 8 additions and 0 deletions

View File

@ -52,6 +52,7 @@ void ScatterArithmeticCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got "
<< input_shape.size() << ".";
}
input_shape_0 = SizeToInt(input_shape[0]);
input_size_ = 1;
inner_size_ = 1;
if (input_shape.empty()) {
@ -92,6 +93,9 @@ bool ScatterArithmeticCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr>
template <typename T>
void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, const T *updates) const {
for (size_t i = 0; i < indices_size_; i++) {
if (indices[i] >= input_shape_0) {
continue;
}
auto base_index_updates = i * inner_size_;
auto base_index_input = indices[i] * inner_size_;
for (size_t j = 0; j < inner_size_; j++) {
@ -103,6 +107,9 @@ void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, con
template <typename T>
void ScatterArithmeticCPUKernel<T>::ScatterSub(T *input, const int *indices, const T *updates) const {
for (size_t i = 0; i < indices_size_; i++) {
if (indices[i] >= input_shape_0) {
continue;
}
auto base_index_updates = i * inner_size_;
auto base_index_input = indices[i] * inner_size_;
for (size_t j = 0; j < inner_size_; j++) {

View File

@ -49,6 +49,7 @@ class ScatterArithmeticCPUKernel : public CPUKernel {
using TypeComputeFunc = std::function<void(ScatterArithmeticCPUKernel *, T *, const int *, const T *)>;
TypeComputeFunc compute_func_;
int input_shape_0{0};
size_t input_size_{0};
size_t inner_size_{0};
size_t indices_size_{0};