forked from mindspore-Ecosystem/mindspore
!29269 add MatmulDynamicInt8CPUKernel
Merge pull request !29269 from yeyunpeng2020/dynamic_quant
This commit is contained in:
commit
8d40dc2caa
|
@ -331,6 +331,35 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
}
|
||||
#endif
|
||||
|
||||
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
|
||||
float input_scale, int input_zp, const float *filter_scale, size_t stride) {
|
||||
/* *
|
||||
* row4x16-major * row16x4-major => (int8)row-major
|
||||
* support activation per-layer asymmetric && weight per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
size_t ci = r * stride + c;
|
||||
double value = 0;
|
||||
for (int d = 0; d < deep16; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
int32_t value_1 = a[ai] * b[bi];
|
||||
int32_t value_3 = input_zp * b[bi];
|
||||
value += input_scale * filter_scale[c] * (value_1 - value_3);
|
||||
}
|
||||
if (bias != NULL) {
|
||||
value += bias[c];
|
||||
}
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift,
|
||||
const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
|
|
|
@ -46,6 +46,8 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
|
||||
const int32_t *filter_zp);
|
||||
|
||||
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
|
||||
float input_scale, int input_zp, const float *filter_scale, size_t stride);
|
||||
/* 8x4 4x8 -> 8x8 */
|
||||
/* optimize conv */
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
|
|
|
@ -78,4 +78,11 @@ typedef struct MatmulQuantParameter {
|
|||
int32_t *quant_multiplier_;
|
||||
} MatmulQuantParameter;
|
||||
|
||||
typedef struct MatmulDynamicQuantParameter {
|
||||
float input_scale_;
|
||||
int32_t input_zp_;
|
||||
float *filter_scale_;
|
||||
int32_t *filter_zp_;
|
||||
} MatmulDynamicQuantParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_MATMUL_H_
|
||||
|
|
|
@ -32,6 +32,7 @@ using mindspore::schema::PrimitiveType_DynamicQuant;
|
|||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kBucketNums = 8;
|
||||
constexpr int k8Bit = 8;
|
||||
constexpr int kMinNums = 512;
|
||||
} // namespace
|
||||
int DynamicQuantCPUKernel::Prepare() {
|
||||
|
@ -106,16 +107,22 @@ void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
|
|||
real_max_ = real_max_array_[i];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DynamicQuantCPUKernel::CalculateScaleZp() {
|
||||
lite::LiteQuantParam quant_parm;
|
||||
double scale = (real_max_ - real_min_) / (INT8_MAX - INT8_MIN);
|
||||
int zp = 0;
|
||||
if (!symmetric_) {
|
||||
zp = std::round(INT8_MIN - real_min_ / scale);
|
||||
zp = static_cast<int>(std::round(INT8_MIN - real_min_ / scale));
|
||||
}
|
||||
this->out_tensors_.front()->quant_params().front().scale = scale;
|
||||
this->out_tensors_.front()->quant_params().front().zeroPoint = zp;
|
||||
quant_parm.scale = scale;
|
||||
quant_parm.zeroPoint = zp;
|
||||
quant_parm.bitNum = k8Bit;
|
||||
quant_parm.inited = true;
|
||||
this->out_tensors_.front()->AddQuantParam(quant_parm);
|
||||
return;
|
||||
}
|
||||
|
||||
int DynamicQuantCPUKernel::QuantData(int task_id) {
|
||||
|
@ -196,7 +203,7 @@ kernel::InnerKernel *DynamicQuantCPUCreator(const std::vector<lite::Tensor *> &i
|
|||
}
|
||||
bool support_dtype =
|
||||
inputs[0]->data_type() == TypeId::kNumberTypeFloat32 && outputs[0]->data_type() == TypeId::kNumberTypeInt8;
|
||||
if (support_dtype) {
|
||||
if (!support_dtype) {
|
||||
MS_LOG(ERROR) << "Unsupported data type input:" << inputs.front()->data_type()
|
||||
<< ", output:" << outputs.front()->data_type();
|
||||
return nullptr;
|
||||
|
|
|
@ -0,0 +1,347 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kHasBiasSize = 3;
|
||||
constexpr int kMinInputSize = 2;
|
||||
constexpr int kOutputSize = 1;
|
||||
constexpr int kSize1 = 1;
|
||||
constexpr int kSize2 = 2;
|
||||
} // namespace
|
||||
|
||||
int MatmulDynamicInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto op = reinterpret_cast<MatmulDynamicInt8CPUKernel *>(cdata);
|
||||
auto ret = op->RunImpl(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
||||
int stride = thread_stride_ * col_tile_;
|
||||
int cur_stride = task_id * stride;
|
||||
int res_stride = param_->col_ - cur_stride;
|
||||
int cur_oc = MSMIN(stride, res_stride);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
DynamicMatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_, quant_param_->input_scale_,
|
||||
quant_param_->input_zp_, quant_param_->filter_scale_, param_->col_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MatmulDynamicInt8CPUKernel::~MatmulDynamicInt8CPUKernel() {
|
||||
FreeQuantParam();
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::FreeQuantParam() {
|
||||
if (quant_param_ != nullptr) {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
free(quant_param_);
|
||||
quant_param_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::MallocQuantParam() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
|
||||
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||
if (quant_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
|
||||
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_scale_);
|
||||
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
|
||||
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_zp_);
|
||||
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
MS_CHECK_TRUE_RET(static_cast<int>(weight_quant_params.size()) == channel_num_, RET_ERROR);
|
||||
for (int i = 0; i < channel_num_; i++) {
|
||||
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
|
||||
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitInputQuantParam() {
|
||||
auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
|
||||
if (in_quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "invalid in quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
|
||||
quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::InitParameter() {
|
||||
param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr);
|
||||
param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr);
|
||||
#ifdef ENABLE_ARM32
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C2NUM;
|
||||
deep_tile_ = C16NUM;
|
||||
#elif ENABLE_ARM64
|
||||
support_sdot_ = mindspore::lite::IsSupportSDot();
|
||||
row_tile_ = C4NUM;
|
||||
if (support_sdot_) {
|
||||
col_tile_ = C16NUM;
|
||||
deep_tile_ = C4NUM;
|
||||
} else {
|
||||
col_tile_ = C4NUM;
|
||||
deep_tile_ = C16NUM;
|
||||
}
|
||||
#else
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C4NUM;
|
||||
deep_tile_ = C16NUM;
|
||||
#endif
|
||||
if (param_->a_transpose_) {
|
||||
a_pack_func_ = RowMajor2Col16x4MajorInt8;
|
||||
} else {
|
||||
a_pack_func_ = RowMajor2Row16x4MajorInt8;
|
||||
}
|
||||
if (param_->b_transpose_) {
|
||||
#ifdef ENABLE_ARM32
|
||||
b_pack_func_ = RowMajor2Row2x16MajorInt8;
|
||||
#elif ENABLE_ARM64
|
||||
if (support_sdot_) {
|
||||
b_pack_func_ = RowMajor2Row4x16MajorInt8;
|
||||
} else {
|
||||
b_pack_func_ = RowMajor2Row16x4MajorInt8;
|
||||
}
|
||||
#else
|
||||
b_pack_func_ = RowMajor2Row16x4MajorInt8;
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_ARM32
|
||||
b_pack_func_ = RowMajor2Col16x2MajorInt8;
|
||||
#elif ENABLE_ARM64
|
||||
if (support_sdot_) {
|
||||
b_pack_func_ = RowMajor2Col4x16MajorInt8;
|
||||
} else {
|
||||
b_pack_func_ = RowMajor2Col16x4MajorInt8;
|
||||
}
|
||||
#else
|
||||
b_pack_func_ = RowMajor2Col16x4MajorInt8;
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::ResizeParameter() {
|
||||
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
|
||||
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
|
||||
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulDynamicInt8CPUKernel::FreeTmpBuffer() {
|
||||
if (pack_a_ptr_ != nullptr) {
|
||||
free(pack_a_ptr_);
|
||||
pack_a_ptr_ = nullptr;
|
||||
}
|
||||
if (pack_b_ptr_ != nullptr) {
|
||||
free(pack_b_ptr_);
|
||||
pack_b_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
|
||||
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
CHECK_NULL_RETURN(b_pack_func_);
|
||||
if (param_->b_transpose_) {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
|
||||
} else {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
|
||||
}
|
||||
}
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitTmpBuffer() {
|
||||
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_a_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
pack_b_ptr_ =
|
||||
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_b_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::CopyBias() {
|
||||
if (in_tensors_.size() == kHasBiasSize) {
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
fp32_bias_ptr_ = reinterpret_cast<float *>(bias_tensor->data());
|
||||
if (fp32_bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
FreeTmpBuffer();
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
} else {
|
||||
fp32_bias_ptr_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
|
||||
InitParameter();
|
||||
|
||||
auto ret = MallocQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CopyBias();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::ReSize() {
|
||||
int batch = 1;
|
||||
auto x_shape = in_tensors_.at(0)->shape();
|
||||
auto o_shape = out_tensors_.at(0)->shape();
|
||||
MS_ASSERT(x_shape.size() >= kSize2);
|
||||
for (size_t i = 0; i < x_shape.size() - kSize2; ++i) {
|
||||
batch *= x_shape[i];
|
||||
}
|
||||
param_->batch = batch;
|
||||
MS_ASSERT(o_shape.size() >= kSize2);
|
||||
param_->row_ = o_shape[o_shape.size() - kSize2];
|
||||
param_->col_ = o_shape[o_shape.size() - kSize1];
|
||||
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
|
||||
|
||||
FreeTmpBuffer();
|
||||
|
||||
ResizeParameter();
|
||||
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
if (param_->b_const_ == true) {
|
||||
TransferB();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::Run() {
|
||||
auto ret = InitInputQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init input quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
if (!param_->b_const_) {
|
||||
ret = TransferB();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
auto *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
|
||||
auto *c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
CHECK_NULL_RETURN(c_ptr);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
|
||||
if (param_->a_transpose_) {
|
||||
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
|
||||
a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
|
||||
} else {
|
||||
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
|
||||
a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
|
||||
}
|
||||
|
||||
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatmulDynamicInt8CPUKernel : public InnerKernel {
|
||||
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
|
||||
|
||||
public:
|
||||
MatmulDynamicInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
}
|
||||
~MatmulDynamicInt8CPUKernel() override;
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
public:
|
||||
int RunImpl(int task_id);
|
||||
#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && (!defined(MACHINE_LINUX_ARM64))
|
||||
int RunArm64Sdot();
|
||||
int Arm64SdotImpl(int task_id);
|
||||
int Arm64SdotPre(int task_id);
|
||||
#endif
|
||||
|
||||
private:
|
||||
void InitParameter();
|
||||
void ResizeParameter();
|
||||
int CopyBias();
|
||||
int InitTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
int TransferA();
|
||||
int TransferB();
|
||||
|
||||
int MallocQuantParam();
|
||||
int InitInputQuantParam();
|
||||
int InitFilterQuantParam();
|
||||
void FreeQuantParam();
|
||||
|
||||
private:
|
||||
MatMulParameter *param_ = nullptr;
|
||||
MatmulDynamicQuantParameter *quant_param_ = nullptr;
|
||||
int thread_count_ = 1;
|
||||
int thread_stride_ = 0;
|
||||
int8_t *pack_a_ptr_ = nullptr;
|
||||
int8_t *pack_b_ptr_ = nullptr;
|
||||
float *fp32_bias_ptr_ = nullptr;
|
||||
bool filter_per_channel_ = true;
|
||||
int8_t *batch_input_ptr_ = nullptr;
|
||||
int8_t *batch_weight_ptr_ = nullptr;
|
||||
int8_t *batch_b_ptr_ = nullptr;
|
||||
float *batch_c_ptr_ = nullptr;
|
||||
int row_tile_ = C4NUM;
|
||||
int col_tile_ = C4NUM;
|
||||
int deep_tile_ = C16NUM;
|
||||
int channel_num_ = 0;
|
||||
bool support_sdot_ = false;
|
||||
PackFunc a_pack_func_{nullptr};
|
||||
PackFunc b_pack_func_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_DYNAMIC_INT8_H_
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/matmul_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/matmul_dynamic_int8.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -66,5 +67,33 @@ int MatmulInt8CPUKernel::ReSize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMulFusion, LiteKernelCreator<MatmulInt8CPUKernel>)
|
||||
kernel::InnerKernel *MatmulInt8CPUKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::Context *ctx, const kernel::KernelKey &desc) {
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
InnerKernel *kernel = nullptr;
|
||||
if (parameter->quant_type_ == schema::QuantType_QUANT_ALL) {
|
||||
kernel =
|
||||
new (std::nothrow) MatmulInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
} else if (parameter->quant_type_ == schema::QuantType_QUANT_DYNAMIC) {
|
||||
kernel = new (std::nothrow)
|
||||
MatmulDynamicInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "kernel: " << parameter->name_ << " is unsupported quant type:" << parameter->quant_type_;
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr.";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMulFusion, MatmulInt8CPUKernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue