diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc index 5a0dc32a710..71aa3515d59 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc @@ -189,7 +189,8 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { size *= input_shape[j]; } } - int32_t *buffer = reinterpret_cast(malloc(size * sizeof(int32_t))); + MS_ASSERT(context_->allocator != nullptr); + int32_t *buffer = reinterpret_cast(context_->allocator->Malloc(size * sizeof(int32_t))); if (buffer == nullptr) { MS_LOG(ERROR) << "Malloc data failed."; return RET_ERROR; @@ -199,7 +200,7 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { } auto input = in_tensors_.at(0); - begin_src_data_ = reinterpret_cast(malloc(sizeof(int32_t) * input->ElementsNum())); + begin_src_data_ = reinterpret_cast(context_->allocator->Malloc(sizeof(int32_t) * input->ElementsNum())); if (begin_src_data_ == nullptr) { return RET_NULL_PTR; } @@ -210,6 +211,32 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() { return RET_OK; } +void ReduceInt8CPUKernel::FreeTmpBuffer() { + for (auto buffer : data_buffers_) { + if (buffer != nullptr) { + MS_ASSERT(context_->allocator != nullptr); + context_->allocator->Free(buffer); + buffer = nullptr; + } + } + data_buffers_.clear(); + + if (begin_src_data_ != nullptr) { + MS_ASSERT(context_->allocator != nullptr); + context_->allocator->Free(begin_src_data_); + begin_src_data_ = nullptr; + } +} + +int ReduceInt8CPUKernel::ReSize() { + FreeTmpBuffer(); + auto ret = MallocTmpBuffer(); + if (ret != RET_OK) { + FreeTmpBuffer(); + } + return ret; +} + int ReduceInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto reduce = reinterpret_cast(cdata); auto error_code = reduce->CallReduceUnit(task_id); @@ -261,6 +288,7 @@ int ReduceInt8CPUKernel::Run() { axis_size_ = tmp_shape_[axis]; auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_); if (error_code != RET_OK) { + FreeTmpBuffer(); MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; return RET_ERROR; } @@ -298,14 +326,11 @@ int ReduceInt8CPUKernel::Run() { auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; + FreeTmpBuffer(); return RET_ERROR; } - if (begin_src_data_ != nullptr) { - free(begin_src_data_); - begin_src_data_ = nullptr; - } - + FreeTmpBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h index 800c8d69b37..e4bbdb3e54e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h @@ -40,13 +40,6 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel { const mindspore::lite::PrimitiveC *primitive) : ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {} ~ReduceInt8CPUKernel() { - for (auto i = 0; i < data_buffers_.size(); i++) { - int32_t *buffer = data_buffers_[i]; - if (buffer != nullptr) { - free(buffer); - buffer = nullptr; - } - } for (auto qm : mean_multipliers_) { delete qm; qm = nullptr; @@ -64,7 +57,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel { } int Init() override; - int ReSize() override { return 0; }; + int ReSize() override; int Run() override; int CallReduceUnit(int task_id); int ReduceLastAxis(int task_id); @@ -74,6 +67,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel { private: int MallocTmpBuffer(); + void FreeTmpBuffer(); int CalculateQuantArgs(); private: