!7310 fix space_to_batch_int8 padding bug

Merge pull request !7310 from XianglongZeng/myms
This commit is contained in:
mindspore-ci-bot 2020-10-15 09:40:15 +08:00 committed by Gitee
commit 60da54651b
3 changed files with 16 additions and 10 deletions

View File

@ -16,8 +16,7 @@
#include "nnacl/int8/space_to_batch_int8.h"
#include "nnacl/arithmetic_common.h"
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape,
int *out_shape) {
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape, int *out_shape) {
int out_dim0 = out_shape[0];
int out_dim1 = out_shape[1];
int out_dim2 = out_shape[2];
@ -46,7 +45,8 @@ void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_size
}
}
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape) {
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape,
int32_t zp) {
int in_h = in_shape[1];
int in_w = in_shape[2];
int in_c = in_shape[3];
@ -64,13 +64,13 @@ void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_offset0 = i * in_strides[0];
for (int pad_h_top = 0; pad_h_top < padding[0]; ++pad_h_top) {
memset(output + out_offset, 0, ped_h_size);
out_offset += ped_h_num;
memset(output + out_offset, zp, ped_h_size);
out_offset += ped_h_num;
}
for (int j = 0; j < in_h; ++j) {
size_t in_offset1 = in_offset0 + j * in_strides[1];
for (int pad_w_left = 0; pad_w_left < padding[2]; ++pad_w_left) {
memset(output + out_offset, 0, ped_w_size);
memset(output + out_offset, zp, ped_w_size);
out_offset += out_c;
}
for (int k = 0; k < in_w; ++k) {
@ -79,12 +79,12 @@ void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_
out_offset += in_c;
}
for (int pad_w_right = 0; pad_w_right < padding[3]; ++pad_w_right) {
memset(output + out_offset, 0, ped_w_size);
memset(output + out_offset, zp, ped_w_size);
out_offset += out_c;
}
}
for (int pad_h_bottom = 0; pad_h_bottom < padding[1]; ++pad_h_bottom) {
memset(output + out_offset, 0, ped_h_size);
memset(output + out_offset, zp, ped_h_size);
out_offset += ped_h_num;
}
}

View File

@ -22,7 +22,8 @@
extern "C" {
#endif
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape, int *out_shape);
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape);
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape,
int32_t zp);
#ifdef __cplusplus
}
#endif

View File

@ -36,6 +36,11 @@ int SpaceToBatchInt8CPUKernel::Run() {
auto input_ptr = reinterpret_cast<const int8_t *>(input_tensor->MutableData());
auto output_ptr = reinterpret_cast<int8_t *>(output_tensor->MutableData());
SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
if (output_tensor->GetQuantParams().empty()) {
MS_LOG(ERROR) << "SpaceToBatchInt8 need quantization parameters which is not found.";
return RET_ERROR;
}
auto quant_arg = output_tensor->GetQuantParams().front();
if (param->need_paddings_) {
padded_input_ = context_->allocator->Malloc(param->padded_input_element_num * sizeof(int8_t));
@ -45,7 +50,7 @@ int SpaceToBatchInt8CPUKernel::Run() {
}
auto padded_input = reinterpret_cast<int8_t *>(padded_input_);
DoSpaceToBatchPaddingNHWCInt8(input_ptr, padded_input, param->input_shape_, param->paddings_,
param->padded_in_shape_);
param->padded_in_shape_, quant_arg.zeroPoint);
DoSpaceToBatchNHWCInt8(padded_input, output_ptr, param->block_sizes_, param->padded_in_shape_,
param->output_shape_);
FreeTmpBuffer();