diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index 3a2a77497f2..0894ba69cc1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -37,6 +37,10 @@ int GatherCPUKernel::Init() { return ReSize(); } +GatherCPUKernel::~GatherCPUKernel() { + context_->allocator->Free(indices_data_); +} + int GatherCPUKernel::ReSize() { return RET_OK; } int GatherCPUKernel::DoGather(int task_id) { @@ -45,7 +49,6 @@ int GatherCPUKernel::DoGather(int task_id) { auto out_tensor = out_tensors_.at(0); auto input_ptr = reinterpret_cast(input_tensor->Data()); - auto indices_ptr = reinterpret_cast(indices_tensor->Data()); auto output_ptr = reinterpret_cast(out_tensor->Data()); auto input_int32 = reinterpret_cast(input_tensor->Data()); @@ -57,13 +60,6 @@ int GatherCPUKernel::DoGather(int task_id) { auto axis = (reinterpret_cast(op_parameter_))->axis_; const int limit = in_shape[axis]; - for (int i = 0; i < indices_element_size; ++i) { - indices_data_[i] = static_cast(indices_ptr[i]); - if (indices_data_[i] >= limit) { - MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]"; - return RET_ERROR; - } - } int outer_size = 1, inner_size = 1; for (int i = 0; i < axis; ++i) { @@ -106,12 +102,23 @@ int GatherCPUKernel::Run() { } auto indices_tensor = in_tensors_.at(1); - indices_data_ = reinterpret_cast(context_->allocator->Malloc(indices_tensor->ElementsNum() * sizeof(int))); + indices_data_ = reinterpret_cast(context_->allocator->Malloc(indices_tensor->Size())); if (indices_data_ == nullptr) { MS_LOG(ERROR) << "Memory allocation failed"; - context_->allocator->Free(indices_data_); return RET_ERROR; } + auto in_shape = in_tensors_.at(0)->shape(); + int indices_element_size = indices_tensor->ElementsNum(); + auto axis = (reinterpret_cast(op_parameter_))->axis_;; + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + const int limit = in_shape[axis]; + for (int i = 0; i < indices_element_size; ++i) { + indices_data_[i] = static_cast(indices_ptr[i]); + if (indices_data_[i] >= limit) { + MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]"; + return RET_ERROR; + } + } int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h index 7af9703e541..b492c4c179b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h @@ -28,7 +28,7 @@ class GatherCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~GatherCPUKernel() override = default; + ~GatherCPUKernel() override; int Init() override; int ReSize() override;