!29154 add DynamicQuant kernel

Merge pull request !29154 from yeyunpeng2020/dynamic_quant
This commit is contained in:
i-robot 2022-01-17 08:53:31 +00:00 committed by Gitee
commit faf7e26381
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 344 additions and 15 deletions

View File

@ -0,0 +1,28 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/dynamic_quant_int8.h"
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) {
for (int i = 0; i < count; ++i) {
if (data[i] < *real_min) {
*real_min = data[i];
}
if (data[i] > *real_max) {
*real_max = data[i];
}
}
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_
#define MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/power_parameter.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_

View File

@ -77,7 +77,7 @@ enum QuantType: int {
PostTraining, // deprecated, use QUANT_ALL instead
QUANT_WEIGHT,
QUANT_ALL,
QUANT_DANAMIC,
QUANT_DYNAMIC,
}
table Primitive {

View File

@ -0,0 +1,208 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/dynamic_quant.h"
#include <vector>
#include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "nnacl/dynamic_quant_parameter.h"
#include "nnacl/int8/dynamic_quant_int8.h"
#include "nnacl/int8/quant_dtype_cast_int8.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_DynamicQuant;
namespace mindspore::kernel {
namespace {
constexpr int kBucketNums = 8;
constexpr int kMinNums = 512;
} // namespace
int DynamicQuantCPUKernel::Prepare() {
auto in_tensor = in_tensors_.front();
CHECK_NULL_RETURN(in_tensor);
auto out_tensor = out_tensors_.front();
CHECK_NULL_RETURN(out_tensor);
auto param = reinterpret_cast<DynamicQuantParameter *>(op_parameter_);
CHECK_NULL_RETURN(param);
src_dtype_ = in_tensor->data_type();
dst_dtype_ = param->dst_type_;
symmetric_ = param->symmetric_;
if (out_tensor->data_type() != dst_dtype_) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int DynamicQuantCPUKernel::ReSize() {
auto in_tensor = in_tensors_.front();
num_unit_ = static_cast<int>(in_tensor->ElementsNum());
if (num_unit_ < kMinNums) {
thread_n_num_ = 1;
} else {
thread_n_num_ = MSMIN(thread_num_, num_unit_);
// Limit for 8 thread
thread_n_num_ = MSMIN(thread_n_num_, kBucketNums);
}
for (int i = 0; i < kBucketNums; ++i) {
real_min_array_[i] = FLT_MAX;
real_max_array_[i] = FLT_MIN;
}
MS_CHECK_GT(thread_n_num_, 0, RET_ERROR);
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
return RET_OK;
}
int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int thread_offset = task_id * thread_n_stride_;
float *data = float32_ptr_ + thread_offset;
CalculateMinMaxFp32(data, num_unit_thread, &real_min_array_[task_id], &real_max_array_[task_id]);
return RET_OK;
}
int CalculateMinMaxRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
auto ret = g_kernel->CalculateMinMax(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CalculateMinMaxRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
for (int i = 0; i < kBucketNums; i++) {
if (real_min_array_[i] < real_min_) {
real_min_ = real_min_array_[i];
}
if (real_max_array_[i] > real_max_) {
real_max_ = real_max_array_[i];
}
}
}
void DynamicQuantCPUKernel::CalculateScaleZp() {
double scale = (real_max_ - real_min_) / (INT8_MAX - INT8_MIN);
int zp = 0;
if (!symmetric_) {
zp = std::round(INT8_MIN - real_min_ / scale);
}
this->out_tensors_.front()->quant_params().front().scale = scale;
this->out_tensors_.front()->quant_params().front().zeroPoint = zp;
}
int DynamicQuantCPUKernel::QuantData(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int thread_offset = task_id * thread_n_stride_;
auto quant_arg = out_tensors_.front()->quant_params().front();
int ret;
TypeId data_type = out_tensors_.front()->data_type();
if (data_type == TypeId::kNumberTypeInt8) {
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
} else {
MS_LOG(ERROR) << "Data type not supported:" << data_type;
return RET_PARAM_INVALID;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int QuantDataRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
auto ret = g_kernel->QuantData(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CalculateMinMaxRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DynamicQuantCPUKernel::Run() {
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data());
float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data());
CHECK_NULL_RETURN(int8_ptr_);
CHECK_NULL_RETURN(float32_ptr_);
auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
return RET_ERROR;
}
ReduceMinMaxFp32();
CalculateScaleZp();
ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
kernel::InnerKernel *DynamicQuantCPUCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
return nullptr;
}
auto *kernel =
new (std::nothrow) DynamicQuantCPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr.";
free(parameter);
return nullptr;
}
if (inputs.size() != 1) {
MS_LOG(ERROR) << "inputs number should be 1, but " << inputs.size() << " is given.";
return nullptr;
}
if (outputs.size() != 1) {
MS_LOG(ERROR) << "outputs number should be 1, but " << outputs.size() << " is given.";
return nullptr;
}
bool support_dtype =
inputs[0]->data_type() == TypeId::kNumberTypeFloat32 && outputs[0]->data_type() == TypeId::kNumberTypeInt8;
if (support_dtype) {
MS_LOG(ERROR) << "Unsupported data type input:" << inputs.front()->data_type()
<< ", output:" << outputs.front()->data_type();
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DynamicQuant, DynamicQuantCPUCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DYNAMIC_QUANT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DYNAMIC_QUANT_H_
#include <vector>
#include <cfloat>
#include "src/inner_kernel.h"
namespace mindspore::kernel {
class DynamicQuantCPUKernel : public InnerKernel {
public:
DynamicQuantCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {}
~DynamicQuantCPUKernel() override = default;
int Prepare() override;
int ReSize() override;
int Run() override;
int QuantData(int task_id);
int CalculateMinMax(int task_id);
private:
void ReduceMinMaxFp32();
void CalculateScaleZp();
private:
int thread_num_;
int thread_n_num_{0};
int thread_n_stride_{0};
int num_unit_{0};
int8_t *int8_ptr_ = nullptr;
float *float32_ptr_ = nullptr;
float real_min_array_[8];
float real_max_array_[8];
float real_min_;
float real_max_;
int32_t src_dtype_{0};
int32_t dst_dtype_{0};
bool symmetric_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DYNAMIC_QUANT_H_

View File

@ -73,7 +73,7 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
return RET_INPUT_PARAM_INVALID;
}
} else if (common_quant->quant_type == schema::QuantType_QUANT_DANAMIC) {
} else if (common_quant->quant_type == schema::QuantType_QUANT_DYNAMIC) {
if (common_quant->bit_num != kQuantBitNumInt8) {
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
return RET_INPUT_PARAM_INVALID;
@ -161,7 +161,7 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
(*quant_type) = schema::QuantType_QUANT_ALL;
return RET_OK;
} else if (quant_type_str == "DYNAMIC_QUANT") {
(*quant_type) = schema::QuantType_QUANT_DANAMIC;
(*quant_type) = schema::QuantType_QUANT_DYNAMIC;
return RET_OK;
} else if (quant_type_str.empty()) {
(*quant_type) = schema::QuantType_QUANT_NONE;

View File

@ -36,7 +36,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
prim::kPrimMatMulFusion,
};
ret = manager.InsertDynamicQuantPass(func_graph, support_dynamic_quant_ops);
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant failed.";
return ret;

View File

@ -619,7 +619,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
if (activation_target_data_type_ == kNumberTypeInt8 || activation_target_data_type_ == kNumberTypeUInt8) {
// add quant_cast
quant::InsertQuantNodeManager inset_quant_node_pass;
status = inset_quant_node_pass.InsertQuantDtypeCastPass(func_graph);
status = inset_quant_node_pass.InsertQuantDtypeCastNode(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -132,7 +132,7 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
return RET_OK;
}
int InsertQuantNodeManager::InsertQuantDtypeCastPass(const FuncGraphPtr &graph) {
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
@ -194,11 +194,11 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DANAMIC);
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DYNAMIC);
return RET_OK;
}
int InsertQuantNodeManager::InsertDynamicQuantPass(const FuncGraphPtr &graph,
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
const std::set<PrimitivePtr> &support_dynamic_quant_ops) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();

View File

@ -31,9 +31,9 @@ class InsertQuantNodeManager {
~InsertQuantNodeManager() = default;
int InsertQuantDtypeCastPass(const FuncGraphPtr &graph);
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph);
int InsertDynamicQuantPass(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
private:
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);

View File

@ -158,31 +158,30 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
}
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
return RET_OK;
}
int status;
SessionModel origin;
if (config->commonQuantParam.is_debug) {
if (config->commonQuantParam.is_debug) { // Bak fp32 model for debug
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization
status = DoFullQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do full quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization
status = DoWeightQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do weight quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DANAMIC) {
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization
status = DoDynamicQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do dynamic quant failed.";
@ -202,6 +201,7 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags
int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) {
std::set<FuncGraphPtr> all_func_graphs{};
GetFuncGraphs(func_graph, &all_func_graphs);
// Support for multi-subgraph models
for (auto &item : all_func_graphs) {
auto status = DoSingleGraphQuantize(item, flags_);
if (status != RET_OK) {