forked from mindspore-Ecosystem/mindspore
!7310 fix space_to_batch_int8 padding bug
Merge pull request !7310 from XianglongZeng/myms
This commit is contained in:
commit
60da54651b
|
@ -16,8 +16,7 @@
|
|||
#include "nnacl/int8/space_to_batch_int8.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape,
|
||||
int *out_shape) {
|
||||
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape, int *out_shape) {
|
||||
int out_dim0 = out_shape[0];
|
||||
int out_dim1 = out_shape[1];
|
||||
int out_dim2 = out_shape[2];
|
||||
|
@ -46,7 +45,8 @@ void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_size
|
|||
}
|
||||
}
|
||||
|
||||
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape) {
|
||||
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape,
|
||||
int32_t zp) {
|
||||
int in_h = in_shape[1];
|
||||
int in_w = in_shape[2];
|
||||
int in_c = in_shape[3];
|
||||
|
@ -64,13 +64,13 @@ void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_
|
|||
for (int i = 0; i < in_shape[0]; ++i) {
|
||||
size_t in_offset0 = i * in_strides[0];
|
||||
for (int pad_h_top = 0; pad_h_top < padding[0]; ++pad_h_top) {
|
||||
memset(output + out_offset, 0, ped_h_size);
|
||||
out_offset += ped_h_num;
|
||||
memset(output + out_offset, zp, ped_h_size);
|
||||
out_offset += ped_h_num;
|
||||
}
|
||||
for (int j = 0; j < in_h; ++j) {
|
||||
size_t in_offset1 = in_offset0 + j * in_strides[1];
|
||||
for (int pad_w_left = 0; pad_w_left < padding[2]; ++pad_w_left) {
|
||||
memset(output + out_offset, 0, ped_w_size);
|
||||
memset(output + out_offset, zp, ped_w_size);
|
||||
out_offset += out_c;
|
||||
}
|
||||
for (int k = 0; k < in_w; ++k) {
|
||||
|
@ -79,12 +79,12 @@ void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_
|
|||
out_offset += in_c;
|
||||
}
|
||||
for (int pad_w_right = 0; pad_w_right < padding[3]; ++pad_w_right) {
|
||||
memset(output + out_offset, 0, ped_w_size);
|
||||
memset(output + out_offset, zp, ped_w_size);
|
||||
out_offset += out_c;
|
||||
}
|
||||
}
|
||||
for (int pad_h_bottom = 0; pad_h_bottom < padding[1]; ++pad_h_bottom) {
|
||||
memset(output + out_offset, 0, ped_h_size);
|
||||
memset(output + out_offset, zp, ped_h_size);
|
||||
out_offset += ped_h_num;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, int *block_sizes, int *in_shape, int *out_shape);
|
||||
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape);
|
||||
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, int *in_shape, int *padding, int *out_shape,
|
||||
int32_t zp);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -36,6 +36,11 @@ int SpaceToBatchInt8CPUKernel::Run() {
|
|||
auto input_ptr = reinterpret_cast<const int8_t *>(input_tensor->MutableData());
|
||||
auto output_ptr = reinterpret_cast<int8_t *>(output_tensor->MutableData());
|
||||
SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
|
||||
if (output_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "SpaceToBatchInt8 need quantization parameters which is not found.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_arg = output_tensor->GetQuantParams().front();
|
||||
|
||||
if (param->need_paddings_) {
|
||||
padded_input_ = context_->allocator->Malloc(param->padded_input_element_num * sizeof(int8_t));
|
||||
|
@ -45,7 +50,7 @@ int SpaceToBatchInt8CPUKernel::Run() {
|
|||
}
|
||||
auto padded_input = reinterpret_cast<int8_t *>(padded_input_);
|
||||
DoSpaceToBatchPaddingNHWCInt8(input_ptr, padded_input, param->input_shape_, param->paddings_,
|
||||
param->padded_in_shape_);
|
||||
param->padded_in_shape_, quant_arg.zeroPoint);
|
||||
DoSpaceToBatchNHWCInt8(padded_input, output_ptr, param->block_sizes_, param->padded_in_shape_,
|
||||
param->output_shape_);
|
||||
FreeTmpBuffer();
|
||||
|
|
Loading…
Reference in New Issue