!35274 fix operator batch_to_space int8

Merge pull request !35274 from liyan2022/dev_r1.8
This commit is contained in:
i-robot 2022-06-02 09:46:13 +00:00 committed by Gitee
commit 1267c54604
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 64 additions and 16 deletions

View File

@ -37,6 +37,31 @@ BatchToSpaceInt8CPUKernel::~BatchToSpaceInt8CPUKernel() {
}
}
int BatchToSpaceInt8CPUKernel::Processinput() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]);
CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]);
auto block_shape_data = in_tensors_[DIMENSION_1D]->data();
auto crops_data = in_tensors_[DIMENSION_2D]->data();
CHECK_NULL_RETURN(block_shape_data);
CHECK_NULL_RETURN(crops_data);
auto block_shape = static_cast<int *>(block_shape_data);
auto crops = static_cast<int *>(crops_data);
CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE);
CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE);
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
block_shape_[i] = block_shape[i];
}
no_crop_ = true;
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
crops_[i] = crops[i];
if (crops_[i] != 0) {
no_crop_ = false;
}
}
return RET_OK;
}
int BatchToSpaceInt8CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_1D);
CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_1D);
@ -82,27 +107,46 @@ int BatchToSpaceInt8CPUKernel::Run() {
int8_t *output_data = reinterpret_cast<int8_t *>(output->data());
auto in_shape = input->shape();
auto out_shape = output->shape();
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
if (param->no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(int8_t));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
sizeof(int8_t));
if (in_tensors_.size() == 1) {
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
CHECK_NULL_RETURN(param);
block_shape_[DIMENSION_0D] = param->block_shape_[DIMENSION_0D];
block_shape_[DIMENSION_1D] = param->block_shape_[DIMENSION_1D];
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
crops_[i] = param->crops_[i];
if (crops_[i] != 0) {
no_crop_ = false;
}
}
} else {
if (param->no_crop_) {
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
in_quant_arg_, out_quant_arg_);
} else {
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
param->crops_, in_quant_arg_, out_quant_arg_);
no_crop_ = param->no_crop_;
} else if (in_tensors_.size() == 3) {
auto ret = Processinput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
return ret;
}
}
if (in_tensors_.size() == 1 || in_tensors_.size() == 3) {
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
if (no_crop_) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(int8_t));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
sizeof(int8_t));
}
} else {
if (no_crop_) {
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_,
in_quant_arg_, out_quant_arg_);
} else {
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
in_quant_arg_, out_quant_arg_);
}
}
}
return RET_OK;
}

View File

@ -35,10 +35,14 @@ class BatchToSpaceInt8CPUKernel : public LiteKernel {
int Prepare() override;
int ReSize() override;
int Run() override;
int Processinput();
private:
QuantArg *in_quant_arg_ = nullptr;
QuantArg *out_quant_arg_ = nullptr;
int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE] = {0};
int32_t crops_[COMM_SHAPE_SIZE] = {0};
bool no_crop_ = false;
};
} // namespace mindspore::kernel