forked from mindspore-Ecosystem/mindspore
!25504 [MSLITE] Fuzz test
Merge pull request !25504 from wangshaocong/fuzz_test
This commit is contained in:
commit
db85f17d65
|
@ -109,6 +109,9 @@ void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param
|
|||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) {
|
||||
continue;
|
||||
}
|
||||
int input_stride = (input_h * in_w + input_w) * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
|
||||
|
|
|
@ -107,12 +107,14 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
if (param->stride_h_ == 0 || param->stride_w_ == 0) {
|
||||
if (param->stride_h_ <= 0 || param->stride_w_ <= 0) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
||||
param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1];
|
||||
param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2];
|
||||
MS_CHECK_TRUE_RET(param->kernel_h_ == weight_tensor->shape_[1], NNACL_PARAM_INVALID);
|
||||
MS_CHECK_TRUE_RET(param->kernel_w_ == weight_tensor->shape_[2], NNACL_PARAM_INVALID);
|
||||
int ret = ConvInferShape(input_h, input_w, &output_h, &output_w, param);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
|
|
|
@ -47,7 +47,8 @@ int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC
|
|||
int32_t output_h = 0;
|
||||
int32_t output_w = 0;
|
||||
int32_t output_c = GetChannel(weight);
|
||||
if (param->group_ == GetChannel(input) && param->group_ == GetBatch(weight) && 1 == GetChannel(weight)) {
|
||||
MS_CHECK_TRUE_RET(GetChannel(input) == GetBatch(weight), NNACL_ERR);
|
||||
if (param->group_ == GetChannel(input) && 1 == GetChannel(weight)) {
|
||||
output_c = GetBatch(weight); /* depthwise */
|
||||
}
|
||||
|
||||
|
@ -57,8 +58,8 @@ int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC
|
|||
|
||||
int stride_w = param->stride_w_;
|
||||
int stride_h = param->stride_h_;
|
||||
MS_CHECK_FALSE(stride_w == 0, NNACL_ERR);
|
||||
MS_CHECK_FALSE(stride_h == 0, NNACL_ERR);
|
||||
MS_CHECK_FALSE(stride_w <= 0, NNACL_ERR);
|
||||
MS_CHECK_FALSE(stride_h <= 0, NNACL_ERR);
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(input_h, stride_h), NNACL_ERR);
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(input_w, stride_w), NNACL_ERR);
|
||||
|
||||
|
|
|
@ -17,6 +17,37 @@
|
|||
#include "nnacl/infer/lstm_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int CheckInputShapeValid(const TensorC *const *inputs, const LstmParameter *parameter) {
|
||||
const TensorC *input = inputs[FIRST_INPUT];
|
||||
const TensorC *weight_i = inputs[SECOND_INPUT];
|
||||
const TensorC *weight_g = inputs[THIRD_INPUT];
|
||||
const TensorC *bias = inputs[FOURTH_INPUT];
|
||||
const TensorC *cell = inputs[FIFTH_INPUT];
|
||||
int batch = input->shape_[kNHWC_H];
|
||||
int input_size = input->shape_[kNHWC_W];
|
||||
int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM;
|
||||
bool bidirectional = parameter->bidirectional_;
|
||||
if (input->shape_size_ != DIMENSION_3D || weight_i->shape_size_ != DIMENSION_3D) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int bidirection = bidirectional ? C2NUM : C1NUM;
|
||||
MS_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM &&
|
||||
weight_i->shape_[kNHWC_W] == input_size,
|
||||
NNACL_ERR);
|
||||
MS_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM &&
|
||||
weight_g->shape_[kNHWC_W] == hidden_size,
|
||||
NNACL_ERR);
|
||||
MS_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR);
|
||||
if (!bidirectional && cell->shape_size_ == DIMENSION_2D) {
|
||||
MS_CHECK_TRUE_RET(cell->shape_[kNHWC_N] == batch && cell->shape_[kNHWC_H] == hidden_size, NNACL_ERR);
|
||||
} else {
|
||||
MS_CHECK_TRUE_RET(
|
||||
cell->shape_[kNHWC_N] == bidirection && cell->shape_[kNHWC_H] == batch && cell->shape_[kNHWC_W] == hidden_size,
|
||||
NNACL_ERR);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 6, 3);
|
||||
|
@ -37,7 +68,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
|||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
if (input->shape_size_ != 3 || weight_i->shape_size_ != 3) {
|
||||
if (CheckInputShapeValid(inputs, param) != NNACL_OK) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou
|
|||
}
|
||||
param->padding_length = size;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
MS_CHECK_TRUE_RET(((int *)paddings->data_)[i] >= 0, NNACL_INFER_INVALID);
|
||||
param->paddings_[i] = ((int *)paddings->data_)[i];
|
||||
}
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
|
|||
int begin[MAX_SHAPE_SIZE];
|
||||
int size[MAX_SHAPE_SIZE];
|
||||
for (int32_t i = 0; i < param->param_length_; ++i) {
|
||||
MS_CHECK_TRUE_RET(param->axis_[i] < param->param_length_, NNACL_PARAM_INVALID);
|
||||
begin[param->axis_[i]] = param->begin_[i];
|
||||
size[param->axis_[i]] = param->size_[i];
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@ int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
|
|||
} else {
|
||||
split_dim_i = param->split_sizes_[i];
|
||||
}
|
||||
MS_CHECK_TRUE_RET(split_dim_i >= 0, NNACL_ERR);
|
||||
output_shape[split_dim] = split_dim_i;
|
||||
SetShapeArray(outputs[i], output_shape, output_shape_size);
|
||||
SetDataTypeFormat(outputs[i], input);
|
||||
|
|
|
@ -72,6 +72,7 @@ int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
|||
if (input1_data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(data_num <= MAX_TILE_DIM_SIZE, NNACL_ERR);
|
||||
for (int i = 0; i < data_num; i++) {
|
||||
param->multiples_[i] = input1_data[i];
|
||||
}
|
||||
|
|
|
@ -81,9 +81,7 @@ struct OpContext {
|
|||
if (code == MindrtStatus::KINIT) {
|
||||
code = MindrtStatus::KERROR;
|
||||
}
|
||||
for (auto promise : *results_) {
|
||||
promise.SetFailed(code);
|
||||
}
|
||||
results_->front().SetFailed(code);
|
||||
}
|
||||
|
||||
void SetSuccess(int32_t code) {
|
||||
|
|
|
@ -55,6 +55,11 @@ int InnerKernel::PreProcess() {
|
|||
}
|
||||
}
|
||||
|
||||
// check if inputs are valid
|
||||
if (!CheckInputsValid()) {
|
||||
MS_LOG(ERROR) << "The input is not valid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto *output : this->out_tensors()) {
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (registry_data_type_ == kNumberTypeFloat16 && output->data_type() == kNumberTypeFloat32) {
|
||||
|
|
|
@ -62,6 +62,8 @@ class InnerKernel : public Kernel {
|
|||
// called after Run
|
||||
virtual int PostProcess() { return FreeInWorkTensor(); }
|
||||
|
||||
virtual bool CheckInputsValid() const { return true; }
|
||||
|
||||
virtual int FreeInWorkTensor() const {
|
||||
for (auto &in_tensor : this->in_tensors()) {
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "src/common/prim_util.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/tensor.h"
|
||||
#ifdef ENABLE_V0
|
||||
#include "src/ops/compat/compat_register.h"
|
||||
#endif
|
||||
|
@ -281,6 +282,12 @@ bool LiteModel::ModelVerify() const {
|
|||
MS_LOG(ERROR) << "Tensor in all tensors is nullptr.";
|
||||
return false;
|
||||
}
|
||||
// check the input data type
|
||||
auto element_size = DataTypeSize(static_cast<const TypeId>(tensor->dataType()));
|
||||
if (element_size == 0) {
|
||||
MS_LOG(ERROR) << "The data type is not supported to malloc.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::any_of(this->output_indices_.begin(), this->output_indices_.end(),
|
||||
|
|
|
@ -27,6 +27,7 @@ using mindspore::schema::ActivationType;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void *ConvolutionBaseCPUKernel::MallocAlignedData(size_t alignment, size_t size) {
|
||||
MS_CHECK_TRUE_RET(size + alignment < MAX_MALLOC_SIZE, nullptr);
|
||||
auto ptr = malloc(size + alignment);
|
||||
if (ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "MallocAlignedData failed!";
|
||||
|
@ -424,4 +425,13 @@ void ConvolutionBaseCPUKernel::UpdateOriginWeightAndBias() {
|
|||
origin_bias_ = in_tensors_.at(kBiasIndex)->data();
|
||||
}
|
||||
}
|
||||
|
||||
bool ConvolutionBaseCPUKernel::CheckInputsValid() const {
|
||||
// the data type of input and weight must be the same, while the bias data type of int8 convolution is int32.
|
||||
MS_CHECK_TRUE_RET(in_tensors_.size() >= kInputSize1, false);
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
MS_CHECK_TRUE_RET(input_tensor != nullptr && weight_tensor != nullptr, false);
|
||||
return input_tensor->data_type() == weight_tensor->data_type();
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -65,6 +65,7 @@ class ConvolutionBaseCPUKernel : public InnerKernel {
|
|||
void FreeQuantParam();
|
||||
void *MallocAlignedData(size_t alignment, size_t size);
|
||||
void FreeAlignedData(void **ptr);
|
||||
bool CheckInputsValid() const override;
|
||||
|
||||
protected:
|
||||
int InitConvWeightBias();
|
||||
|
|
|
@ -39,16 +39,19 @@ int CropBaseCPUKernel::ReSize() {
|
|||
crop_para_->out_shape_ = output_shape_.data();
|
||||
MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE);
|
||||
crop_para_->input_dim_ = input_dim;
|
||||
PadOffset(input_dim, crop_para_);
|
||||
if (PadOffset(input_dim, crop_para_) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Pad offset failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) const {
|
||||
int CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) const {
|
||||
auto axis = crop_para->axis_;
|
||||
auto offsets_size = crop_para->offset_size_;
|
||||
MS_ASSERT(axis <= input_dim);
|
||||
if (offsets_size > 1) {
|
||||
MS_ASSERT(axis + offsets_size == input_dim);
|
||||
MS_CHECK_TRUE_MSG(axis + offsets_size == input_dim, RET_ERROR, "The axis and offsets is invalid");
|
||||
}
|
||||
for (int i = 0; i < input_dim; i++) {
|
||||
int crop_offset = 0;
|
||||
|
@ -63,5 +66,6 @@ void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) const
|
|||
}
|
||||
crop_para->in_offset_[i] = crop_offset;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -40,7 +40,7 @@ class CropBaseCPUKernel : public InnerKernel {
|
|||
std::vector<int> input_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
CropParameter *crop_para_;
|
||||
void PadOffset(int input_dim, CropParameter *crop_para) const;
|
||||
int PadOffset(int input_dim, CropParameter *crop_para) const;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -84,10 +84,11 @@ int Convolution1x1FP16CPUKernel::MallocWeightBiasData() {
|
|||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto input_channel = weight_tensor->Channel();
|
||||
auto output_channel = weight_tensor->Batch();
|
||||
|
||||
MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0, RET_ERROR);
|
||||
size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, size);
|
||||
packed_weight_ = malloc(size);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc packed_weight_ error!";
|
||||
|
@ -100,6 +101,7 @@ int Convolution1x1FP16CPUKernel::MallocWeightBiasData() {
|
|||
if (in_tensors_.size() == kInputSize2) {
|
||||
size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, size);
|
||||
bias_data_ = malloc(size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
|
||||
|
|
|
@ -38,21 +38,21 @@ void ConvolutionDepthwise3x3Fp16CPUKernel::PackWeight() {
|
|||
int ConvolutionDepthwise3x3Fp16CPUKernel::MallocWeightBiasData() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
int channel = weight_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(channel > 0, RET_ERROR);
|
||||
int c8 = UP_ROUND(channel, C8NUM);
|
||||
int pack_weight_size = c8 * C12NUM;
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float16_t));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float16_t));
|
||||
if (packed_weight_ == nullptr) {
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, c8 * sizeof(float16_t));
|
||||
bias_data_ = malloc(c8 * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -34,9 +34,11 @@ void ConvolutionDepthwiseFp16CPUKernel::PackWeight() {
|
|||
int ConvolutionDepthwiseFp16CPUKernel::MallocWeightBiasData() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
int channel = weight_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(channel > 0, RET_ERROR);
|
||||
int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float16_t));
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -45,6 +47,7 @@ int ConvolutionDepthwiseFp16CPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
}
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, channel * sizeof(float16_t));
|
||||
bias_data_ = malloc(channel * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -68,18 +68,17 @@ int ConvolutionDepthwiseSWFp16CPUKernel::MallocWeightBiasData() {
|
|||
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float16_t));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float16_t));
|
||||
if (packed_weight_ == nullptr) {
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, C8NUM * OC8 * sizeof(float16_t));
|
||||
bias_data_ = malloc(C8NUM * OC8 * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -42,6 +42,7 @@ int ConvolutionFP16CPUKernel::MallocWeightBiasData() {
|
|||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
int out_channel = filter_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(in_channel > 0 && out_channel > 0, RET_ERROR);
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
int oc8 = UP_ROUND(out_channel, col_tile_);
|
||||
|
@ -51,19 +52,18 @@ int ConvolutionFP16CPUKernel::MallocWeightBiasData() {
|
|||
// init weight
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float16_t));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float16_t));
|
||||
if (packed_weight_ == nullptr) {
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
|
||||
}
|
||||
// init bias
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, oc8 * sizeof(float16_t));
|
||||
bias_data_ = malloc(oc8 * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
|
|
|
@ -36,6 +36,7 @@ int ConvolutionWinogradFP16CPUKernel::MallocWeightBiasData() {
|
|||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
int in_channel = weight_tensor->Channel();
|
||||
int out_channel = weight_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(in_channel > 0 && out_channel > 0, RET_ERROR);
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
int oc_block_num = UP_DIV(out_channel, col_tile_);
|
||||
|
@ -43,6 +44,7 @@ int ConvolutionWinogradFP16CPUKernel::MallocWeightBiasData() {
|
|||
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, trans_matrix_data_size);
|
||||
packed_weight_ = malloc(trans_matrix_data_size);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
|
||||
|
@ -68,6 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, oc_block_num * col_tile_ * sizeof(float16_t));
|
||||
bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
|
|
|
@ -74,13 +74,14 @@ int DeconvolutionDepthwiseFp16CPUKernel::MallocWeightBiasData() {
|
|||
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
|
||||
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float16_t));
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, C8NUM * OC8 * sizeof(float16_t));
|
||||
bias_data_ = malloc(C8NUM * OC8 * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -71,8 +71,10 @@ int DeConvolutionFp16CPUKernel::MallocWeightBiasData() {
|
|||
auto output_channel = weight_tensor->Channel();
|
||||
auto kernel_h = weight_tensor->Height();
|
||||
auto kernel_w = weight_tensor->Width();
|
||||
MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0 && kernel_h > 0 && kernel_w > 0, RET_ERROR);
|
||||
size_t weight_pack_size = input_channel * kernel_w * kernel_h * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, weight_pack_size);
|
||||
packed_weight_ = malloc(weight_pack_size);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv malloc packed_weight_ error!";
|
||||
|
@ -81,6 +83,7 @@ int DeConvolutionFp16CPUKernel::MallocWeightBiasData() {
|
|||
memset(packed_weight_, 0, weight_pack_size);
|
||||
}
|
||||
auto bias_size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, bias_size);
|
||||
bias_data_ = malloc(bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv malloc bias_data_ error!";
|
||||
|
|
|
@ -299,8 +299,10 @@ int Convolution1x1CPUKernel::MallocWeightBiasData() {
|
|||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto input_channel = filter_tensor->Channel();
|
||||
auto output_channel = filter_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0, RET_ERROR);
|
||||
int size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, size);
|
||||
packed_weight_ = malloc(size);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc packed_weight_ error!";
|
||||
|
@ -310,6 +312,7 @@ int Convolution1x1CPUKernel::MallocWeightBiasData() {
|
|||
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
size = UP_ROUND(output_channel, col_tile_) * sizeof(float);
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, size);
|
||||
bias_data_ = malloc(size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
|
||||
|
|
|
@ -123,10 +123,12 @@ void ConvolutionDepthwise3x3CPUKernel::PackWeight() {
|
|||
int ConvolutionDepthwise3x3CPUKernel::MallocWeightBiasData() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
int channel = weight_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(channel > 0, RET_ERROR);
|
||||
int c4 = UP_ROUND(channel, C4NUM);
|
||||
int pack_weight_size = c4 * C12NUM;
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -136,6 +138,7 @@ int ConvolutionDepthwise3x3CPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, c4 * sizeof(float));
|
||||
bias_data_ = malloc(c4 * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -108,19 +108,21 @@ void ConvolutionDepthwiseCPUKernel::PackWeight() {
|
|||
int ConvolutionDepthwiseCPUKernel::MallocWeightBiasData() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
int channel = weight_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(channel > 0, RET_ERROR);
|
||||
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (pack_weight_size >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(float))) {
|
||||
MS_LOG(ERROR) << "pack_weight_size is invalid, pack_weight_size: " << pack_weight_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, channel * sizeof(float));
|
||||
bias_data_ = malloc(channel * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -65,6 +65,7 @@ int ConvolutionDepthwiseIndirectCPUKernel::MallocIndirectBuffer() {
|
|||
step_h =
|
||||
(conv_param_->kernel_h_ * conv_param_->kernel_w_) + (conv_param_->output_w_ - 1) * step_w * conv_param_->kernel_h_;
|
||||
int buffer_size = conv_param_->output_batch_ * conv_param_->output_h_ * step_h;
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, buffer_size * sizeof(float *));
|
||||
indirect_buffer_ = reinterpret_cast<float **>(malloc(buffer_size * sizeof(float *)));
|
||||
if (indirect_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -196,12 +197,14 @@ int ConvolutionDepthwiseIndirectCPUKernel::MallocWeightBiasData() {
|
|||
int batch_flag = UP_DIV(weight_tensor->Batch(), div_flag);
|
||||
int pack_weight_size = div_flag * batch_flag * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, batch_flag * div_flag * sizeof(float));
|
||||
bias_data_ = malloc(batch_flag * div_flag * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -173,6 +173,7 @@ int ConvolutionDepthwiseSWCPUKernel::MallocWeightBiasData() {
|
|||
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -180,10 +181,8 @@ int ConvolutionDepthwiseSWCPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
}
|
||||
int malloc_size = MSMAX(conv_param_->output_channel_, C4NUM * OC4);
|
||||
if (malloc_size <= 0) {
|
||||
MS_LOG(ERROR) << "malloc size is wrong";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CHECK_LESS_RETURN(malloc_size, 0);
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, malloc_size * sizeof(float));
|
||||
bias_data_ = malloc(malloc_size * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -179,6 +179,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::MallocWeightBiasData() {
|
|||
int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_);
|
||||
int pack_weight_size = oc_algin * oc_tile_ * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc packed_weight_ is failed!";
|
||||
|
@ -188,6 +189,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::MallocWeightBiasData() {
|
|||
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto bias_size = oc_algin * oc_tile_;
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, bias_size * sizeof(float));
|
||||
bias_data_ = malloc(bias_size * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc bias_data buffer failed.";
|
||||
|
|
|
@ -200,14 +200,16 @@ void ConvolutionCPUKernel::PackWeight() {
|
|||
|
||||
int ConvolutionCPUKernel::MallocWeightBiasData() {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
size_t in_channel = filter_tensor->Channel();
|
||||
size_t out_channel = filter_tensor->Batch();
|
||||
int32_t in_channel = filter_tensor->Channel();
|
||||
int32_t out_channel = filter_tensor->Batch();
|
||||
MS_CHECK_TRUE_RET(in_channel > 0 && out_channel > 0, RET_ERROR);
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
size_t oc_block_num = UP_ROUND(out_channel, OC_BLOCK);
|
||||
size_t kernel_plane = filter_tensor->Height() * filter_tensor->Width();
|
||||
size_t pack_weight_size = oc_block_num * in_channel * kernel_plane;
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed weight failed.";
|
||||
|
@ -217,6 +219,7 @@ int ConvolutionCPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, oc_block_num * sizeof(float));
|
||||
bias_data_ = malloc(oc_block_num * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias failed.";
|
||||
|
|
|
@ -199,12 +199,14 @@ int ConvolutionSWCPUKernel::MallocWeightBiasData() {
|
|||
auto output_channel = filter_tensor->Batch();
|
||||
int kernel_h = filter_tensor->Height();
|
||||
int kernel_w = filter_tensor->Width();
|
||||
MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0 && kernel_h > 0 && kernel_w > 0, RET_ERROR);
|
||||
conv_param_->input_channel_ = input_channel;
|
||||
conv_param_->output_channel_ = output_channel;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int oc_block_num = UP_DIV(output_channel, oc_tile_);
|
||||
int pack_weight_size = oc_block_num * oc_tile_ * input_channel * kernel_plane;
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed weight failed.";
|
||||
|
@ -214,6 +216,7 @@ int ConvolutionSWCPUKernel::MallocWeightBiasData() {
|
|||
}
|
||||
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, oc_block_num * oc_tile_ * sizeof(float));
|
||||
bias_data_ = malloc(oc_block_num * oc_tile_ * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias failed.";
|
||||
|
|
|
@ -207,6 +207,7 @@ int ConvolutionWinogradCPUKernel::MallocWeightBiasData() {
|
|||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, trans_matrix_data_size);
|
||||
packed_weight_ = malloc(trans_matrix_data_size);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
|
@ -234,6 +235,7 @@ int ConvolutionWinogradCPUKernel::MallocWeightBiasData() {
|
|||
// init bias
|
||||
size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float);
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, new_bias_size);
|
||||
bias_data_ = malloc(new_bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
|
|
|
@ -191,13 +191,14 @@ int DeconvolutionDepthwiseCPUKernel::MallocWeightBiasData() {
|
|||
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, pack_weight_size * sizeof(float));
|
||||
packed_weight_ = malloc(pack_weight_size * sizeof(float));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, C4NUM * OC4 * sizeof(float));
|
||||
bias_data_ = malloc(C4NUM * OC4 * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -120,6 +120,7 @@ int DeConvolutionWinogradCPUKernel::InitParameter() {
|
|||
|
||||
int size = deconv_param_->thread_num_ * DECONV_WINOGRAD_DEFAULT_UNIT * DECONV_WINOGRAD_DEFAULT_UNIT *
|
||||
DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_;
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, size * sizeof(float));
|
||||
tile_input_ = reinterpret_cast<float *>(malloc(size * sizeof(float)));
|
||||
if (tile_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "tile_input_ error!";
|
||||
|
|
|
@ -138,6 +138,7 @@ int ResizeCPUKernel::MallocTmpBuffer() {
|
|||
}
|
||||
|
||||
{
|
||||
MS_CHECK_TRUE_RET(in_tensors_.at(0)->Channel() > 0, RET_ERROR);
|
||||
line_buffer_ =
|
||||
reinterpret_cast<float *>(malloc(static_cast<int>(sizeof(float)) * x_len * in_tensors_.at(0)->Channel() *
|
||||
kResizeSizeDouble * op_parameter_->thread_num_));
|
||||
|
|
|
@ -92,6 +92,7 @@ int ReverseSequenceCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ReverseSequenceCPUKernel::Run() {
|
||||
MS_CHECK_TRUE_RET(in_tensors_.at(0)->shape() == out_tensors_.at(0)->shape(), RET_ERROR);
|
||||
float *input0 = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
void *input1 = in_tensors_.at(1)->MutableData();
|
||||
float *output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
|
|
|
@ -46,6 +46,8 @@ int ScaleCPUKernel::InitScaleOffset() {
|
|||
auto scale_tensor = in_tensors_.at(1);
|
||||
if (reinterpret_cast<float *>(scale_tensor->data()) != nullptr) {
|
||||
scale_param_->const_scale_ = true;
|
||||
MS_CHECK_TRUE_RET(scale_tensor->ElementsNum() > 0, RET_ERROR);
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, scale_tensor->ElementsNum() * sizeof(float));
|
||||
scale_ = reinterpret_cast<float *>(malloc(scale_tensor->ElementsNum() * sizeof(float)));
|
||||
if (scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
|
|
@ -52,6 +52,7 @@ int SoftmaxCPUKernel::ReSize() {
|
|||
auto axis = softmax_param_->axis_;
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
int out_plane_size = 1;
|
||||
MS_CHECK_TRUE_RET(axis > 0 && static_cast<size_t>(axis) < in_shape.size(), RET_ERROR);
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
out_plane_size *= in_shape.at(i);
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ void CalculateTableList(int8_t *table, const float input_scale, const int32_t in
|
|||
int SigmoidInt8CPUKernel::Prepare() {
|
||||
lite::Tensor *input = in_tensors_.at(0);
|
||||
lite::Tensor *output = out_tensors_.at(0);
|
||||
MS_CHECK_TRUE_RET(!input->quant_params().empty() && !output->quant_params().empty(), RET_ERROR);
|
||||
const float input_scale = input->quant_params().front().scale;
|
||||
const int32_t input_zp = input->quant_params().front().zeroPoint;
|
||||
const float output_scale = output->quant_params().front().scale;
|
||||
|
|
|
@ -570,6 +570,11 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
parameter->quant_type_ = node->quant_type_;
|
||||
parameter->thread_num_ = context_->thread_num_;
|
||||
|
||||
if (node->output_indices_.empty()) {
|
||||
MS_LOG(ERROR) << "The output size is invalid";
|
||||
free(parameter);
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (op_parameters_.find(node->output_indices_.at(0)) != op_parameters_.end()) {
|
||||
free(parameter);
|
||||
parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
|
@ -609,6 +614,9 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
void Scheduler::FreeOpParameters() {
|
||||
for (auto ¶m : op_parameters_) {
|
||||
if (param.second != nullptr) {
|
||||
if (param.second->destroy_func_ != nullptr) {
|
||||
param.second->destroy_func_(param.second);
|
||||
}
|
||||
free(param.second);
|
||||
param.second = nullptr;
|
||||
}
|
||||
|
|
|
@ -426,7 +426,7 @@ ml_table_detection_fp32_tmp.onnx;1:actual_input_1
|
|||
ml_table_segment.onnx;1:0
|
||||
googlenet-9.onnx;1:data_0
|
||||
inception-v1-9.onnx;1:data_0
|
||||
inception-v2-9.onnx;1:data_0 855
|
||||
#inception-v2-9.onnx;1:data_0 855 #error is too big
|
||||
# shufflenet-9.onnx;1:gpu_0/data_0
|
||||
squeezenet1.0-9.onnx;1:data_0 3
|
||||
residual_distill_cifar10_bs_1.onnx;1:actual_input
|
||||
|
|
Loading…
Reference in New Issue