forked from mindspore-Ecosystem/mindspore
!29154 add DynamicQuant kernel
Merge pull request !29154 from yeyunpeng2020/dynamic_quant
This commit is contained in:
commit
faf7e26381
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/dynamic_quant_int8.h"
|
||||
|
||||
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) {
|
||||
for (int i = 0; i < count; ++i) {
|
||||
if (data[i] < *real_min) {
|
||||
*real_min = data[i];
|
||||
}
|
||||
if (data[i] > *real_max) {
|
||||
*real_max = data[i];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_
|
||||
#define MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_INT8_DYNAMIC_QUANT_INT8_H_
|
|
@ -77,7 +77,7 @@ enum QuantType: int {
|
|||
PostTraining, // deprecated, use QUANT_ALL instead
|
||||
QUANT_WEIGHT,
|
||||
QUANT_ALL,
|
||||
QUANT_DANAMIC,
|
||||
QUANT_DYNAMIC,
|
||||
}
|
||||
|
||||
table Primitive {
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/int8/dynamic_quant.h"
|
||||
#include <vector>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/dynamic_quant_parameter.h"
|
||||
#include "nnacl/int8/dynamic_quant_int8.h"
|
||||
#include "nnacl/int8/quant_dtype_cast_int8.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_PARAM_INVALID;
|
||||
using mindspore::schema::PrimitiveType_DynamicQuant;
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kBucketNums = 8;
|
||||
constexpr int kMinNums = 512;
|
||||
} // namespace
|
||||
int DynamicQuantCPUKernel::Prepare() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
CHECK_NULL_RETURN(in_tensor);
|
||||
auto out_tensor = out_tensors_.front();
|
||||
CHECK_NULL_RETURN(out_tensor);
|
||||
auto param = reinterpret_cast<DynamicQuantParameter *>(op_parameter_);
|
||||
CHECK_NULL_RETURN(param);
|
||||
src_dtype_ = in_tensor->data_type();
|
||||
dst_dtype_ = param->dst_type_;
|
||||
symmetric_ = param->symmetric_;
|
||||
if (out_tensor->data_type() != dst_dtype_) {
|
||||
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int DynamicQuantCPUKernel::ReSize() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
num_unit_ = static_cast<int>(in_tensor->ElementsNum());
|
||||
if (num_unit_ < kMinNums) {
|
||||
thread_n_num_ = 1;
|
||||
} else {
|
||||
thread_n_num_ = MSMIN(thread_num_, num_unit_);
|
||||
// Limit for 8 thread
|
||||
thread_n_num_ = MSMIN(thread_n_num_, kBucketNums);
|
||||
}
|
||||
for (int i = 0; i < kBucketNums; ++i) {
|
||||
real_min_array_[i] = FLT_MAX;
|
||||
real_max_array_[i] = FLT_MIN;
|
||||
}
|
||||
MS_CHECK_GT(thread_n_num_, 0, RET_ERROR);
|
||||
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
|
||||
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
|
||||
if (num_unit_thread <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int thread_offset = task_id * thread_n_stride_;
|
||||
float *data = float32_ptr_ + thread_offset;
|
||||
|
||||
CalculateMinMaxFp32(data, num_unit_thread, &real_min_array_[task_id], &real_max_array_[task_id]);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CalculateMinMaxRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
|
||||
auto ret = g_kernel->CalculateMinMax(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalculateMinMaxRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
|
||||
for (int i = 0; i < kBucketNums; i++) {
|
||||
if (real_min_array_[i] < real_min_) {
|
||||
real_min_ = real_min_array_[i];
|
||||
}
|
||||
if (real_max_array_[i] > real_max_) {
|
||||
real_max_ = real_max_array_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicQuantCPUKernel::CalculateScaleZp() {
|
||||
double scale = (real_max_ - real_min_) / (INT8_MAX - INT8_MIN);
|
||||
int zp = 0;
|
||||
if (!symmetric_) {
|
||||
zp = std::round(INT8_MIN - real_min_ / scale);
|
||||
}
|
||||
this->out_tensors_.front()->quant_params().front().scale = scale;
|
||||
this->out_tensors_.front()->quant_params().front().zeroPoint = zp;
|
||||
}
|
||||
|
||||
int DynamicQuantCPUKernel::QuantData(int task_id) {
|
||||
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
|
||||
if (num_unit_thread <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int thread_offset = task_id * thread_n_stride_;
|
||||
auto quant_arg = out_tensors_.front()->quant_params().front();
|
||||
int ret;
|
||||
TypeId data_type = out_tensors_.front()->data_type();
|
||||
if (data_type == TypeId::kNumberTypeInt8) {
|
||||
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
|
||||
quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Data type not supported:" << data_type;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int QuantDataRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto g_kernel = reinterpret_cast<DynamicQuantCPUKernel *>(cdata);
|
||||
auto ret = g_kernel->QuantData(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalculateMinMaxRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DynamicQuantCPUKernel::Run() {
|
||||
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data());
|
||||
float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
CHECK_NULL_RETURN(int8_ptr_);
|
||||
CHECK_NULL_RETURN(float32_ptr_);
|
||||
auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_n_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ReduceMinMaxFp32();
|
||||
CalculateScaleZp();
|
||||
ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::InnerKernel *DynamicQuantCPUCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::Context *ctx, const kernel::KernelKey &desc) {
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel =
|
||||
new (std::nothrow) DynamicQuantCPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr.";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "inputs number should be 1, but " << inputs.size() << " is given.";
|
||||
return nullptr;
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "outputs number should be 1, but " << outputs.size() << " is given.";
|
||||
return nullptr;
|
||||
}
|
||||
bool support_dtype =
|
||||
inputs[0]->data_type() == TypeId::kNumberTypeFloat32 && outputs[0]->data_type() == TypeId::kNumberTypeInt8;
|
||||
if (support_dtype) {
|
||||
MS_LOG(ERROR) << "Unsupported data type input:" << inputs.front()->data_type()
|
||||
<< ", output:" << outputs.front()->data_type();
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DynamicQuant, DynamicQuantCPUCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DYNAMIC_QUANT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DYNAMIC_QUANT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
#include "src/inner_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DynamicQuantCPUKernel : public InnerKernel {
|
||||
public:
|
||||
DynamicQuantCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {}
|
||||
~DynamicQuantCPUKernel() override = default;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int QuantData(int task_id);
|
||||
int CalculateMinMax(int task_id);
|
||||
|
||||
private:
|
||||
void ReduceMinMaxFp32();
|
||||
void CalculateScaleZp();
|
||||
|
||||
private:
|
||||
int thread_num_;
|
||||
int thread_n_num_{0};
|
||||
int thread_n_stride_{0};
|
||||
int num_unit_{0};
|
||||
int8_t *int8_ptr_ = nullptr;
|
||||
float *float32_ptr_ = nullptr;
|
||||
|
||||
float real_min_array_[8];
|
||||
float real_max_array_[8];
|
||||
float real_min_;
|
||||
float real_max_;
|
||||
int32_t src_dtype_{0};
|
||||
int32_t dst_dtype_{0};
|
||||
bool symmetric_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DYNAMIC_QUANT_H_
|
|
@ -73,7 +73,7 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
} else if (common_quant->quant_type == schema::QuantType_QUANT_DANAMIC) {
|
||||
} else if (common_quant->quant_type == schema::QuantType_QUANT_DYNAMIC) {
|
||||
if (common_quant->bit_num != kQuantBitNumInt8) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
@ -161,7 +161,7 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
|
|||
(*quant_type) = schema::QuantType_QUANT_ALL;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str == "DYNAMIC_QUANT") {
|
||||
(*quant_type) = schema::QuantType_QUANT_DANAMIC;
|
||||
(*quant_type) = schema::QuantType_QUANT_DYNAMIC;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str.empty()) {
|
||||
(*quant_type) = schema::QuantType_QUANT_NONE;
|
||||
|
|
|
@ -36,7 +36,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
|
||||
prim::kPrimMatMulFusion,
|
||||
};
|
||||
ret = manager.InsertDynamicQuantPass(func_graph, support_dynamic_quant_ops);
|
||||
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert dynamic quant failed.";
|
||||
return ret;
|
||||
|
|
|
@ -619,7 +619,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
if (activation_target_data_type_ == kNumberTypeInt8 || activation_target_data_type_ == kNumberTypeUInt8) {
|
||||
// add quant_cast
|
||||
quant::InsertQuantNodeManager inset_quant_node_pass;
|
||||
status = inset_quant_node_pass.InsertQuantDtypeCastPass(func_graph);
|
||||
status = inset_quant_node_pass.InsertQuantDtypeCastNode(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -132,7 +132,7 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertQuantDtypeCastPass(const FuncGraphPtr &graph) {
|
||||
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto cnodes = graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
|
@ -194,11 +194,11 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DANAMIC);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DYNAMIC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertDynamicQuantPass(const FuncGraphPtr &graph,
|
||||
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
||||
const std::set<PrimitivePtr> &support_dynamic_quant_ops) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto cnodes = graph->GetOrderedCnodes();
|
||||
|
|
|
@ -31,9 +31,9 @@ class InsertQuantNodeManager {
|
|||
|
||||
~InsertQuantNodeManager() = default;
|
||||
|
||||
int InsertQuantDtypeCastPass(const FuncGraphPtr &graph);
|
||||
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph);
|
||||
|
||||
int InsertDynamicQuantPass(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
|
||||
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
|
||||
|
||||
private:
|
||||
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
|
||||
|
|
|
@ -158,31 +158,30 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
|
|||
}
|
||||
|
||||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
|
||||
return RET_OK;
|
||||
}
|
||||
int status;
|
||||
|
||||
SessionModel origin;
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
if (config->commonQuantParam.is_debug) { // Bak fp32 model for debug
|
||||
converter::Flags new_flag = *config;
|
||||
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
origin = CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
|
||||
}
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization
|
||||
status = DoFullQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do full quant failed.";
|
||||
return status;
|
||||
}
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization
|
||||
status = DoWeightQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do weight quant failed.";
|
||||
return status;
|
||||
}
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DANAMIC) {
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization
|
||||
status = DoDynamicQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do dynamic quant failed.";
|
||||
|
@ -202,6 +201,7 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags
|
|||
int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) {
|
||||
std::set<FuncGraphPtr> all_func_graphs{};
|
||||
GetFuncGraphs(func_graph, &all_func_graphs);
|
||||
// Support for multi-subgraph models
|
||||
for (auto &item : all_func_graphs) {
|
||||
auto status = DoSingleGraphQuantize(item, flags_);
|
||||
if (status != RET_OK) {
|
||||
|
|
Loading…
Reference in New Issue