forked from mindspore-Ecosystem/mindspore
!35274 fix operator batch_to_space int8
Merge pull request !35274 from liyan2022/dev_r1.8
This commit is contained in:
commit
1267c54604
|
@ -37,6 +37,31 @@ BatchToSpaceInt8CPUKernel::~BatchToSpaceInt8CPUKernel() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int BatchToSpaceInt8CPUKernel::Processinput() {
|
||||||
|
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
|
||||||
|
CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]);
|
||||||
|
CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]);
|
||||||
|
auto block_shape_data = in_tensors_[DIMENSION_1D]->data();
|
||||||
|
auto crops_data = in_tensors_[DIMENSION_2D]->data();
|
||||||
|
CHECK_NULL_RETURN(block_shape_data);
|
||||||
|
CHECK_NULL_RETURN(crops_data);
|
||||||
|
auto block_shape = static_cast<int *>(block_shape_data);
|
||||||
|
auto crops = static_cast<int *>(crops_data);
|
||||||
|
CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE);
|
||||||
|
CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE);
|
||||||
|
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
|
||||||
|
block_shape_[i] = block_shape[i];
|
||||||
|
}
|
||||||
|
no_crop_ = true;
|
||||||
|
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
|
||||||
|
crops_[i] = crops[i];
|
||||||
|
if (crops_[i] != 0) {
|
||||||
|
no_crop_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int BatchToSpaceInt8CPUKernel::Prepare() {
|
int BatchToSpaceInt8CPUKernel::Prepare() {
|
||||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_1D);
|
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_1D);
|
||||||
CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_1D);
|
CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_1D);
|
||||||
|
@ -82,27 +107,46 @@ int BatchToSpaceInt8CPUKernel::Run() {
|
||||||
int8_t *output_data = reinterpret_cast<int8_t *>(output->data());
|
int8_t *output_data = reinterpret_cast<int8_t *>(output->data());
|
||||||
auto in_shape = input->shape();
|
auto in_shape = input->shape();
|
||||||
auto out_shape = output->shape();
|
auto out_shape = output->shape();
|
||||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
|
||||||
|
|
||||||
|
if (in_tensors_.size() == 1) {
|
||||||
|
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
||||||
|
CHECK_NULL_RETURN(param);
|
||||||
|
block_shape_[DIMENSION_0D] = param->block_shape_[DIMENSION_0D];
|
||||||
|
block_shape_[DIMENSION_1D] = param->block_shape_[DIMENSION_1D];
|
||||||
|
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
|
||||||
|
crops_[i] = param->crops_[i];
|
||||||
|
if (crops_[i] != 0) {
|
||||||
|
no_crop_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
no_crop_ = param->no_crop_;
|
||||||
|
} else if (in_tensors_.size() == 3) {
|
||||||
|
auto ret = Processinput();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (in_tensors_.size() == 1 || in_tensors_.size() == 3) {
|
||||||
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
|
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
|
||||||
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
|
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
|
||||||
if (param->no_crop_) {
|
if (no_crop_) {
|
||||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(int8_t));
|
||||||
sizeof(int8_t));
|
|
||||||
} else {
|
} else {
|
||||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
|
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
|
||||||
sizeof(int8_t));
|
sizeof(int8_t));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (param->no_crop_) {
|
if (no_crop_) {
|
||||||
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_,
|
||||||
in_quant_arg_, out_quant_arg_);
|
in_quant_arg_, out_quant_arg_);
|
||||||
} else {
|
} else {
|
||||||
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
|
||||||
param->crops_, in_quant_arg_, out_quant_arg_);
|
in_quant_arg_, out_quant_arg_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,10 +35,14 @@ class BatchToSpaceInt8CPUKernel : public LiteKernel {
|
||||||
int Prepare() override;
|
int Prepare() override;
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
|
int Processinput();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QuantArg *in_quant_arg_ = nullptr;
|
QuantArg *in_quant_arg_ = nullptr;
|
||||||
QuantArg *out_quant_arg_ = nullptr;
|
QuantArg *out_quant_arg_ = nullptr;
|
||||||
|
int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE] = {0};
|
||||||
|
int32_t crops_[COMM_SHAPE_SIZE] = {0};
|
||||||
|
bool no_crop_ = false;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue