forked from OSSInnovation/mindspore
!6345 [MS][LITE]add determination of malloc/new
Merge pull request !6345 from fuzhiye/tmp
This commit is contained in:
commit
a217deada9
|
@ -20,7 +20,7 @@
|
|||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingParameter *param) {
|
||||
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, float *max_c, int tid, ROIPoolingParameter *param) {
|
||||
int num_rois = param->output_n_;
|
||||
int units = UP_DIV(num_rois, param->thread_num_);
|
||||
int roi_st = tid * units;
|
||||
|
@ -37,11 +37,9 @@ int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingPar
|
|||
int pooled_width = param->pooledW_;
|
||||
const int roi_stride = 5;
|
||||
int roi_ind_st = roi_st * roi_stride;
|
||||
float *max_c = malloc(channels_ * sizeof(float));
|
||||
for (int i = roi_st; i < roi_end; ++i) {
|
||||
int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index
|
||||
if (roi_batch_ind >= batch_size) {
|
||||
free(max_c);
|
||||
return NNACL_ERRCODE_INDEX_OUT_OF_RANGE;
|
||||
}
|
||||
int roi_start_h = (int)roundf(roi[roi_ind_st + 1] * scale); // top-left x1
|
||||
|
@ -93,6 +91,5 @@ int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingPar
|
|||
}
|
||||
roi_ind_st += roi_stride;
|
||||
}
|
||||
free(max_c);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ typedef struct ROIPoolingParameter {
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingParameter *param);
|
||||
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, float *max_c, int tid, ROIPoolingParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
#include "nnacl/int8/leaky_relu_int8.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoLeakReluInt8(int8_t *inputs, int8_t *output_ptr, LeakyReluQuantArg *quant_prelu_parm, int task_id) {
|
||||
int DoLeakReluInt8(int8_t *inputs, int8_t *output_ptr, LeakyReluQuantArg *quant_prelu_parm, QuantArg *input_quant,
|
||||
int task_id) {
|
||||
if (quant_prelu_parm == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
@ -26,10 +27,6 @@ int DoLeakReluInt8(int8_t *inputs, int8_t *output_ptr, LeakyReluQuantArg *quant_
|
|||
const float output_inverse_scale = 1.f / output_scale;
|
||||
int output_dim = quant_prelu_parm->input_dim_;
|
||||
|
||||
QuantArg *input_quant = malloc(sizeof(QuantArg)*output_dim);
|
||||
if (input_quant == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
for (int i = 0; i < output_dim; i++) {
|
||||
input_quant[i].scale_ = quant_prelu_parm->quant_arg.in_args_.scale_;
|
||||
input_quant[i].zp_ = quant_prelu_parm->quant_arg.in_args_.zp_;
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int DoLeakReluInt8(int8_t *inputs, int8_t *output_ptr, LeakyReluQuantArg *quant_Prelu_parm, int task_id);
|
||||
int DoLeakReluInt8(int8_t *inputs, int8_t *output_ptr, LeakyReluQuantArg *quant_Prelu_parm, QuantArg *input_quant,
|
||||
int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
#include "nnacl/minimal_filtering_generator.h"
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include <stdlib.h>
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
void Polynomial(float *interval, float *m, int degree) {
|
||||
for (int i = 0; i < degree; ++i) {
|
||||
|
@ -53,9 +54,13 @@ void ResidueMatrix(float *interval, float *b, int row, int col) {
|
|||
b[len - 1] = 1;
|
||||
}
|
||||
|
||||
void LT(float *poly_array, float *matrix_lt, int n) {
|
||||
float *coefficient_array = (float *)malloc(n * sizeof(float));
|
||||
float *poly = (float *)malloc(n * sizeof(float));
|
||||
int LT(float *poly_array, float *matrix_lt, int n) {
|
||||
if (n > MAX_LEN) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
float coefficient_array[MAX_LEN]; // n
|
||||
float poly[MAX_LEN]; // n
|
||||
|
||||
Polynomial(poly_array, poly, n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
// get coefficient
|
||||
|
@ -79,8 +84,7 @@ void LT(float *poly_array, float *matrix_lt, int n) {
|
|||
matrix_lt[setp + l] = coefficient_array[l] / poly[i];
|
||||
}
|
||||
} // matrix L row loop
|
||||
free(coefficient_array);
|
||||
free(poly);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void T(float *poly_array, float *matrix_t, int n) {
|
||||
|
@ -99,20 +103,22 @@ void T(float *poly_array, float *matrix_t, int n) {
|
|||
}
|
||||
}
|
||||
|
||||
void B(float *poly_array, float *matrix_b, int in_unit) {
|
||||
int B(float *poly_array, float *matrix_b, int in_unit) {
|
||||
memset(matrix_b, 0, in_unit * in_unit * sizeof(float));
|
||||
int n = in_unit - 1;
|
||||
float *matrix_l = (float *)malloc(n * n * sizeof(float));
|
||||
float *matrix_lt = (float *)malloc(n * n * sizeof(float));
|
||||
float *matrix_t = (float *)malloc(n * in_unit * sizeof(float));
|
||||
if ((n * n) > MAX_LEN || (n * in_unit) > MAX_LEN) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
float matrix_l[MAX_LEN]; // n * n
|
||||
float matrix_lt[MAX_LEN]; // n * n
|
||||
float matrix_t[MAX_LEN]; // n * in_unit
|
||||
|
||||
T(poly_array, matrix_t, n);
|
||||
LT(poly_array, matrix_lt, n);
|
||||
MatrixTranspose(matrix_lt, matrix_l, n, n);
|
||||
MatrixMultiply(matrix_l, matrix_t, matrix_b, n, n, in_unit);
|
||||
matrix_b[in_unit * in_unit - 1] = 1;
|
||||
free(matrix_l);
|
||||
free(matrix_lt);
|
||||
free(matrix_t);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void GenerateIntervalArray(float *array, float interval, int degree) {
|
||||
|
@ -146,16 +152,19 @@ void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_
|
|||
}
|
||||
}
|
||||
|
||||
void CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
|
||||
float *matrix_gt, float coefficient, int out_unit, int filter_size) {
|
||||
int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
|
||||
float *matrix_gt, float coefficient, int out_unit, int filter_size) {
|
||||
int in_unit = out_unit + filter_size - 1;
|
||||
int degree = in_unit - 1;
|
||||
float *polynomial_m = malloc(degree * sizeof(float));
|
||||
float *diagonal_matrix = malloc(in_unit * in_unit * sizeof(float));
|
||||
float *inverse_diagonal_matrix = malloc(in_unit * in_unit * sizeof(float));
|
||||
if (degree > MAX_LEN || (in_unit * in_unit) > MAX_LEN || degree > MAX_LEN || (in_unit * filter_size) > MAX_LEN) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
float polynomial_m[MAX_LEN]; // degree
|
||||
float diagonal_matrix[MAX_LEN]; // input_unit * input_unit
|
||||
float inverse_diagonal_matrix[MAX_LEN]; // input_unit * input_unit
|
||||
|
||||
// get diagonal matrix
|
||||
float *interval = malloc(degree * sizeof(float));
|
||||
float interval[MAX_LEN]; // degree
|
||||
GenerateIntervalArray(interval, coefficient, degree);
|
||||
Polynomial(interval, polynomial_m, degree);
|
||||
DiagonalPlusMatrix(polynomial_m, diagonal_matrix, degree);
|
||||
|
@ -185,17 +194,12 @@ void CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *m
|
|||
MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit);
|
||||
|
||||
// get matrix G && GT
|
||||
float *tmp_g = malloc(in_unit * filter_size * sizeof(float));
|
||||
float tmp_g[MAX_LEN]; // in_unit * filter_size
|
||||
ResidueMatrix(interval, matrix_g, in_unit, filter_size);
|
||||
MatrixTranspose(matrix_g, tmp_g, in_unit, filter_size);
|
||||
MatrixMultiply(tmp_g, inverse_diagonal_matrix, matrix_gt, filter_size, in_unit, in_unit);
|
||||
MatrixTranspose(matrix_gt, matrix_g, filter_size, in_unit);
|
||||
|
||||
free(interval);
|
||||
free(polynomial_m);
|
||||
free(diagonal_matrix);
|
||||
free(inverse_diagonal_matrix);
|
||||
free(tmp_g);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
|
|
|
@ -30,11 +30,11 @@ void DiagonalPlusMatrix(float *matrix, float *diagonal_matrix, int degree);
|
|||
|
||||
void ResidueMatrix(float *interval, float *b, int row, int col);
|
||||
|
||||
void LT(float *poly_array, float *matrix_lt, int n);
|
||||
int LT(float *poly_array, float *matrix_lt, int n);
|
||||
|
||||
void T(float *poly_array, float *matrix_t, int n);
|
||||
|
||||
void B(float *poly_array, float *matrix_b, int in_unit);
|
||||
int B(float *poly_array, float *matrix_b, int in_unit);
|
||||
|
||||
void GenerateIntervalArray(float *array, float interval, int degree);
|
||||
|
||||
|
@ -42,7 +42,7 @@ void MatrixTranspose(float *matrix, float *trans_matrix, int row, int col);
|
|||
|
||||
void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n);
|
||||
|
||||
void CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
|
||||
int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
|
||||
float *matrix_gt, float coefficient, int out_unit, int filter_size);
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
|
|
|
@ -73,6 +73,10 @@ int ArgMinMaxBaseCPUKernel::Run() {
|
|||
|
||||
auto in_tensor = in_tensors_.at(0)->shape();
|
||||
auto shape = reinterpret_cast<int *>(malloc(in_tensor.size() * sizeof(int)));
|
||||
if (shape == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc shape failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(shape, in_tensor.data(), in_tensor.size() * sizeof(int));
|
||||
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType;
|
||||
using mindspore::schema::PadMode;
|
||||
|
@ -193,17 +194,17 @@ int ConvolutionBaseCPUKernel::MallocQuantParam() {
|
|||
conv_quant_arg_->input_quant_args_ = reinterpret_cast<QuantArg *>(malloc(input_arg_num * sizeof(QuantArg)));
|
||||
if (conv_quant_arg_->input_quant_args_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_quant_args_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->filter_quant_args_ = reinterpret_cast<QuantArg *>(malloc(filter_arg_num * sizeof(QuantArg)));
|
||||
if (conv_quant_arg_->filter_quant_args_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc filter_quant_args_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->output_quant_args_ = reinterpret_cast<QuantArg *>(malloc(output_arg_num * sizeof(QuantArg)));
|
||||
if (conv_quant_arg_->output_quant_args_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_quant_args_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -261,11 +262,35 @@ int ConvolutionBaseCPUKernel::SetQuantMultiplier() {
|
|||
weight_arg_num = conv_quant_arg_->filter_arg_num_;
|
||||
}
|
||||
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(weight_arg_num * sizeof(double)));
|
||||
if (conv_quant_arg_->real_multiplier_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->real_multiplier_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->left_shift_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->left_shift_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->right_shift_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->right_shift_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(weight_arg_num * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->quant_multiplier_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->quant_multiplier_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
|
||||
if (conv_quant_arg_->out_act_min_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->out_act_min_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
|
||||
if (conv_quant_arg_->out_act_max_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->out_act_max_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
for (int i = 0; i < weight_arg_num; ++i) {
|
||||
const double in_scale =
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Pooling;
|
||||
|
||||
|
@ -31,8 +32,20 @@ namespace mindspore::kernel {
|
|||
int PoolingBaseCPUKernel::SetQuantParam() {
|
||||
// per tensor init
|
||||
pooling_quant_arg_ = reinterpret_cast<QuantArg **>(malloc(2 * sizeof(QuantArg *)));
|
||||
if (pooling_quant_arg_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc pooling_quant_arg failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
pooling_quant_arg_[0] = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg)));
|
||||
if (pooling_quant_arg_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc pooling_quant_arg[0] failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
pooling_quant_arg_[1] = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg)));
|
||||
if (pooling_quant_arg_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc pooling_quant_arg[1] failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto *input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto in_quant_arg = input_tensor->GetQuantParams();
|
||||
auto *out_tensor = out_tensors_.at(kOutputIndex);
|
||||
|
|
|
@ -24,6 +24,10 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext
|
|||
if (data_type == kNumberTypeFloat32) {
|
||||
auto ele_num = input->ElementsNum();
|
||||
fp16_data = reinterpret_cast<float16_t *>(ctx->allocator->Malloc(ele_num * sizeof(float16_t)));
|
||||
if (fp16_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_data failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ori_data = reinterpret_cast<float *>(input->MutableData());
|
||||
Float32ToFloat16(ori_data, fp16_data, ele_num);
|
||||
} else {
|
||||
|
@ -38,6 +42,10 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx)
|
|||
if (data_type == kNumberTypeFloat32) {
|
||||
auto ele_num = output->ElementsNum();
|
||||
fp16_data = reinterpret_cast<float16_t *>(ctx->allocator->Malloc(ele_num * sizeof(float16_t)));
|
||||
if (fp16_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_data failed.";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
fp16_data = reinterpret_cast<float16_t *>(output->MutableData());
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
|
|||
|
||||
size_t tmp_size = oC8 * C8NUM * iC8 * C8NUM * kernel_plane * sizeof(float16_t);
|
||||
auto tmp_addr = reinterpret_cast<float16_t *>(malloc(tmp_size));
|
||||
if (tmp_addr == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_addr failed.";
|
||||
return;
|
||||
}
|
||||
memset(tmp_addr, 0, tmp_size);
|
||||
|
||||
PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param);
|
||||
|
|
|
@ -45,15 +45,35 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
|
|||
int oc_block_num = UP_DIV(channel_out, oc_block);
|
||||
|
||||
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
|
||||
if (matrix_g_data_fp16 == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_g_data_fp16 failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
|
||||
if (matrix_gt_data_fp16 == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_);
|
||||
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_);
|
||||
|
||||
// trans_filter = G*g*GT (g represents weight_data)
|
||||
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
|
||||
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t)));
|
||||
if (tmp_weight_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_weight_data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
|
||||
if (tmp_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t)));
|
||||
if (trans_out_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_out_data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block};
|
||||
std::vector<int> strides;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
|
@ -180,7 +200,15 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
memset(trans_weight_, 0, trans_matrix_data_size);
|
||||
auto *matrix_g = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (matrix_g == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_g failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto matrix_gt = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (matrix_gt == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_gt failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = MallocTransformMatrices();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Malloc transform matrices failed.";
|
||||
|
@ -191,7 +219,11 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
float matrix_at[MAX_LEN];
|
||||
float matrix_b[MAX_LEN];
|
||||
float matrix_bt[MAX_LEN];
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, output_unit_, kernel_unit_);
|
||||
ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, output_unit_, kernel_unit_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
}
|
||||
Float32ToFloat16(matrix_a, matrix_a_, input_unit_ * output_unit_);
|
||||
Float32ToFloat16(matrix_at, matrix_at_, input_unit_ * output_unit_);
|
||||
Float32ToFloat16(matrix_b, matrix_b_, input_unit_ * input_unit_);
|
||||
|
|
|
@ -117,6 +117,10 @@ int MatmulFP16CPUKernel::ReSize() {
|
|||
if (out_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||
output_ptr_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(params_->batch * params_->row_ * params_->col_ * sizeof(float16_t)));
|
||||
if (output_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_ptr_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -83,6 +83,10 @@ int SplitFp16CPUKernel::Run() {
|
|||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
||||
input_ptr_ =
|
||||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t)));
|
||||
if (input_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_ptr_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Float32ToFloat16(reinterpret_cast<float *>(in_tensor->MutableData()), input_ptr_, in_tensor->ElementsNum());
|
||||
} else {
|
||||
input_ptr_ = reinterpret_cast<float16_t *>(in_tensor->MutableData());
|
||||
|
@ -91,6 +95,10 @@ int SplitFp16CPUKernel::Run() {
|
|||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
||||
output_ptr_[i] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(out_tensors_.at(i)->ElementsNum() * sizeof(float16_t)));
|
||||
if (output_ptr_[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_ptr_[" << i << "]" << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Float32ToFloat16(reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()), output_ptr_[i],
|
||||
out_tensors_.at(i)->ElementsNum());
|
||||
} else {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
|
@ -40,8 +41,20 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
|
|||
// trans_filter = G*g*GT (g represents weight_data)
|
||||
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
|
||||
auto tmp_weight_data = reinterpret_cast<float *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (tmp_weight_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_weight_data failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto tmp_data = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
|
||||
if (tmp_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto trans_out_data = reinterpret_cast<float *>(malloc(input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (trans_out_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_out_data failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block};
|
||||
std::vector<int> strides;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
|
@ -110,9 +123,8 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
|
||||
if (trans_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
memset(trans_weight_, 0, trans_matrix_data_size);
|
||||
|
||||
float matrix_g[64];
|
||||
|
@ -121,10 +133,14 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
float matrix_at[64];
|
||||
float matrix_b[64];
|
||||
float matrix_bt[64];
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_);
|
||||
|
||||
auto ret =
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
}
|
||||
auto weight_data = reinterpret_cast<float *>(filter_tensor->MutableData());
|
||||
auto ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block);
|
||||
ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "winograd filter transfrom failed.";
|
||||
return ret;
|
||||
|
@ -133,6 +149,10 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
// init bias
|
||||
size_t new_bias_size = oc4 * C4NUM * sizeof(float);
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias_addr = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
|
||||
|
@ -162,14 +182,14 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
gemm_out_ = reinterpret_cast<float *>(
|
||||
|
@ -186,14 +206,14 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float)));
|
||||
if (tmp_out_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_out_data_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
tmp_data_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (tmp_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
col_buffer_ =
|
||||
|
|
|
@ -88,13 +88,27 @@ int DetectionPostProcessCPUKernel::Run() {
|
|||
parameter->nms_candidate_ = context_->allocator->Malloc(num_boxes * sizeof(uint8_t));
|
||||
parameter->selected_ = context_->allocator->Malloc(num_boxes * sizeof(int));
|
||||
parameter->score_with_class_ = context_->allocator->Malloc(num_boxes * sizeof(ScoreWithIndex));
|
||||
if (!parameter->decoded_boxes_ || !parameter->nms_candidate_ || !parameter->selected_ ||
|
||||
!parameter->score_with_class_) {
|
||||
MS_LOG(ERROR) << "malloc parameter->decoded_boxes_ || parameter->nms_candidate_ || parameter->selected_ || "
|
||||
"parameter->score_with_class_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (parameter->use_regular_nms_) {
|
||||
parameter->score_with_class_all_ =
|
||||
context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(ScoreWithIndex));
|
||||
parameter->indexes_ = context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(int));
|
||||
if (!parameter->score_with_class_all_ || !parameter->indexes_) {
|
||||
MS_LOG(ERROR) << "malloc parameter->score_with_class_all_ || parameter->indexes_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
parameter->score_with_class_all_ =
|
||||
context_->allocator->Malloc((num_boxes * parameter->num_classes_) * sizeof(ScoreWithIndex));
|
||||
if (!parameter->score_with_class_all_) {
|
||||
MS_LOG(ERROR) << "malloc parameter->score_with_class_all_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, parameter->anchors_, output_boxes,
|
||||
output_classes, output_scores, output_num, parameter);
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_ROIPooling;
|
||||
|
||||
|
@ -37,6 +38,10 @@ int ROIPoolingCPUKernel::Init() {
|
|||
}
|
||||
|
||||
int ROIPoolingCPUKernel::ReSize() {
|
||||
if (max_c_ != nullptr) {
|
||||
free(max_c_);
|
||||
max_c_ = nullptr;
|
||||
}
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
auto out_shape = out_tensors_.front()->shape();
|
||||
int ndims = in_shape.size();
|
||||
|
@ -60,11 +65,16 @@ int ROIPoolingCPUKernel::ReSize() {
|
|||
param_->out_strides_[i] = out_shape[i + 1] * param_->out_strides_[i + 1];
|
||||
}
|
||||
param_->thread_num_ = MSMIN(param_->op_parameter_.thread_num_, out_shape[0]);
|
||||
max_c_ = reinterpret_cast<float *>(malloc(param_->input_c_ * sizeof(float)));
|
||||
if (max_c_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc max_c failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ROIPoolingCPUKernel::DoExecute(int task_id) {
|
||||
auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, task_id, param_);
|
||||
auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, max_c_, task_id, param_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
|
|
@ -29,7 +29,12 @@ class ROIPoolingCPUKernel : public LiteKernel {
|
|||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
param_ = reinterpret_cast<ROIPoolingParameter *>(parameter);
|
||||
}
|
||||
~ROIPoolingCPUKernel() override = default;
|
||||
~ROIPoolingCPUKernel() override {
|
||||
if (max_c_ != nullptr) {
|
||||
free(max_c_);
|
||||
max_c_ = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
@ -40,6 +45,7 @@ class ROIPoolingCPUKernel : public LiteKernel {
|
|||
float *in_ptr_;
|
||||
float *out_ptr_;
|
||||
float *roi_ptr_;
|
||||
float *max_c_ = nullptr;
|
||||
ROIPoolingParameter *param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -60,7 +60,15 @@ int TransposeCPUKernel::ReSize() {
|
|||
free(this->out_shape_);
|
||||
}
|
||||
in_shape_ = reinterpret_cast<int *>(malloc(in_shape.size() * sizeof(int)));
|
||||
if (in_shape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc in_shape_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_shape_ = reinterpret_cast<int *>(malloc(out_shape.size() * sizeof(int)));
|
||||
if (out_shape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc out_shape_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(in_shape_, in_shape.data(), in_shape.size() * sizeof(int));
|
||||
memcpy(out_shape_, out_shape.data(), in_shape.size() * sizeof(int));
|
||||
return RET_OK;
|
||||
|
|
|
@ -91,6 +91,10 @@ int QuantizedAddCPUKernel::Run() {
|
|||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
input1_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
if (!input0_data_ || !input1_data_) {
|
||||
MS_LOG(ERROR) << "malloc input0_data_ || input1_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->MutableData()),
|
||||
static_cast<uint8_t *>(in_tensors_.at(1)->MutableData()),
|
||||
|
|
|
@ -68,9 +68,18 @@ int ConcatInt8CPUKernel::ReSize() {
|
|||
auto input_num = in_tensors_.size();
|
||||
concat_param_->input_num_ = input_num;
|
||||
concat_param_->input_shapes_ = reinterpret_cast<const int **>(malloc(sizeof(int *) * input_num));
|
||||
if (concat_param_->input_shapes_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc concat_param_->input_shapes_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto in_shape = in_tensors_.at(i)->shape();
|
||||
concat_param_->input_shapes_[i] = reinterpret_cast<int *>(malloc(in_shape.size() * sizeof(int)));
|
||||
if (concat_param_->input_shapes_[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc concat_param_->input_shapes_[" << i << "]"
|
||||
<< " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(reinterpret_cast<void *>(const_cast<int *>(concat_param_->input_shapes_[i])), in_shape.data(),
|
||||
sizeof(int) * in_shape.size());
|
||||
}
|
||||
|
@ -85,6 +94,10 @@ int ConcatInt8CPUKernel::ReSize() {
|
|||
auto out_shape = output_tensor->shape();
|
||||
size_t output_dim = out_shape.size();
|
||||
concat_param_->output_shapes_ = reinterpret_cast<int *>(malloc(output_dim * sizeof(int)));
|
||||
if (concat_param_->output_shapes_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc concat_param_->output_shapes_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(reinterpret_cast<void *>(const_cast<int *>(concat_param_->output_shapes_)), output_tensor->shape().data(),
|
||||
sizeof(int) * output_dim);
|
||||
|
||||
|
|
|
@ -153,7 +153,15 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
|
|||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto channel = conv_param_->input_channel_;
|
||||
input_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
if (input_scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_sacle_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
input_zp_ = reinterpret_cast<int8_t *>(malloc(channel * sizeof(int8_t)));
|
||||
if (input_zp_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_zp_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto input_quant_arg = input_tensor->GetQuantParams().front();
|
||||
|
@ -170,7 +178,15 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
|
|||
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
if (output_scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_scale_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_zp_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (output_zp_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_zp_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (output_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto output_quant_arg = output_tensor->GetQuantParams().front();
|
||||
|
@ -186,13 +202,41 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
|
|||
}
|
||||
|
||||
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(channel * sizeof(double)));
|
||||
if (conv_quant_arg_->real_multiplier_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->real_multiplier_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->left_shift_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->left_shift_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->right_shift_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->right_shift_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->quant_multiplier_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->quant_multiplier_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->out_act_min_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->out_act_min_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (conv_quant_arg_->out_act_max_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc conv_quant_arg_->out_act_max_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
weight_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
if (weight_scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc weight_scale_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
if (weight_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
|
|
|
@ -85,6 +85,10 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
|
|||
}
|
||||
memset(packed_weight_, 0, pack_weight_size);
|
||||
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * output_channel));
|
||||
if (weight_sum == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc weight_sum failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < output_channel; i++) weight_sum[i] = 0;
|
||||
PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum);
|
||||
|
||||
|
@ -192,6 +196,10 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
|
|||
}
|
||||
memset(packed_weight_, 0, pack_weight_size);
|
||||
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * output_channel));
|
||||
if (weight_sum == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc weight_sum failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < output_channel; i++) weight_sum[i] = 0;
|
||||
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -74,6 +75,10 @@ LeakyReluInt8CPUKernel::~LeakyReluInt8CPUKernel() {
|
|||
free(quant_prelu_parm_.slope_);
|
||||
quant_prelu_parm_.slope_ = nullptr;
|
||||
}
|
||||
if (input_quant_ != nullptr) {
|
||||
free(input_quant_);
|
||||
input_quant_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int LeakyReluInt8CPUKernel::ReSize() {
|
||||
|
@ -81,10 +86,18 @@ int LeakyReluInt8CPUKernel::ReSize() {
|
|||
auto *out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto input_dim = input_tensor->shape().size();
|
||||
MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE);
|
||||
if (input_quant_ != nullptr) {
|
||||
free(input_quant_);
|
||||
input_quant_ = nullptr;
|
||||
}
|
||||
quant_prelu_parm_.input_dim_ = input_dim;
|
||||
quant_prelu_parm_.element_num = in_tensors_[0]->Size();
|
||||
quant_prelu_parm_.in_shape_ = input_tensor->shape().data();
|
||||
quant_prelu_parm_.out_shape_ = out_tensor->shape().data();
|
||||
input_quant_ = static_cast<QuantArg *>(malloc(sizeof(QuantArg) * input_dim));
|
||||
if (input_quant_ == nullptr) {
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -106,7 +119,7 @@ int LeakyReluInt8CPUKernel::DoExecute(int task_id) {
|
|||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
int8_t *input_data = reinterpret_cast<int8_t *>(input_tensor->MutableData());
|
||||
int8_t *output_data = reinterpret_cast<int8_t *>(out_tensor->MutableData());
|
||||
auto ret = DoLeakReluInt8(input_data, output_data, &quant_prelu_parm_, task_id);
|
||||
auto ret = DoLeakReluInt8(input_data, output_data, &quant_prelu_parm_, input_quant_, task_id);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "DoLeakReluInt8 failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -39,6 +39,7 @@ class LeakyReluInt8CPUKernel : public LeakyReluBaseCPUKernel {
|
|||
|
||||
private:
|
||||
LeakyReluQuantArg quant_prelu_parm_;
|
||||
QuantArg *input_quant_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -76,7 +76,10 @@ int MulInt8CPUKernel::Run() {
|
|||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
input1_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
|
||||
if (!input0_data_ || !input1_data_) {
|
||||
MS_LOG(ERROR) << "malloc input0_data_ || input1_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ArithmeticParameter tile_para;
|
||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
||||
|
|
|
@ -34,6 +34,10 @@ int SqueezeInt8CPUKernel::Init() {
|
|||
return init_ret;
|
||||
}
|
||||
quant_Squeeze_parm_ = new (std::nothrow) SqueezeQuantArg;
|
||||
if (quant_Squeeze_parm_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new quant_Squeeze_parm_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_num = in_tensors_.size();
|
||||
quant_Squeeze_parm_->input_num_ = input_num;
|
||||
quant_Squeeze_parm_->input_sizes_ = reinterpret_cast<int *>(malloc(sizeof(int) * input_num));
|
||||
|
@ -115,6 +119,10 @@ int SqueezeInt8CPUKernel::ReSize() {
|
|||
quant_Squeeze_parm_->output_size_ = output_size;
|
||||
|
||||
quant_Squeeze_parm_->output_shape_ = new int[output_size];
|
||||
if (quant_Squeeze_parm_->output_shape_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new quant_Squeeze_parm_->output_shape_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
::memcpy(quant_Squeeze_parm_->output_shape_, output_shape.data(), sizeof(int) * output_size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -127,9 +135,18 @@ int SqueezeInt8CPUKernel::Run() {
|
|||
}
|
||||
auto input_dim = quant_Squeeze_parm_->input_num_;
|
||||
int8_t **inputs_array = reinterpret_cast<int8_t **>(malloc(sizeof(int8_t *) * input_dim));
|
||||
if (inputs_array == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc inputs_array failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < input_dim; i++) {
|
||||
auto input_size = quant_Squeeze_parm_->input_sizes_[i];
|
||||
inputs_array[i] = reinterpret_cast<int8_t *>(malloc(sizeof(int8_t) * input_size));
|
||||
if (inputs_array[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc inputs_array[" << i << "]"
|
||||
<< " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_type = in_tensors_[i]->data_type();
|
||||
if (input_type == kNumberTypeUInt8) {
|
||||
uint8_t *input_tmp = reinterpret_cast<uint8_t *>(in_tensors_[i]->MutableData());
|
||||
|
|
Loading…
Reference in New Issue