reduce int8 add ReSize

This commit is contained in:
zhaozhenlong 2020-08-19 15:33:29 +08:00
parent 5686315199
commit f7649750ee
2 changed files with 34 additions and 15 deletions

View File

@ -189,7 +189,8 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
size *= input_shape[j];
}
}
int32_t *buffer = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
MS_ASSERT(context_->allocator != nullptr);
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(size * sizeof(int32_t)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
@ -199,7 +200,7 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
}
auto input = in_tensors_.at(0);
begin_src_data_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * input->ElementsNum()));
begin_src_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * input->ElementsNum()));
if (begin_src_data_ == nullptr) {
return RET_NULL_PTR;
}
@ -210,6 +211,32 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
return RET_OK;
}
void ReduceInt8CPUKernel::FreeTmpBuffer() {
for (auto buffer : data_buffers_) {
if (buffer != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(buffer);
buffer = nullptr;
}
}
data_buffers_.clear();
if (begin_src_data_ != nullptr) {
MS_ASSERT(context_->allocator != nullptr);
context_->allocator->Free(begin_src_data_);
begin_src_data_ = nullptr;
}
}
int ReduceInt8CPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
}
return ret;
}
int ReduceInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto reduce = reinterpret_cast<ReduceInt8CPUKernel *>(cdata);
auto error_code = reduce->CallReduceUnit(task_id);
@ -261,6 +288,7 @@ int ReduceInt8CPUKernel::Run() {
axis_size_ = tmp_shape_[axis];
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
return RET_ERROR;
}
@ -298,14 +326,11 @@ int ReduceInt8CPUKernel::Run() {
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
if (begin_src_data_ != nullptr) {
free(begin_src_data_);
begin_src_data_ = nullptr;
}
FreeTmpBuffer();
return RET_OK;
}

View File

@ -40,13 +40,6 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
const mindspore::lite::PrimitiveC *primitive)
: ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {}
~ReduceInt8CPUKernel() {
for (auto i = 0; i < data_buffers_.size(); i++) {
int32_t *buffer = data_buffers_[i];
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
}
}
for (auto qm : mean_multipliers_) {
delete qm;
qm = nullptr;
@ -64,7 +57,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
}
int Init() override;
int ReSize() override { return 0; };
int ReSize() override;
int Run() override;
int CallReduceUnit(int task_id);
int ReduceLastAxis(int task_id);
@ -74,6 +67,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
private:
int MallocTmpBuffer();
void FreeTmpBuffer();
int CalculateQuantArgs();
private: