fix scale-int8 bug

This commit is contained in:
xuanyue 2022-03-12 17:19:10 +08:00
parent cd0a025cea
commit b083ca05b7
2 changed files with 24 additions and 52 deletions

View File

@ -99,72 +99,45 @@ int ScaleInt8CPUKernel::InitScaleOffset() {
}
int ScaleInt8CPUKernel::InitParameter() {
auto in_tensor = in_tensors_.at(0);
auto in_shape = in_tensor->shape();
auto scale_tensor = in_tensors_.at(1);
auto scale_shape = scale_tensor->shape();
auto input0_shape = in_tensors_[FIRST_INPUT]->shape();
auto input1_shape = in_tensors_[SECOND_INPUT]->shape();
if (scale_param_->axis_ < 0) {
scale_param_->axis_ += in_shape.size();
scale_param_->axis_ += input0_shape.size();
}
if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) {
if (input1_shape.size() + scale_param_->axis_ > input0_shape.size()) {
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
return RET_ERROR;
}
for (size_t i = 0; i < scale_shape.size(); i++) {
if (in_shape[i + scale_param_->axis_] != scale_shape[i]) {
for (size_t i = 0; i < input1_shape.size(); i++) {
if (input0_shape[i + scale_param_->axis_] != input1_shape[i]) {
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
return RET_ERROR;
}
}
tile_para = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (tile_para == nullptr) {
MS_LOG(ERROR) << "malloc tile parameter failed.";
return RET_ERROR;
tile_para = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
}
size_t input0_size = in_tensors_.at(0)->shape().size();
size_t input1_size = in_tensors_.at(1)->shape().size();
size_t output_size = out_tensors_.at(0)->shape().size();
auto input1_shape = in_tensors_.at(1)->shape();
tile_para->ndim_ = output_size;
// supplement shape of scale tensor with number 1
size_t len = input0_size - scale_param_->axis_;
second_in_shape_ = input1_shape;
if (len != input1_size) {
second_in_shape_.resize(len);
size_t i = 0;
for (; i < input1_size; ++i) {
second_in_shape_.at(i) = input1_shape.at(i);
}
for (; i < len; ++i) {
second_in_shape_.at(i) = 1;
}
input1_size = len;
MS_CHECK_TRUE_MSG(tile_para != nullptr, RET_ERROR, "scale's arithmetic-param is a nullptr.");
auto out_shape = out_tensors_.front()->shape();
tile_para->ndim_ = out_shape.size();
int i = 0;
for (; i < scale_param_->axis_; ++i) {
tile_para->in_shape0_[i] = input0_shape[i];
tile_para->in_shape1_[i] = 1;
tile_para->out_shape_[i] = out_shape[i];
}
if (input0_size == input1_size) {
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
} else {
MS_CHECK_TRUE_RET(input0_size > input1_size, RET_ERROR);
size_t fill_dim_num = input0_size - input1_size;
int j = 0;
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
if (i < fill_dim_num) {
tile_para->in_shape1_[i] = 1;
} else {
tile_para->in_shape1_[i] = second_in_shape_.at(j++);
}
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
for (; i < static_cast<int>(input1_shape.size()) + scale_param_->axis_; ++i) {
tile_para->in_shape0_[i] = input0_shape[i];
tile_para->in_shape1_[i] = input1_shape[i];
tile_para->out_shape_[i] = out_shape[i];
}
for (; i < static_cast<int>(tile_para->ndim_); ++i) {
tile_para->in_shape0_[i] = input0_shape[i];
tile_para->in_shape1_[i] = 1;
tile_para->out_shape_[i] = out_shape[i];
}
return RET_OK;
}

View File

@ -50,7 +50,6 @@ class ScaleInt8CPUKernel : public InnerKernel {
const lite::InnerContext *ctx_ = nullptr;
ScaleParameter *scale_param_ = nullptr;
ArithmeticParameter *tile_para = nullptr;
std::vector<int> second_in_shape_;
int thread_count_ = 1;
int64_t elements_num_ = 0;
int64_t count_unit_ = 0;