!22571 [MS][Lite][Stable]code security check fix

Merge pull request !22571 from ivss/Dev2
This commit is contained in:
i-robot 2021-08-30 13:16:55 +00:00 committed by Gitee
commit 943a9a020b
44 changed files with 104 additions and 69 deletions

View File

@ -574,7 +574,6 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *
}
#endif
for (; index < element_size; index++) {
NNACL_ASSERT(input1[index] != 0);
output[index] = input0[index] / input1[index];
}
return NNACL_OK;
@ -596,7 +595,6 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_
}
#endif
for (; index < element_size; index++) {
NNACL_ASSERT(input1[index] != 0);
output[index] = input0[0] / input1[index];
}
} else {
@ -633,7 +631,6 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16
if (input1[index] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
NNACL_ASSERT(input1[index] != 0);
float16_t res = input0[index] / input1[index];
output[index] = res > 0 ? res : 0;
}
@ -660,7 +657,6 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa
if (input1[index] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
NNACL_ASSERT(input1[index] != 0);
output[index] = MSMAX(input0[0] / input1[index], 0);
}
} else {
@ -758,12 +754,10 @@ int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, flo
const ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[0] != 0);
output[i] = input0[i] - floorf(input0[i] / input1[0]) * input1[0];
}
} else {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[i] != 0);
output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i];
}
}
@ -772,7 +766,6 @@ int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, flo
int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[i] != 0);
output[i] = floorf(input0[i] / input1[i]);
}
return NNACL_OK;
@ -781,12 +774,10 @@ int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, flo
const ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[0] != 0);
output[i] = floorf(input0[i] / input1[0]);
}
} else {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[i] != 0);
output[i] = floorf(input0[i] / input1[i]);
}
}

View File

@ -66,6 +66,7 @@ void Fp16ScaleAxis(const float16_t *in_data, float16_t *out_data, const float16_
void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset,
int task_id, const ScaleParameter *scale_param) {
NNACL_CHECK_ZERO_RETURN(scale_param->op_parameter_.thread_num_);
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
@ -137,6 +138,7 @@ void Fp16ScaleAxisRelu(const float16_t *in_data, float16_t *out_data, const floa
void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset,
int task_id, const ScaleParameter *scale_param) {
NNACL_CHECK_ZERO_RETURN(scale_param->op_parameter_.thread_num_);
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
@ -210,6 +212,7 @@ void Fp16ScaleAxisRelu6(const float16_t *in_data, float16_t *out_data, const flo
void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset,
int task_id, const ScaleParameter *scale_param) {
NNACL_CHECK_ZERO_RETURN(scale_param->op_parameter_.thread_num_);
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);

View File

@ -38,6 +38,7 @@ void RaggedRangeInt(const int *starts, const int *limits, const int *deltas, int
int start = param->starts_is_scalar ? starts[0] : starts[i];
int limit = param->limits_is_scalar ? limits[0] : limits[i];
int delta = param->deltas_is_scalar ? deltas[0] : deltas[i];
NNACL_CHECK_ZERO_RETURN(delta);
int len = MSMAX((int)ceil((float)(limit - start) / delta), 0);
splits[i + 1] = splits[i] + len;
for (int j = 0; j < len; j++) {

View File

@ -20,6 +20,7 @@
int MatmulInfer(const AffineParameter *param, int a_shape[MAX_SHAPE_SIZE], size_t a_shape_size,
int b_shape[MAX_SHAPE_SIZE], size_t b_shape_size) {
MatMulParameter *matmul_param = param->matmul_parameter_;
NNACL_CHECK_NULL_RETURN_ERR(matmul_param);
if (matmul_param->a_transpose_) {
if (a_shape_size < 2) {
return NNACL_ERR;
@ -37,12 +38,10 @@ int MatmulInfer(const AffineParameter *param, int a_shape[MAX_SHAPE_SIZE], size_
int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
// splice + matmul
TensorC *input0 = (TensorC *)inputs[0];
TensorC *input1 = (TensorC *)inputs[1];

View File

@ -51,14 +51,13 @@ int BroadCastInferShape(const int input_shape0_size, const int input_shape1_size
int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
ArithmeticParameter *param = (ArithmeticParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
param->broadcasting_ = false;
const TensorC *input0 = inputs[0];

View File

@ -24,6 +24,7 @@ int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape)
if (shape_size == 0) {
return NNACL_INFER_INVALID;
}
NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_);
switch (shape_tensor->data_type_) {
case kNumberTypeInt8: {
int8_t *data = (int8_t *)(shape_tensor->data_);

View File

@ -275,6 +275,7 @@ int GetElementNum(const TensorC *tensor) {
}
int res = 1;
for (size_t i = 0; i < tensor->shape_size_; i++) {
MS_CHECK_INT_MUL_NOT_OVERFLOW(res, tensor->shape_[i], NNACL_ERRCODE_MUL_OVERFLOW);
res = res * tensor->shape_[i];
}
return res;

View File

@ -19,12 +19,10 @@
int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
const TensorC *input = inputs[0];
TensorC *output = outputs[0];
@ -35,6 +33,7 @@ int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou
}
SetShapeTensor(output, input);
GluParameter *param = (GluParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
int axis = param->axis_ > 0 ? param->axis_ : (int)input->shape_size_ + param->axis_;
output->shape_[axis] /= 2;
return NNACL_OK;

View File

@ -30,6 +30,7 @@ int CheckInputTensor(const TensorC *const *inputs) {
int GetRows(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, bool deltas_is_scalar,
int *rows) {
NNACL_CHECK_NULL_RETURN_ERR(rows);
int sizes[3];
int not_scalar_count = 0;
if (!starts_is_scalar) {
@ -61,6 +62,7 @@ int GetOutputValueElementNum(const TensorC *const *inputs, RaggedRangeParameter
int start = param->starts_is_scalar ? starts[0] : starts[i];
int limit = param->limits_is_scalar ? limits[0] : limits[i];
int delta = param->deltas_is_scalar ? deltas[0] : deltas[i];
NNACL_CHECK_ZERO_RETURN_ERR(delta);
count += MSMAX((int)(ceil((float)(limit - start) / delta)), 0);
}
} break;
@ -72,6 +74,7 @@ int GetOutputValueElementNum(const TensorC *const *inputs, RaggedRangeParameter
int start = param->starts_is_scalar ? starts[0] : starts[i];
int limit = param->limits_is_scalar ? limits[0] : limits[i];
int delta = param->deltas_is_scalar ? deltas[0] : deltas[i];
NNACL_CHECK_ZERO_RETURN_ERR(delta);
count += MSMAX((int)(ceil((float)(limit - start) / delta)), 0);
}
} break;
@ -85,16 +88,10 @@ int GetOutputValueElementNum(const TensorC *const *inputs, RaggedRangeParameter
int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif
if (inputs_size != 3 || outputs_size != 2) {
return NNACL_INPUT_TENSOR_ERROR;
}
outputs[0]->data_type_ = kNumberTypeInt32;
outputs[0]->format_ = inputs[0]->format_;
@ -108,6 +105,7 @@ int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
return ret;
}
RaggedRangeParameter *param = (RaggedRangeParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
param->starts_is_scalar = inputs[0]->shape_size_ == 0;
param->limits_is_scalar = inputs[1]->shape_size_ == 0;
param->deltas_is_scalar = inputs[2]->shape_size_ == 0;

View File

@ -75,6 +75,7 @@ int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
}
} else {
RangeParameter *param = (RangeParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
if (param->delta_ == 0) {
return NNACL_PARAM_INVALID;
}

View File

@ -16,6 +16,7 @@
#include "nnacl/infer/resize_infer.h"
#include <math.h>
#include <limits.h>
#include "nnacl/infer/infer_register.h"
int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
@ -42,6 +43,9 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
if (data == NULL) {
return NNACL_INFER_INVALID;
}
MS_CHECK_INT_MUL_NOT_OVERFLOW(data[1], GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
MS_CHECK_INT_MUL_NOT_OVERFLOW(data[2], GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
param->new_height_ = round(data[1] * GetHeight(input));
param->new_width_ = round(data[2] * GetWidth(input));
}
@ -68,6 +72,8 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
} else {
return NNACL_ERR;
}
MS_CHECK_INT_MUL_NOT_OVERFLOW(GetHeight(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW);
MS_CHECK_INT_MUL_NOT_OVERFLOW(GetWidth(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW);
param->new_height_ = GetHeight(input) + (GetHeight(input) - 1) * (scale - 1);
param->new_width_ = GetWidth(input) + (GetWidth(input) - 1) * (scale - 1);
break;
@ -109,6 +115,7 @@ int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
return NNACL_ERR;
}
ResizeParameter *param = (ResizeParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
int output_shape[MAX_SHAPE_SIZE] = {0};
size_t output_shape_size = 0;
ShapePush(output_shape, &output_shape_size, GetBatch(input));

View File

@ -31,6 +31,7 @@ int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
}
SetDataTypeFormat(outputs[0], input);
SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter;
NNACL_CHECK_NULL_RETURN_ERR(param);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}

View File

@ -195,9 +195,9 @@
} while (0)
#define MS_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) \
MS_CHECK_TRUE(!(INT_MUL_OVERFLOW(value1, value2)), errcode)
MS_CHECK_TRUE_RET(!(INT_MUL_OVERFLOW(value1, value2)), errcode)
#define MS_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) \
MS_CHECK_TRUE(!(INT_ADD_OVERFLOW(value1, value2)), errcode)
MS_CHECK_TRUE_RET(!(INT_ADD_OVERFLOW(value1, value2)), errcode)
#define NNACL_CHECK_ZERO_RETURN_ERR(val) \
do { \

View File

@ -31,8 +31,8 @@ static void ReleaseParam(AffineParameter *affine, MatMulParameter *matmul) {
}
OpParameter *PopulateAffineParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Affine();
if (value == nullptr) {
MS_LOG(ERROR) << "cast affine_primitive to value failed";

View File

@ -37,9 +37,8 @@ using mindspore::schema::PrimitiveType_SquaredDifference;
namespace mindspore {
namespace lite {
ArithmeticParameter *PopulateArithmeticCommonPara(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";

View File

@ -35,8 +35,8 @@ using mindspore::schema::PrimitiveType_Square;
namespace mindspore {
namespace lite {
OpParameter *PopulateArithmeticSelf(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (param == nullptr) {

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_BiasAdd;
namespace mindspore {
namespace lite {
OpParameter *PopulateBiasAddParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
auto *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_GLU;
namespace mindspore {
namespace lite {
OpParameter *PopulateGluParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_GLU();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_LRN;
namespace mindspore {
namespace lite {
OpParameter *PopulateLocalResponseNormParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_LRN();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -20,9 +20,8 @@ using mindspore::schema::PrimitiveType_RaggedRange;
namespace mindspore {
namespace lite {
OpParameter *PopulateRaggedRangeParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<RaggedRangeParameter *>(malloc(sizeof(RaggedRangeParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc RaggedRangeParameter failed.";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_Range;
namespace mindspore {
namespace lite {
OpParameter *PopulateRangeParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Range();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_Resize;
namespace mindspore {
namespace lite {
OpParameter *PopulateResizeParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_Resize();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_ScaleFusion;
namespace mindspore {
namespace lite {
OpParameter *PopulateScaleParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_ScaleFusion();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -20,8 +20,8 @@ using mindspore::schema::PrimitiveType_SpaceToDepth;
namespace mindspore {
namespace lite {
OpParameter *PopulateSpaceToDepthParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_SpaceToDepth();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";

View File

@ -23,8 +23,8 @@ namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateArithmeticSelfV0(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *arithmetic_self_param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (arithmetic_self_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed.";

View File

@ -22,8 +22,8 @@ namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateLocalResponseNormParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto local_response_normalization_prim = primitive->value_as_LocalResponseNormalization();
if (local_response_normalization_prim == nullptr) {
MS_LOG(ERROR) << "local_response_normalization_prim is nullptr";

View File

@ -22,8 +22,8 @@ namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateRangeParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto range_prim = primitive->value_as_Range();
if (range_prim == nullptr) {
MS_LOG(ERROR) << "range_prim is nullptr";

View File

@ -22,8 +22,8 @@ namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateScaleParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto scale_prim = primitive->value_as_Scale();
if (scale_prim == nullptr) {
MS_LOG(ERROR) << "scale_prim is nullptr";

View File

@ -56,7 +56,9 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int
int ArithmeticSelfFp16CPUKernel::DoExecute(int task_id) {
int elements_num = in_tensors_.at(0)->ElementsNum();
MS_CHECK_TRUE_RET(op_parameter_->thread_num_ != 0, RET_ERROR);
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, RET_ERROR);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {
@ -76,8 +78,8 @@ int ArithmeticSelfFp16CPUKernel::DoExecute(int task_id) {
int ArithmeticSelfFp16CPUKernel::Run() {
auto input_tensor = in_tensors_.at(0);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(output_tensor != nullptr);
CHECK_NULL_RETURN(input_tensor);
CHECK_NULL_RETURN(output_tensor);
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_fp16_ptr_ = ConvertInputFp32toFp16(input_tensor, static_cast<const lite::InnerContext *>(ms_context_));
if (input_fp16_ptr_ == nullptr) {
@ -85,10 +87,10 @@ int ArithmeticSelfFp16CPUKernel::Run() {
}
} else {
input_fp16_ptr_ = reinterpret_cast<float16_t *>(input_tensor->data_c());
MS_ASSERT(input_fp16_ptr_ != nullptr);
CHECK_NULL_RETURN(input_fp16_ptr_);
}
output_fp16_ptr_ = reinterpret_cast<float16_t *>(output_tensor->data_c());
MS_ASSERT(output_fp16_ptr_ != nullptr);
CHECK_NULL_RETURN(output_fp16_ptr_);
auto ret = ParallelLaunch(ms_context_, ArithmeticSelfRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {

View File

@ -60,10 +60,11 @@ int BiasAddCPUFp16Kernel::Run() {
}
auto in = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
auto out = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
MS_ASSERT(in != nullptr);
MS_ASSERT(out != nullptr);
CHECK_NULL_RETURN(in);
CHECK_NULL_RETURN(out);
size_t data_size = in_tensors_.at(0)->ElementsNum();
MS_ASSERT(ms_context_->allocator != nullptr);
CHECK_NULL_RETURN(ms_context_->allocator);
MS_CHECK_INT_MUL_NOT_OVERFLOW(data_size, sizeof(float16_t), RET_ERROR);
auto tile_in = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(data_size * sizeof(float16_t)));
auto tile_bias = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(data_size * sizeof(float16_t)));
if (tile_in == nullptr || tile_bias == nullptr) {
@ -89,6 +90,7 @@ int BiasAddCPUFp16Kernel::GetBiasData() {
bias_data_type_ = bias_tensor_->data_type();
if (bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) {
if (bias_data_ == nullptr) {
MS_CHECK_INT_MUL_NOT_OVERFLOW(bias_tensor_->ElementsNum(), sizeof(float16_t), RET_ERROR);
bias_data_ = reinterpret_cast<float16_t *>(malloc(bias_tensor_->ElementsNum() * sizeof(float16_t)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "bias_data_ is nullptr";
@ -117,7 +119,7 @@ int BiasAddCPUFp16Kernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), 2);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
bias_tensor_ = in_tensors_.at(1);
MS_ASSERT(bias_tensor_ != nullptr);
CHECK_NULL_RETURN(bias_tensor_);
if (!InferShapeDone()) {
return RET_OK;
}

View File

@ -26,6 +26,8 @@ using mindspore::schema::PrimitiveType_RaggedRange;
namespace mindspore::kernel {
int RaggedRangeFp16CPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), 3);
CHECK_LESS_RETURN(out_tensors_.size(), 2);
if (!InferShapeDone()) {
return RET_OK;
}

View File

@ -102,12 +102,12 @@ int ScaleFp16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
int ScaleFp16CPUKernel::Run() {
auto input_tensor = in_tensors_.at(0);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(output_tensor != nullptr);
CHECK_NULL_RETURN(input_tensor);
CHECK_NULL_RETURN(output_tensor);
input_ = reinterpret_cast<float16_t *>(input_tensor->data_c());
output_ = reinterpret_cast<float16_t *>(output_tensor->data_c());
MS_ASSERT(input_ != nullptr);
MS_ASSERT(output_ != nullptr);
CHECK_NULL_RETURN(input_);
CHECK_NULL_RETURN(output_);
auto ret = InitScaleOffset();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale fp16 InitScaleOffset failed.";
@ -143,6 +143,7 @@ int ScaleFp16CPUKernel::MallocAssignTmpBuffer() {
return RET_ERROR;
}
} else {
MS_CHECK_INT_MUL_NOT_OVERFLOW(in_tensors_.at(1)->ElementsNum(), sizeof(float16_t), RET_ERROR);
offset_ = reinterpret_cast<float16_t *>(
ms_context_->allocator->Malloc(in_tensors_.at(1)->ElementsNum() * sizeof(float16_t)));
if (offset_ == nullptr) {

View File

@ -72,7 +72,9 @@ int ArithmeticSelfCPUKernel::ReSize() { return RET_OK; }
int ArithmeticSelfCPUKernel::DoExecute(int task_id) {
int elements_num = in_tensors_.at(0)->ElementsNum();
MS_CHECK_TRUE_RET(op_parameter_->thread_num_ != 0, RET_ERROR);
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, RET_ERROR);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {

View File

@ -48,7 +48,7 @@ int BiasCPUKernel::Run() {
auto bias = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
size_t data_size = static_cast<size_t>(in_tensors_.at(0)->ElementsNum());
MS_ASSERT(ms_context_->allocator != nullptr);
CHECK_NULL_RETURN(ms_context_->allocator);
float *tile_in = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
float *tile_bias = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
if (tile_in == nullptr || tile_bias == nullptr) {

View File

@ -98,6 +98,7 @@ int GluCPUKernel::ReSize() {
int GluCPUKernel::Split(int task_id) {
input_ptr_ = in_tensors_.front()->data_c();
MS_CHECK_INT_MUL_NOT_OVERFLOW(task_id, thread_n_stride_, RET_ERROR);
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
@ -117,8 +118,9 @@ int GluCPUKernel::Sigmoid(int task_id) {
auto input_addr = reinterpret_cast<float *>(split_ptr_.at(1));
auto output_addr = reinterpret_cast<float *>(sigmoid_ptr_);
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
MS_CHECK_TRUE_RET(op_parameter_->thread_num_ != 0, RET_ERROR);
int stride = UP_DIV(length, op_parameter_->thread_num_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, RET_ERROR);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
@ -131,8 +133,9 @@ int GluCPUKernel::Mul(int task_id) {
auto input_addr1 = reinterpret_cast<float *>(sigmoid_ptr_);
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
MS_CHECK_TRUE_RET(op_parameter_->thread_num_ != 0, RET_ERROR);
int stride = UP_DIV(length, op_parameter_->thread_num_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, RET_ERROR);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;

View File

@ -42,7 +42,7 @@ int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) {
auto output_ptr = reinterpret_cast<float *>(out_tensor->MutableData());
auto in_shape = input_tensor->shape();
MS_ASSERT(in_shape.size() == 4);
MS_CHECK_TRUE_RET(in_shape.size() == 4, RET_ERROR);
int batch = in_shape.at(0);
int height = in_shape.at(1);
@ -50,7 +50,9 @@ int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) {
int channel = in_shape.at(3);
int outer_size = batch * width * height;
MS_CHECK_TRUE_RET(thread_count_ != 0, RET_ERROR);
int stride = UP_DIV(outer_size, thread_count_);
MS_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, RET_ERROR);
int count = MSMIN(stride, outer_size - stride * task_id);
input_ptr += stride * task_id * channel;

View File

@ -26,6 +26,8 @@ using mindspore::schema::PrimitiveType_RaggedRange;
namespace mindspore::kernel {
int RaggedRangeCPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), 3);
CHECK_LESS_RETURN(out_tensors_.size(), 2);
if (!InferShapeDone()) {
return RET_OK;
}

View File

@ -68,7 +68,7 @@ int ScaleCPUKernel::InitScaleOffset() {
} else if (in_tensors_.size() == 3 && reinterpret_cast<float *>(in_tensors_.at(2)->data_c()) != nullptr) {
scale_param_->const_offset_ = true;
auto offset_tensor = in_tensors_.at(2);
MS_ASSERT(scale_tensor->ElementsNum() == offset_tensor->ElementsNum());
MS_CHECK_TRUE_RET(scale_tensor->ElementsNum() == offset_tensor->ElementsNum(), RET_ERROR);
offset_ = reinterpret_cast<float *>(malloc(offset_tensor->ElementsNum() * sizeof(float)));
if (offset_ == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
@ -180,13 +180,12 @@ int ScaleCPUKernel::Run() {
if (!scale_param_->const_scale_) {
auto scale_tensor = in_tensors_.at(1);
scale_ = reinterpret_cast<float *>(scale_tensor->data_c());
MS_ASSERT(scale_ != nullptr);
CHECK_NULL_RETURN(scale_);
}
if (!scale_param_->const_offset_) {
MS_ASSERT(in_tensors_.size() == 3);
auto offset_tensor = in_tensors_.at(2);
offset_ = reinterpret_cast<float *>(offset_tensor->data_c());
MS_ASSERT(offset_ != nullptr);
CHECK_NULL_RETURN(offset_);
}
auto out_tensor = out_tensors_.front();
output_ptr_ = reinterpret_cast<float *>(out_tensor->MutableData());

View File

@ -35,6 +35,7 @@ int SpaceToDepthCPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), 1);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
SpaceToDepthParameter *param = reinterpret_cast<SpaceToDepthParameter *>(op_parameter_);
CHECK_NULL_RETURN(param);
if (param->block_size_ <= 0) {
MS_LOG(ERROR) << "Input block_size should > 0!";
return RET_PARAM_INVALID;
@ -62,6 +63,7 @@ int SpaceToDepthCPUKernel::ReSize() {
}
int SpaceToDepthCPUKernel::SpaceToDepth(int task_id) {
MS_CHECK_INT_MUL_NOT_OVERFLOW(task_id, thread_h_stride_, RET_ERROR);
int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
@ -70,9 +72,9 @@ int SpaceToDepthCPUKernel::SpaceToDepth(int task_id) {
auto in_shape = in_tensors_.at(0)->shape();
auto out_shape = out_tensors_.at(0)->shape();
SpaceToDepthParameter *param = reinterpret_cast<SpaceToDepthParameter *>(op_parameter_);
MS_ASSERT(param);
MS_ASSERT(input_ptr_);
MS_ASSERT(output_ptr_);
CHECK_NULL_RETURN(param);
CHECK_NULL_RETURN(input_ptr_);
CHECK_NULL_RETURN(output_ptr_);
auto ret = SpaceToDepthForNHWC(input_ptr_, output_ptr_, in_shape.data(), out_shape.data(), in_shape.size(),
param->block_size_, thread_offset, thread_offset + num_unit_thread, sizeof(float));
if (ret != RET_OK) {
@ -84,6 +86,7 @@ int SpaceToDepthCPUKernel::SpaceToDepth(int task_id) {
int SpaceToDepthRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto g_kernel = reinterpret_cast<SpaceToDepthCPUKernel *>(cdata);
CHECK_NULL_RETURN(g_kernel);
auto ret = g_kernel->SpaceToDepth(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToDepthRun error task_id[" << task_id << "] error_code[" << ret << "]";

View File

@ -27,7 +27,10 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int ArithmeticSelfInt8CPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), kInputIndex + 1);
CHECK_LESS_RETURN(out_tensors_.size(), kOutputIndex + 1);
auto *input_tensor = in_tensors_.at(kInputIndex);
CHECK_NULL_RETURN(input_tensor);
auto in_quant_args = input_tensor->quant_params();
para_->quant_arg_.in_args_.scale_ = in_quant_args.front().scale;
para_->quant_arg_.in_args_.zp_ = in_quant_args.front().zeroPoint * (-1);

View File

@ -151,7 +151,7 @@ int ScaleInt8CPUKernel::InitParameter() {
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
} else {
MS_ASSERT(input0_size > input1_size);
MS_CHECK_TRUE_RET(input0_size > input1_size, RET_ERROR);
size_t fill_dim_num = input0_size - input1_size;
int j = 0;
for (size_t i = 0; i < output_size; i++) {
@ -253,20 +253,21 @@ int ScaleInt8CPUKernel::ReSize() {
}
int ScaleInt8CPUKernel::Scale(int task_id) const {
MS_CHECK_INT_MUL_NOT_OVERFLOW(task_id, count_unit_, RET_ERROR);
int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
if (real_dst_count <= 0) {
return lite::RET_OK;
}
int8_t *cur_input0_data = input0_data_ + task_id * count_unit_;
MS_ASSERT(cur_input0_data);
CHECK_NULL_RETURN(cur_input0_data);
int8_t *cur_input1_data = input1_data_ + task_id * count_unit_;
MS_ASSERT(cur_input1_data);
CHECK_NULL_RETURN(cur_input1_data);
int8_t *cur_output_data = output_data_ + task_id * count_unit_;
MS_ASSERT(cur_output_data);
CHECK_NULL_RETURN(cur_output_data);
if (has_bias_) {
int8_t *cur_input2_data = input2_data_ + task_id * count_unit_;
MS_ASSERT(cur_input2_data);
CHECK_NULL_RETURN(cur_input2_data);
DoScaleWithBiasInt8(cur_input0_data, cur_output_data, cur_input1_data, cur_input2_data, scale_param_,
real_dst_count);
} else {

View File

@ -47,8 +47,16 @@ void ArithmeticSelfGetWorkGroup(const std::vector<size_t> &global, std::vector<s
const int max_divider = 8;
const int max_x = 4, max_y = 8;
int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x);
if (x == 0) {
MS_LOG(ERROR) << "div num shouldn't be 0";
return;
}
int yz = max_size / x;
int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y);
if (y == 0) {
MS_LOG(ERROR) << "div num shouldn't be 0";
return;
}
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
local->clear();

View File

@ -57,6 +57,7 @@ int ArithmeticInt8OpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
CHECK_NULL_RETURN(param);
if (!IsArithmetic(type())) {
MS_LOG(ERROR) << "UnSupported Operator: " << schema::EnumNamePrimitiveType(type());
return RET_ERROR;
@ -183,6 +184,7 @@ int ArithmeticInt8OpenCLKernel::Prepare() {
out_shape_ = GpuTensorInfo(out_tensors_[0]);
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
CHECK_NULL_RETURN(param);
if (type() == PrimitiveType_BiasAdd) {
const_cast<ArithmeticParameter *>(param)->broadcasting_ = true;
}

View File

@ -43,6 +43,7 @@ int ResizeOpenCLKernel::CheckSpecs() {
return RET_PARAM_INVALID;
}
auto resize_param = reinterpret_cast<ResizeParameter *>(op_parameter_);
CHECK_NULL_RETURN(resize_param);
if (resize_param->method_ != schema::ResizeMethod_LINEAR && resize_param->method_ != schema::ResizeMethod_NEAREST) {
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;
return RET_PARAM_INVALID;
@ -52,6 +53,7 @@ int ResizeOpenCLKernel::CheckSpecs() {
int ResizeOpenCLKernel::Prepare() {
auto resize_param = reinterpret_cast<ResizeParameter *>(op_parameter_);
CHECK_NULL_RETURN(resize_param);
alignCorner = resize_param->coordinate_transform_mode_ == 1;
preserveAspectRatio = resize_param->preserve_aspect_ratio_;
auto in_shape = in_tensors_[0]->shape();
@ -93,6 +95,8 @@ float ResizeOpenCLKernel::getResizeScaleFactor(int input_size, int output_size)
int ResizeOpenCLKernel::SetConstArgs() {
auto in_shape = in_tensors_[0]->shape();
auto out_shape = out_tensors_[0]->shape();
MS_CHECK_GE(in_shape.size(), DIMENSION_4D, RET_ERROR);
MS_CHECK_GE(out_shape.size(), DIMENSION_4D, RET_ERROR);
int n = out_shape[0];
int h = out_shape[1];
int w = out_shape[2];