forked from mindspore-Ecosystem/mindspore
!4608 fix concat int8 memory leak
Merge pull request !4608 from zhaozhenlong/lite/issue/concat_int8_mem_leak_fix
This commit is contained in:
commit
60551b1f28
|
@ -46,7 +46,7 @@ class ConcatBaseCPUKernel : public LiteKernel {
|
|||
int thread_count_;
|
||||
int axis_;
|
||||
const Context *ctx_;
|
||||
ConcatParameter *concat_param_;
|
||||
ConcatParameter *concat_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -28,9 +28,15 @@ namespace mindspore::kernel {
|
|||
|
||||
int ConcatInt8CPUKernel::Init() {
|
||||
ConcatBaseCPUKernel::Init();
|
||||
concat_param_->input_shapes_ = nullptr;
|
||||
auto input_num = in_tensors_.size();
|
||||
input_data_ = reinterpret_cast<int8_t **>(malloc(sizeof(int8_t *) * input_num));
|
||||
if (input_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Null pointer reference: inputs_array.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
concat_param_->quant_arg_.in_args_ =
|
||||
reinterpret_cast<QuantArg *>(ctx_->allocator->Malloc(sizeof(QuantArg) * input_num));
|
||||
reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg) * input_num));
|
||||
if (concat_param_->quant_arg_.in_args_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->in_quant_args_.";
|
||||
return RET_ERROR;
|
||||
|
@ -61,11 +67,11 @@ int ConcatInt8CPUKernel::ReSize() {
|
|||
return ret;
|
||||
}
|
||||
if (concat_param_->input_shapes_ != nullptr) {
|
||||
ctx_->allocator->Free(concat_param_->input_shapes_);
|
||||
// free(concat_param_->input_shapes_);
|
||||
}
|
||||
auto input_num = in_tensors_.size();
|
||||
concat_param_->input_num_ = input_num;
|
||||
concat_param_->input_shapes_ = reinterpret_cast<const int **>(ctx_->allocator->Malloc(sizeof(int *) * input_num));
|
||||
concat_param_->input_shapes_ = reinterpret_cast<const int **>(malloc(sizeof(int *) * input_num));
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
concat_param_->input_shapes_[i] = reinterpret_cast<const int *>(in_tensors_.at(i)->shape().data());
|
||||
}
|
||||
|
@ -96,11 +102,7 @@ int ConcatInt8CPUKernel::Run() {
|
|||
auto input_num = concat_param_->input_num_;
|
||||
count_unit_ = thread_count_ > 1 ? UP_DIV(before_axis_size, thread_count_) : before_axis_size;
|
||||
concat_param_->count_unit_ = count_unit_;
|
||||
input_data_ = reinterpret_cast<int8_t **>(ctx_->allocator->Malloc(sizeof(int8_t *) * input_num));
|
||||
if (input_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Null pointer reference: inputs_array.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
input_data_[i] = static_cast<int8_t *>(in_tensors_.at(i)->Data());
|
||||
}
|
||||
|
@ -108,10 +110,6 @@ int ConcatInt8CPUKernel::Run() {
|
|||
|
||||
ret = LiteBackendParallelLaunch(ConcatInt8Run, this, thread_count_);
|
||||
|
||||
ctx_->allocator->Free(input_data_);
|
||||
ctx_->allocator->Free(concat_param_->input_shapes_);
|
||||
ctx_->allocator->Free(concat_param_->quant_arg_.in_args_);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,17 @@ class ConcatInt8CPUKernel : public ConcatBaseCPUKernel {
|
|||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConcatBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConcatInt8CPUKernel() override {}
|
||||
~ConcatInt8CPUKernel() override {
|
||||
if (input_data_ != nullptr) {
|
||||
free(input_data_);
|
||||
}
|
||||
if (concat_param_->input_shapes_ != nullptr) {
|
||||
free(concat_param_->input_shapes_);
|
||||
}
|
||||
if (concat_param_->quant_arg_.in_args_ != nullptr) {
|
||||
free(concat_param_->quant_arg_.in_args_);
|
||||
}
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -35,7 +35,7 @@ int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQ
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -54,7 +54,7 @@ int Int8ElementRound(int8_t *input, int8_t *output, int element_size, ArithSelfQ
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -73,7 +73,7 @@ int Int8ElementCeil(int8_t *input, int8_t *output, int element_size, ArithSelfQu
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -92,7 +92,7 @@ int Int8ElementAbs(int8_t *input, int8_t *output, int element_size, ArithSelfQua
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -111,7 +111,7 @@ int Int8ElementSin(int8_t *input, int8_t *output, int element_size, ArithSelfQua
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -130,7 +130,7 @@ int Int8ElementCos(int8_t *input, int8_t *output, int element_size, ArithSelfQua
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -149,7 +149,7 @@ int Int8ElementLog(int8_t *input, int8_t *output, int element_size, ArithSelfQua
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -172,7 +172,7 @@ int Int8ElementSqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQu
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -195,7 +195,7 @@ int Int8ElementRsqrt(int8_t *input, int8_t *output, int element_size, ArithSelfQ
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (output_tmp);
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -230,6 +230,7 @@ void SquareInt8NEON(int8_t *input_data, int8_t *output_data, int64_t element_siz
|
|||
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
|
||||
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
|
||||
vst1_s8(output_data, res_u8_n0);
|
||||
output_data += 8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -253,7 +254,7 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[index] = para.output_activation_min_;
|
||||
} else {
|
||||
output[index] = (output_tmp);
|
||||
output[index] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -22,36 +22,36 @@ void Int8Concat(int8_t **inputs, int8_t *output, ConcatParameter *para, int axis
|
|||
float output_scale = para->quant_arg_.out_args_.scale_;
|
||||
const float output_inverse_scale = 1.f / output_scale;
|
||||
int input_num = para->input_num_;
|
||||
int count_unit_ = para->count_unit_;
|
||||
int after_axis_size = para->after_axis_size;
|
||||
int64_t count_unit_ = para->count_unit_;
|
||||
int64_t after_axis_size = para->after_axis_size;
|
||||
const int *output_shape = para->output_shapes_;
|
||||
int out_copy_size = output_shape[axis] * after_axis_size;
|
||||
QuantArg *input_quant = para->quant_arg_.in_args_;
|
||||
int output_zp = para->quant_arg_.out_args_.zp_;
|
||||
int max_int8 = para->quant_arg_.output_activation_max_;
|
||||
int min_int8 = para->quant_arg_.output_activation_min_;
|
||||
int8_t max_int8 = para->quant_arg_.output_activation_max_;
|
||||
int8_t min_int8 = para->quant_arg_.output_activation_min_;
|
||||
int64_t start = task_id * count_unit_;
|
||||
int64_t end = start + real_dst_count;
|
||||
output += start * out_copy_size;
|
||||
|
||||
for (int k = start; k < end; k++) {
|
||||
for (int i = 0; i < input_num; i++) {
|
||||
const int *input_shape = para->input_shapes_[i];
|
||||
int in_copy_size = input_shape[axis] * after_axis_size;
|
||||
int64_t in_copy_size = input_shape[axis] * after_axis_size;
|
||||
int8_t *input_ptr = inputs[i] + k * in_copy_size;
|
||||
int8_t *output_ptr = output + k * out_copy_size;
|
||||
if (input_quant[i].scale_ == output_scale && input_quant[i].zp_ == output_zp) {
|
||||
memcpy(output_ptr, input_ptr, in_copy_size);
|
||||
memcpy(output, input_ptr, in_copy_size);
|
||||
} else {
|
||||
float scale = input_quant[i].scale_ * output_inverse_scale;
|
||||
float bias = -input_quant[i].zp_ * scale;
|
||||
for (int j = 0; j < in_copy_size; j++) {
|
||||
int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp;
|
||||
if (output_tmp > max_int8) {
|
||||
output_ptr[j] = max_int8;
|
||||
output[j] = max_int8;
|
||||
} else if (output_tmp < min_int8) {
|
||||
output_ptr[j] = min_int8;
|
||||
output[j] = min_int8;
|
||||
} else {
|
||||
output_ptr[j] = (output_tmp);
|
||||
output[j] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ void Crop1D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
|
|||
} else if (output_tmp < para->quant_arg.output_activation_min_) {
|
||||
out_ptr[i] = para->quant_arg.output_activation_min_;
|
||||
} else {
|
||||
out_ptr[i] = output_tmp;
|
||||
out_ptr[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ void Crop2D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
|
|||
} else if (output_tmp < para->quant_arg.output_activation_min_) {
|
||||
out_ptr[i] = para->quant_arg.output_activation_min_;
|
||||
} else {
|
||||
out_ptr[i] = (output_tmp);
|
||||
out_ptr[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ void Crop3D(const int8_t *input, int8_t *output, int task_id, CropParameter *par
|
|||
} else if (output_tmp < para->quant_arg.output_activation_min_) {
|
||||
out_ptr[i] = para->quant_arg.output_activation_min_;
|
||||
} else {
|
||||
out_ptr[i] = (output_tmp);
|
||||
out_ptr[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -225,7 +225,7 @@ void Int8Crop4D(const int8_t *input, int8_t *output, int task_id, CropParameter
|
|||
} else if (output_tmp < para->quant_arg.output_activation_min_) {
|
||||
out_ptr[i] = para->quant_arg.output_activation_min_;
|
||||
} else {
|
||||
out_ptr[i] = (output_tmp);
|
||||
out_ptr[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t
|
|||
} else if (mul_result < para.output_activation_min_) {
|
||||
output_data[index] = para.output_activation_min_;
|
||||
} else {
|
||||
output_data[index] = (mul_result);
|
||||
output_data[index] = (int8_t)mul_result;
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
|
|
@ -33,7 +33,7 @@ void Int8Reshape(int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count,
|
|||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output_ptr[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output_ptr[i] = output_tmp;
|
||||
output_ptr[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ int Int8DoSplit(int8_t *in_data, int8_t **out_data, const int *input_shape, int
|
|||
} else if (output_tmp < param->quant_arg_.output_activation_min_) {
|
||||
dst[j] = param->quant_arg_.output_activation_min_;
|
||||
} else {
|
||||
dst[j] = output_tmp;
|
||||
dst[j] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,8 +53,8 @@ typedef struct ConvQuantArg {
|
|||
typedef struct ConcatQuantArg {
|
||||
QuantArg *in_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
int8_t output_activation_min_;
|
||||
int8_t output_activation_max_;
|
||||
} ConcatQuantArg;
|
||||
|
||||
typedef struct SqueezeQuantArg {
|
||||
|
|
Loading…
Reference in New Issue