forked from mindspore-Ecosystem/mindspore
reduce int8 add ReSize
This commit is contained in:
parent
5686315199
commit
f7649750ee
|
@ -189,7 +189,8 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
|
|||
size *= input_shape[j];
|
||||
}
|
||||
}
|
||||
int32_t *buffer = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
int32_t *buffer = reinterpret_cast<int32_t *>(context_->allocator->Malloc(size * sizeof(int32_t)));
|
||||
if (buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc data failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -199,7 +200,7 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
|
|||
}
|
||||
|
||||
auto input = in_tensors_.at(0);
|
||||
begin_src_data_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * input->ElementsNum()));
|
||||
begin_src_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * input->ElementsNum()));
|
||||
if (begin_src_data_ == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -210,6 +211,32 @@ int ReduceInt8CPUKernel::MallocTmpBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void ReduceInt8CPUKernel::FreeTmpBuffer() {
|
||||
for (auto buffer : data_buffers_) {
|
||||
if (buffer != nullptr) {
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
context_->allocator->Free(buffer);
|
||||
buffer = nullptr;
|
||||
}
|
||||
}
|
||||
data_buffers_.clear();
|
||||
|
||||
if (begin_src_data_ != nullptr) {
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
context_->allocator->Free(begin_src_data_);
|
||||
begin_src_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ReduceInt8CPUKernel::ReSize() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ReduceInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto reduce = reinterpret_cast<ReduceInt8CPUKernel *>(cdata);
|
||||
auto error_code = reduce->CallReduceUnit(task_id);
|
||||
|
@ -261,6 +288,7 @@ int ReduceInt8CPUKernel::Run() {
|
|||
axis_size_ = tmp_shape_[axis];
|
||||
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
|
||||
if (error_code != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -298,14 +326,11 @@ int ReduceInt8CPUKernel::Run() {
|
|||
auto error_code = LiteBackendParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (begin_src_data_ != nullptr) {
|
||||
free(begin_src_data_);
|
||||
begin_src_data_ = nullptr;
|
||||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,13 +40,6 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
|
|||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {}
|
||||
~ReduceInt8CPUKernel() {
|
||||
for (auto i = 0; i < data_buffers_.size(); i++) {
|
||||
int32_t *buffer = data_buffers_[i];
|
||||
if (buffer != nullptr) {
|
||||
free(buffer);
|
||||
buffer = nullptr;
|
||||
}
|
||||
}
|
||||
for (auto qm : mean_multipliers_) {
|
||||
delete qm;
|
||||
qm = nullptr;
|
||||
|
@ -64,7 +57,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
|
|||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; };
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int CallReduceUnit(int task_id);
|
||||
int ReduceLastAxis(int task_id);
|
||||
|
@ -74,6 +67,7 @@ class ReduceInt8CPUKernel : public ReduceBaseCPUKernel {
|
|||
|
||||
private:
|
||||
int MallocTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
int CalculateQuantArgs();
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue