diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc index 97c35a130a2..5a0a9c01c5e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc @@ -46,8 +46,6 @@ int ScatterNdUpdateCPUKernel::ReSize() { auto input = in_tensors_.at(kScatterUpdateInputIndex); auto indices = in_tensors_.at(kScatterIndicesIndex); auto update = in_tensors_.at(kScatterUpdateIndex); - auto output = out_tensors_.front(); - output_ptr_ = output->data(); // check indices shape int input_rank = static_cast(input->shape().size()); @@ -87,15 +85,32 @@ int ScatterNdUpdateCPUKernel::ReSize() { param_->num_unit *= update_shape.at(i); } - int *indices_ptr = reinterpret_cast(indices->MutableData()); - MS_ASSERT(indices_ptr != nullptr); + auto indices_ptr = indices->data(); + if (indices_ptr == nullptr) { + return RET_OK; + } output_unit_offsets_.clear(); - for (int i = 0; i < param_->num_unit; i++) { - int tmp_stride = 0; - for (int j = 0; j < indice_unit_rank; j++) { - tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size; + if (indices->data_type() == kNumberTypeInt || indices->data_type() == kNumberTypeInt32) { + auto indices_data = reinterpret_cast(indices_ptr); + for (int i = 0; i < param_->num_unit; i++) { + int tmp_stride = 0; + for (int j = 0; j < indice_unit_rank; j++) { + tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size; + } + output_unit_offsets_.push_back(tmp_stride); } - output_unit_offsets_.push_back(tmp_stride); + } else if (indices->data_type() == kNumberTypeInt64) { + auto indices_data = reinterpret_cast(indices_ptr); + for (int i = 0; i < param_->num_unit; i++) { + int tmp_stride = 0; + for (int j = 0; j < indice_unit_rank; j++) { + tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size; + } + output_unit_offsets_.push_back(tmp_stride); + } + } else { + MS_LOG(ERROR) << "Unsupported data type for indices tensor."; + return RET_ERROR; } return RET_OK; } @@ -126,7 +141,7 @@ int ScatterNdUpdateCPUKernel::Run() { auto in_tensor = in_tensors().front(); auto out_tensor = out_tensors().front(); if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() || - op_parameter_->is_train_session_) { + in_tensor->own_data() == false || op_parameter_->is_train_session_) { memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size()); } else { out_tensor->FreeData(); @@ -134,7 +149,6 @@ int ScatterNdUpdateCPUKernel::Run() { in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count()); out_tensor->set_data(in_tensor->data()); out_tensor->set_own_data(in_tensor->own_data()); - output_ptr_ = out_tensor->data(); } auto indices = in_tensors_.at(kScatterIndicesIndex); if (!indices->IsConst() && ReSize() != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h index 67b3dd875c5..45e84edf356 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h @@ -39,7 +39,6 @@ class ScatterNdUpdateCPUKernel : public InnerKernel { private: ScatterNDParameter *param_ = nullptr; - void *output_ptr_ = nullptr; std::vector output_unit_offsets_; }; } // namespace mindspore::kernel