forked from mindspore-Ecosystem/mindspore
!6733 scale int8
Merge pull request !6733 from zhaozhenlong/lite/op/scale_int8
This commit is contained in:
commit
dc1368f09d
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/scale_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
void ScaleInnerInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int outer_start, int outer_end,
|
||||
int axis_size, int inner_size, const ScaleParameter *scale_param) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size * inner_size;
|
||||
for (int i = 0; i < axis_size; i++) {
|
||||
int axis_offset = out_offset + i * inner_size;
|
||||
int in_index = 0;
|
||||
|
||||
for (; in_index < inner_size; in_index++) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
|
||||
int input_mul_scale =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
|
||||
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp = input_mul_scale + scale_param->output_zp_;
|
||||
tmp = tmp > INT8_MAX ? INT8_MAX : tmp;
|
||||
tmp = tmp < INT8_MIN ? INT8_MIN : tmp;
|
||||
out_data[in_offset] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScaleInnerWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int outer_start, int outer_end, int axis_size, int inner_size,
|
||||
const ScaleParameter *scale_param) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size * inner_size;
|
||||
for (int i = 0; i < axis_size; i++) {
|
||||
int axis_offset = out_offset + i * inner_size;
|
||||
int in_index = 0;
|
||||
|
||||
for (; in_index < inner_size; in_index++) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_);
|
||||
int input_mul_scale =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
|
||||
tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_),
|
||||
scale_param->scale_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp_bias = offset[i] - scale_param->offset_zp_;
|
||||
int bias = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_),
|
||||
scale_param->offset_mul_arg_.multiplier_),
|
||||
scale_param->scale_mul_arg_.right_shift_);
|
||||
int tmp = input_mul_scale + bias + scale_param->output_zp_;
|
||||
tmp = tmp > INT8_MAX ? INT8_MAX : tmp;
|
||||
tmp = tmp < INT8_MIN ? INT8_MIN : tmp;
|
||||
out_data[in_offset] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
|
||||
const ScaleParameter *scale_param) {
|
||||
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
|
||||
int outer_start = task_id * outer_step;
|
||||
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
|
||||
|
||||
ScaleInnerInt8(in_data, out_data, scale, outer_start, outer_end, scale_param->axis_size_, scale_param->inner_size_,
|
||||
scale_param);
|
||||
}
|
||||
|
||||
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int task_id, const ScaleParameter *scale_param) {
|
||||
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
|
||||
int outer_start = task_id * outer_step;
|
||||
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
|
||||
|
||||
ScaleInnerWithBiasInt8(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
|
||||
scale_param->inner_size_, scale_param);
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_SCALE_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_SCALE_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/scale.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id,
|
||||
const ScaleParameter *scale_param);
|
||||
void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset,
|
||||
int task_id, const ScaleParameter *scale_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_SCALE_INT8_H_
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
#define MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
|
||||
#include <mindspore/lite/nnacl/quantization/quantize.h>
|
||||
#include "nnacl/op_base.h"
|
||||
typedef struct ScaleParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
@ -25,6 +26,13 @@ typedef struct ScaleParameter {
|
|||
int inner_size_;
|
||||
int axis_;
|
||||
bool const_scale_;
|
||||
bool const_offset_;
|
||||
QuantMulArg scale_mul_arg_;
|
||||
QuantMulArg offset_mul_arg_;
|
||||
int input_zp_;
|
||||
int scale_zp_;
|
||||
int offset_zp_;
|
||||
int output_zp_;
|
||||
} ScaleParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
|
|
|
@ -43,9 +43,13 @@ ScaleCPUKernel::~ScaleCPUKernel() {
|
|||
|
||||
int ScaleCPUKernel::InitScaleOffset() {
|
||||
auto scale_tensor = in_tensors_.at(1);
|
||||
float *scale_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
float *scale_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
|
||||
if (scale_ptr != nullptr) {
|
||||
scale_param_->const_scale_ = true;
|
||||
if (scale_ != nullptr) {
|
||||
free(scale_);
|
||||
scale_ = nullptr;
|
||||
}
|
||||
scale_ = reinterpret_cast<float *>(malloc(scale_tensor->ElementsNum() * sizeof(float)));
|
||||
if (scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -57,6 +61,10 @@ int ScaleCPUKernel::InitScaleOffset() {
|
|||
scale_ = nullptr;
|
||||
}
|
||||
|
||||
if (offset_ != nullptr) {
|
||||
free(offset_);
|
||||
offset_ = nullptr;
|
||||
}
|
||||
offset_ = reinterpret_cast<float *>(malloc(scale_param_->axis_size_ * sizeof(float)));
|
||||
if (offset_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -65,7 +73,9 @@ int ScaleCPUKernel::InitScaleOffset() {
|
|||
memset(offset_, 0, scale_param_->axis_size_ * sizeof(float));
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto offset_tensor = in_tensors_.at(2);
|
||||
memcpy(offset_, offset_tensor->MutableData(), offset_tensor->ElementsNum() * sizeof(float));
|
||||
if (offset_tensor->data_c() != nullptr) {
|
||||
memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -159,6 +169,10 @@ int ScaleCPUKernel::Run() {
|
|||
auto scale_tensor = in_tensors_[1];
|
||||
scale_ = reinterpret_cast<float *>(scale_tensor->MutableData());
|
||||
}
|
||||
if (offset_ == nullptr) {
|
||||
auto offset_tensor = in_tensors_.at(2);
|
||||
memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
auto out_tensor = out_tensors_.front();
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensor->MutableData());
|
||||
|
||||
|
|
|
@ -0,0 +1,267 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/scale_int8.h"
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "nnacl/int8/scale_int8.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Scale;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr size_t kScaleInputsSize = 2;
|
||||
constexpr size_t kScaleBiasInputsSize = 3;
|
||||
} // namespace
|
||||
ScaleInt8CPUKernel::~ScaleInt8CPUKernel() {
|
||||
if (scale_param_->const_scale_) {
|
||||
if (scale_ != nullptr) {
|
||||
free(scale_);
|
||||
scale_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (has_bias_ && scale_param_->const_offset_) {
|
||||
if (offset_ != nullptr) {
|
||||
free(offset_);
|
||||
offset_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::InitScaleOffset() {
|
||||
auto scale_tensor = in_tensors_.at(1);
|
||||
int8_t *scale_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
|
||||
if (scale_ptr != nullptr) {
|
||||
scale_param_->const_scale_ = true;
|
||||
if (scale_ != nullptr) {
|
||||
free(scale_);
|
||||
scale_ = nullptr;
|
||||
}
|
||||
scale_ = reinterpret_cast<int8_t *>(malloc(scale_tensor->ElementsNum() * sizeof(int8_t)));
|
||||
if (scale_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(int8_t));
|
||||
} else {
|
||||
scale_param_->const_scale_ = false;
|
||||
scale_ = nullptr;
|
||||
}
|
||||
|
||||
if (in_tensors_.size() == 3) {
|
||||
has_bias_ = true;
|
||||
auto offset_tensor = in_tensors_.at(2);
|
||||
int8_t *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
|
||||
if (offset_ptr != nullptr) {
|
||||
scale_param_->const_offset_ = true;
|
||||
if (offset_ != nullptr) {
|
||||
free(offset_);
|
||||
offset_ = nullptr;
|
||||
}
|
||||
offset_ = reinterpret_cast<int8_t *>(malloc(offset_tensor->ElementsNum() * sizeof(int8_t)));
|
||||
if (offset_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(offset_, offset_ptr, offset_tensor->ElementsNum() * sizeof(int8_t));
|
||||
} else {
|
||||
scale_param_->const_offset_ = false;
|
||||
offset_ = nullptr;
|
||||
}
|
||||
} else {
|
||||
has_bias_ = false;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::InitParameter() {
|
||||
auto in_tensor = in_tensors_.at(0);
|
||||
auto in_shape = in_tensor->shape();
|
||||
auto scale_tensor = in_tensors_.at(1);
|
||||
auto scale_shape = scale_tensor->shape();
|
||||
|
||||
if (scale_param_->axis_ < 0) {
|
||||
scale_param_->axis_ = scale_param_->axis_ + in_shape.size();
|
||||
}
|
||||
if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) {
|
||||
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scale_param_->outer_size_ = 1;
|
||||
scale_param_->axis_size_ = 1;
|
||||
scale_param_->inner_size_ = 1;
|
||||
for (int i = 0; i < scale_param_->axis_; i++) {
|
||||
scale_param_->outer_size_ *= in_shape[i];
|
||||
}
|
||||
for (size_t i = 0; i < scale_shape.size(); i++) {
|
||||
if (in_shape[i + scale_param_->axis_] != scale_shape[i]) {
|
||||
MS_LOG(ERROR) << "Scale tensor shape is incorrect.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scale_param_->axis_size_ *= in_shape[i + scale_param_->axis_];
|
||||
}
|
||||
for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) {
|
||||
scale_param_->inner_size_ *= in_shape[i];
|
||||
}
|
||||
scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::InitQuantArgs() {
|
||||
auto input = in_tensors_.at(0);
|
||||
auto scale = in_tensors_.at(1);
|
||||
auto output = out_tensors_.at(0);
|
||||
auto input_scale = input->GetQuantParams().front().scale;
|
||||
auto scale_scale = scale->GetQuantParams().front().scale;
|
||||
auto output_scale = output->GetQuantParams().front().scale;
|
||||
scale_param_->input_zp_ = input->GetQuantParams().front().zeroPoint;
|
||||
scale_param_->scale_zp_ = scale->GetQuantParams().front().zeroPoint;
|
||||
scale_param_->output_zp_ = output->GetQuantParams().front().zeroPoint;
|
||||
|
||||
// (in * scale + offset) / output
|
||||
const double input_output_multiplier = input_scale * scale_scale / output_scale;
|
||||
int shift;
|
||||
QuantizeMultiplier(input_output_multiplier, &scale_param_->scale_mul_arg_.multiplier_, &shift);
|
||||
scale_param_->scale_mul_arg_.left_shift_ = shift > 0 ? shift : 0;
|
||||
scale_param_->scale_mul_arg_.right_shift_ = shift < 0 ? -shift : 0;
|
||||
|
||||
if (in_tensors_.size() == kScaleBiasInputsSize) {
|
||||
auto offset = in_tensors_.at(2);
|
||||
auto offset_scale = offset->GetQuantParams().front().scale;
|
||||
scale_param_->offset_zp_ = offset->GetQuantParams().front().zeroPoint;
|
||||
|
||||
const double offset_multiplier = offset_scale / output_scale;
|
||||
QuantizeMultiplier(offset_multiplier, &scale_param_->offset_mul_arg_.multiplier_, &shift);
|
||||
scale_param_->offset_mul_arg_.left_shift_ = shift > 0 ? shift : 0;
|
||||
scale_param_->offset_mul_arg_.right_shift_ = shift < 0 ? -shift : 0;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::Init() {
|
||||
if (in_tensors_.size() < kScaleInputsSize || in_tensors_.size() > kScaleBiasInputsSize) {
|
||||
MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << in_tensors_.size() << " is given.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ReSize();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::ReSize() {
|
||||
auto ret = InitParameter();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale fp32 InitParameter failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitScaleOffset();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitQuantArgs();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale fp32 InitQuantArgs failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::Scale(int task_id) {
|
||||
if (has_bias_) {
|
||||
DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_);
|
||||
} else {
|
||||
DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleRunInt8(void *cdata, int task_id) {
|
||||
auto scale = reinterpret_cast<ScaleInt8CPUKernel *>(cdata);
|
||||
auto ret = scale->Scale(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScaleRunInt8 error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScaleInt8CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
auto in_tensor = in_tensors_.front();
|
||||
input_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
|
||||
if (scale_ == nullptr) {
|
||||
auto scale_tensor = in_tensors_[1];
|
||||
scale_ = reinterpret_cast<int8_t *>(scale_tensor->data_c());
|
||||
}
|
||||
if (has_bias_ && !scale_param_->const_offset_) {
|
||||
offset_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
|
||||
}
|
||||
auto out_tensor = out_tensors_.front();
|
||||
output_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
|
||||
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
kernel::LiteKernel *CpuScaleInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Scale);
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "opParameter is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) ScaleInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "New kernel fails.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Scale, CpuScaleInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SCALE_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SCALE_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/scale.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class ScaleInt8CPUKernel : public LiteKernel {
|
||||
public:
|
||||
ScaleInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
scale_param_ = reinterpret_cast<ScaleParameter *>(op_parameter_);
|
||||
}
|
||||
~ScaleInt8CPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int InitParameter();
|
||||
int InitScaleOffset();
|
||||
int Scale(int task_id);
|
||||
|
||||
private:
|
||||
int8_t *input_ptr_ = nullptr;
|
||||
int8_t *scale_ = nullptr;
|
||||
int8_t *offset_ = nullptr;
|
||||
int8_t *output_ptr_ = nullptr;
|
||||
bool has_bias_;
|
||||
ScaleParameter *scale_param_;
|
||||
|
||||
int InitQuantArgs();
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SCALE_INT8_H_
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/tensor.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "nnacl/int8/scale_int8.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::lite::QuantArg;
|
||||
using mindspore::lite::Tensor;
|
||||
|
||||
class TestScaleInt8 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestScaleInt8() = default;
|
||||
void Prepare(const std::vector<int> &in_shape, int8_t *input_data, const std::vector<int> &scale_shape,
|
||||
int8_t *scale_data, const std::vector<int> &bias_shape, int8_t *bias_data,
|
||||
const std::vector<int> &out_shape, int8_t *output_data, int axis, bool has_bias);
|
||||
void TearDown() override;
|
||||
|
||||
public:
|
||||
int thread_num_ = 1;
|
||||
|
||||
ScaleParameter param_ = {};
|
||||
Tensor in_tensor_;
|
||||
Tensor scale_tensor_;
|
||||
Tensor bias_tensor_;
|
||||
Tensor out_tensor_;
|
||||
std::vector<Tensor *> inputs;
|
||||
std::vector<Tensor *> outputs = {&out_tensor_};
|
||||
kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Scale};
|
||||
kernel::KernelCreator creator_ = nullptr;
|
||||
lite::InnerContext ctx_ = lite::InnerContext();
|
||||
kernel::LiteKernel *kernel_ = nullptr;
|
||||
const QuantArg quant_in_ = {0.005f, 5};
|
||||
const QuantArg quant_scale_ = {0.1f, 1};
|
||||
const QuantArg quant_bias_ = {0.002f, 2};
|
||||
const QuantArg quant_out_ = {0.01f, 1};
|
||||
float err_tol_ = 0.05;
|
||||
};
|
||||
|
||||
void TestScaleInt8::TearDown() {
|
||||
in_tensor_.SetData(nullptr);
|
||||
scale_tensor_.SetData(nullptr);
|
||||
bias_tensor_.SetData(nullptr);
|
||||
out_tensor_.SetData(nullptr);
|
||||
}
|
||||
|
||||
void TestScaleInt8::Prepare(const std::vector<int> &in_shape, int8_t *input_data, const std::vector<int> &scale_shape,
|
||||
int8_t *scale_data, const std::vector<int> &bias_shape, int8_t *bias_data,
|
||||
const std::vector<int> &out_shape, int8_t *output_data, int axis, bool has_bias) {
|
||||
in_tensor_.set_data_type(kNumberTypeInt8);
|
||||
in_tensor_.set_shape(in_shape);
|
||||
in_tensor_.SetData(input_data);
|
||||
in_tensor_.AddQuantParam(quant_in_);
|
||||
scale_tensor_.set_data_type(kNumberTypeInt8);
|
||||
scale_tensor_.set_shape(scale_shape);
|
||||
scale_tensor_.SetData(scale_data);
|
||||
scale_tensor_.AddQuantParam(quant_scale_);
|
||||
|
||||
inputs.clear();
|
||||
inputs.emplace_back(&in_tensor_);
|
||||
inputs.emplace_back(&scale_tensor_);
|
||||
if (has_bias) {
|
||||
bias_tensor_.set_data_type(kNumberTypeInt8);
|
||||
bias_tensor_.set_shape(bias_shape);
|
||||
bias_tensor_.SetData(bias_data);
|
||||
bias_tensor_.AddQuantParam(quant_bias_);
|
||||
inputs.emplace_back(&bias_tensor_);
|
||||
}
|
||||
|
||||
out_tensor_.set_data_type(kNumberTypeInt8);
|
||||
out_tensor_.set_shape(out_shape);
|
||||
out_tensor_.SetData(output_data);
|
||||
out_tensor_.AddQuantParam(quant_out_);
|
||||
|
||||
param_.axis_ = axis;
|
||||
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_);
|
||||
|
||||
ctx_.thread_num_ = thread_num_;
|
||||
ASSERT_EQ(lite::RET_OK, ctx_.Init());
|
||||
kernel_ = creator_(inputs, outputs, reinterpret_cast<OpParameter *>(¶m_), &ctx_, desc_, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestScaleInt8, scale1) {
|
||||
/* 1 2 2 1 NHWC */
|
||||
int8_t input_data[96] = {0, 1, 2, 3};
|
||||
int8_t scale_data[4] = {2, 2, 2, 2};
|
||||
int8_t bias_data[4] = {3, 3, 3, 3};
|
||||
int8_t out_data[4] = {0};
|
||||
bool has_bias = true;
|
||||
|
||||
int axis = 1;
|
||||
std::vector<int> input_shape = {1, 2, 2, 1};
|
||||
std::vector<int> scale_shape = {2, 2, 1};
|
||||
std::vector<int> bias_shape = {2, 2, 1};
|
||||
std::vector<int> output_shape = {1, 2, 2, 1};
|
||||
int output_size = 4;
|
||||
int8_t correct[] = {1, 1, 1, 1};
|
||||
|
||||
thread_num_ = 2;
|
||||
Prepare(input_shape, input_data, scale_shape, scale_data, bias_shape, bias_data, output_shape, out_data, axis,
|
||||
has_bias);
|
||||
auto ret = kernel_->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
err_tol_ = 0.01;
|
||||
CompareOutputInt8(out_data, correct, output_size, err_tol_);
|
||||
}
|
||||
|
||||
TEST_F(TestScaleInt8, scale2) {
|
||||
/* 1 2 2 1 NHWC */
|
||||
int8_t input_data[96] = {0, 10, 20, 30};
|
||||
int8_t scale_data[4] = {2, 2, 2, 2};
|
||||
int8_t bias_data[4] = {3, 3, 3, 3};
|
||||
int8_t out_data[4] = {0};
|
||||
bool has_bias = true;
|
||||
|
||||
int axis = 1;
|
||||
std::vector<int> input_shape = {1, 2, 2, 1};
|
||||
std::vector<int> scale_shape = {2, 2, 1};
|
||||
std::vector<int> bias_shape = {2, 2, 1};
|
||||
std::vector<int> output_shape = {1, 2, 2, 1};
|
||||
int output_size = 4;
|
||||
int8_t correct[] = {1, 1, 2, 2};
|
||||
|
||||
thread_num_ = 2;
|
||||
Prepare(input_shape, input_data, scale_shape, scale_data, bias_shape, bias_data, output_shape, out_data, axis,
|
||||
has_bias);
|
||||
auto ret = kernel_->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
err_tol_ = 0.01;
|
||||
CompareOutputInt8(out_data, correct, output_size, err_tol_);
|
||||
}
|
||||
|
||||
TEST_F(TestScaleInt8, scale3) {
|
||||
/* 1 2 2 1 NHWC */
|
||||
int8_t input_data[96] = {0, 90, 100, 120};
|
||||
int8_t scale_data[4] = {2, 2, 2, 2};
|
||||
int8_t bias_data[4] = {3, 3, 3, 3};
|
||||
int8_t out_data[4] = {0};
|
||||
bool has_bias = false;
|
||||
|
||||
int axis = 1;
|
||||
std::vector<int> input_shape = {1, 2, 2, 1};
|
||||
std::vector<int> scale_shape = {2, 2, 1};
|
||||
std::vector<int> bias_shape = {2, 2, 1};
|
||||
std::vector<int> output_shape = {1, 2, 2, 1};
|
||||
int output_size = 4;
|
||||
int8_t correct[] = {1, 5, 6, 7};
|
||||
|
||||
thread_num_ = 2;
|
||||
Prepare(input_shape, input_data, scale_shape, scale_data, bias_shape, bias_data, output_shape, out_data, axis,
|
||||
has_bias);
|
||||
auto ret = kernel_->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
err_tol_ = 0.01;
|
||||
CompareOutputInt8(out_data, correct, output_size, err_tol_);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue