diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.cc index 520e0f11a20..9778ec2c32c 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.cc @@ -37,6 +37,31 @@ BatchToSpaceInt8CPUKernel::~BatchToSpaceInt8CPUKernel() { } } +int BatchToSpaceInt8CPUKernel::Processinput() { + CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D); + CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]); + CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]); + auto block_shape_data = in_tensors_[DIMENSION_1D]->data(); + auto crops_data = in_tensors_[DIMENSION_2D]->data(); + CHECK_NULL_RETURN(block_shape_data); + CHECK_NULL_RETURN(crops_data); + auto block_shape = static_cast(block_shape_data); + auto crops = static_cast(crops_data); + CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE); + CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE); + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + block_shape_[i] = block_shape[i]; + } + no_crop_ = true; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + crops_[i] = crops[i]; + if (crops_[i] != 0) { + no_crop_ = false; + } + } + return RET_OK; +} + int BatchToSpaceInt8CPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_1D); CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_1D); @@ -82,27 +107,46 @@ int BatchToSpaceInt8CPUKernel::Run() { int8_t *output_data = reinterpret_cast(output->data()); auto in_shape = input->shape(); auto out_shape = output->shape(); - BatchToSpaceParameter *param = reinterpret_cast(this->op_parameter_); - if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON && - in_quant_arg_->zp_ == out_quant_arg_->zp_) { - if (param->no_crop_) { - BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, - sizeof(int8_t)); - } else { - BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, - sizeof(int8_t)); + if (in_tensors_.size() == 1) { + BatchToSpaceParameter *param = reinterpret_cast(this->op_parameter_); + CHECK_NULL_RETURN(param); + block_shape_[DIMENSION_0D] = param->block_shape_[DIMENSION_0D]; + block_shape_[DIMENSION_1D] = param->block_shape_[DIMENSION_1D]; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + crops_[i] = param->crops_[i]; + if (crops_[i] != 0) { + no_crop_ = false; + } } - } else { - if (param->no_crop_) { - BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, - in_quant_arg_, out_quant_arg_); - } else { - BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, - param->crops_, in_quant_arg_, out_quant_arg_); + no_crop_ = param->no_crop_; + } else if (in_tensors_.size() == 3) { + auto ret = Processinput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Processinput failed in BatchToSpace."; + return ret; } } + if (in_tensors_.size() == 1 || in_tensors_.size() == 3) { + if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON && + in_quant_arg_->zp_ == out_quant_arg_->zp_) { + if (no_crop_) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(int8_t)); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, + sizeof(int8_t)); + } + } else { + if (no_crop_) { + BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, + in_quant_arg_, out_quant_arg_); + } else { + BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, + in_quant_arg_, out_quant_arg_); + } + } + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.h index e59b3e7ff75..132fcb067e6 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/batch_to_space_int8.h @@ -35,10 +35,14 @@ class BatchToSpaceInt8CPUKernel : public LiteKernel { int Prepare() override; int ReSize() override; int Run() override; + int Processinput(); private: QuantArg *in_quant_arg_ = nullptr; QuantArg *out_quant_arg_ = nullptr; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE] = {0}; + int32_t crops_[COMM_SHAPE_SIZE] = {0}; + bool no_crop_ = false; }; } // namespace mindspore::kernel