fuzz problem fix

This commit is contained in:
greatpan 2022-10-08 17:01:44 +08:00
parent 3c7851e714
commit 8464f516e2
3 changed files with 9 additions and 2 deletions

View File

@ -37,8 +37,11 @@ int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
}
int shape_size = 0;
if (inputs_size == C3NUM) {
if ((inputs[FIRST_INPUT]->data_ == NULL) || (inputs[SECOND_INPUT]->data_ == NULL) ||
(inputs[THIRD_INPUT]->data_ == NULL)) {
MS_CHECK_FALSE(inputs[FIRST_INPUT]->data_ == NULL, NNACL_INFER_INVALID);
MS_CHECK_FALSE(inputs[SECOND_INPUT]->data_ == NULL, NNACL_INFER_INVALID);
MS_CHECK_FALSE(inputs[THIRD_INPUT]->data_ == NULL, NNACL_INFER_INVALID);
if ((inputs[FIRST_INPUT]->data_type_ != inputs[SECOND_INPUT]->data_type_) ||
(inputs[FIRST_INPUT]->data_type_ != inputs[THIRD_INPUT]->data_type_)) {
return NNACL_INFER_INVALID;
}
if (GetElementNum(inputs[SECOND_INPUT]) < 1 || GetElementNum(inputs[THIRD_INPUT]) < 1) {

View File

@ -42,6 +42,7 @@ int InstanceNormCPUKernel::ReSize() {
MS_CHECK_INT_MUL_NOT_OVERFLOW(in_tensor->Height(), in_tensor->Width(), RET_ERROR);
param_->inner_size_ = in_tensor->Height() * in_tensor->Width();
param_->channel_ = in_tensor->Channel();
CHECK_LESS_RETURN(static_cast<int64_t>(in_tensors_.at(THIRD_INPUT)->Size()), param_->channel_);
param_->op_parameter_.thread_num_ = MSMIN(UP_DIV(param_->channel_, C8NUM), op_parameter_->thread_num_);
return RET_OK;
}

View File

@ -160,6 +160,7 @@ int WhereCPUKernel::RunWithTripleInputs() {
int condition_nums = condition->ElementsNum();
int x_num = x->ElementsNum();
int y_num = y->ElementsNum();
int out_num = out_tensors_.front()->ElementsNum();
condition_ = reinterpret_cast<bool *>(condition->data());
CHECK_NULL_RETURN(condition_);
@ -174,6 +175,8 @@ int WhereCPUKernel::RunWithTripleInputs() {
where_param_->y_num_ = y_num;
where_param_->max_num_ = num_max;
CHECK_LESS_RETURN(out_num, num_max);
if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) ||
((y_num != 1) && (y_num != num_max))) {
MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable";