forked from mindspore-Ecosystem/mindspore
fix op batch_to_space int8
This commit is contained in:
parent
df9b900c04
commit
f4122956ab
|
@ -37,6 +37,31 @@ BatchToSpaceInt8CPUKernel::~BatchToSpaceInt8CPUKernel() {
|
|||
}
|
||||
}
|
||||
|
||||
int BatchToSpaceInt8CPUKernel::Processinput() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
|
||||
CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]);
|
||||
CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]);
|
||||
auto block_shape_data = in_tensors_[DIMENSION_1D]->data();
|
||||
auto crops_data = in_tensors_[DIMENSION_2D]->data();
|
||||
CHECK_NULL_RETURN(block_shape_data);
|
||||
CHECK_NULL_RETURN(crops_data);
|
||||
auto block_shape = static_cast<int *>(block_shape_data);
|
||||
auto crops = static_cast<int *>(crops_data);
|
||||
CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE);
|
||||
CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE);
|
||||
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
|
||||
block_shape_[i] = block_shape[i];
|
||||
}
|
||||
no_crop_ = true;
|
||||
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
|
||||
crops_[i] = crops[i];
|
||||
if (crops_[i] != 0) {
|
||||
no_crop_ = false;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpaceInt8CPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_1D);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_1D);
|
||||
|
@ -82,27 +107,46 @@ int BatchToSpaceInt8CPUKernel::Run() {
|
|||
int8_t *output_data = reinterpret_cast<int8_t *>(output->data());
|
||||
auto in_shape = input->shape();
|
||||
auto out_shape = output->shape();
|
||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
||||
|
||||
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
|
||||
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
|
||||
if (param->no_crop_) {
|
||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
||||
sizeof(int8_t));
|
||||
} else {
|
||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
|
||||
sizeof(int8_t));
|
||||
if (in_tensors_.size() == 1) {
|
||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
|
||||
CHECK_NULL_RETURN(param);
|
||||
block_shape_[DIMENSION_0D] = param->block_shape_[DIMENSION_0D];
|
||||
block_shape_[DIMENSION_1D] = param->block_shape_[DIMENSION_1D];
|
||||
for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
|
||||
crops_[i] = param->crops_[i];
|
||||
if (crops_[i] != 0) {
|
||||
no_crop_ = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (param->no_crop_) {
|
||||
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
||||
in_quant_arg_, out_quant_arg_);
|
||||
} else {
|
||||
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
|
||||
param->crops_, in_quant_arg_, out_quant_arg_);
|
||||
no_crop_ = param->no_crop_;
|
||||
} else if (in_tensors_.size() == 3) {
|
||||
auto ret = Processinput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (in_tensors_.size() == 1 || in_tensors_.size() == 3) {
|
||||
if (std::abs(in_quant_arg_->scale_ - out_quant_arg_->scale_) < FLT_EPSILON &&
|
||||
in_quant_arg_->zp_ == out_quant_arg_->zp_) {
|
||||
if (no_crop_) {
|
||||
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(int8_t));
|
||||
} else {
|
||||
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
|
||||
sizeof(int8_t));
|
||||
}
|
||||
} else {
|
||||
if (no_crop_) {
|
||||
BatchToSpaceNoCropForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_,
|
||||
in_quant_arg_, out_quant_arg_);
|
||||
} else {
|
||||
BatchToSpaceForNHWCInt8(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_,
|
||||
in_quant_arg_, out_quant_arg_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,10 +35,14 @@ class BatchToSpaceInt8CPUKernel : public LiteKernel {
|
|||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Processinput();
|
||||
|
||||
private:
|
||||
QuantArg *in_quant_arg_ = nullptr;
|
||||
QuantArg *out_quant_arg_ = nullptr;
|
||||
int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE] = {0};
|
||||
int32_t crops_[COMM_SHAPE_SIZE] = {0};
|
||||
bool no_crop_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue