forked from mindspore-Ecosystem/mindspore
!14180 [MS][LITE] Support mindir aware quant model
From: @cjh9368 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
48a5748354
|
@ -123,8 +123,10 @@ enum NNACLQuantType {
|
|||
QuantType_AwareTraining = 1,
|
||||
QuantType_WeightQuant = 2,
|
||||
QuantType_PostTraining = 3,
|
||||
QuantType_QUANT_WEIGHT = 4,
|
||||
QuantType_QUANT_ALL = 5,
|
||||
QuantType_MIN = QuantType_QUANT_NONE,
|
||||
QuantType_MAX = QuantType_PostTraining
|
||||
QuantType_MAX = QuantType_QUANT_ALL
|
||||
};
|
||||
|
||||
typedef struct vvector {
|
||||
|
|
|
@ -26,7 +26,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
const TensorC *indices = inputs[1];
|
||||
TensorC *output = outputs[0];
|
||||
output->data_type_ = input->data_type_;
|
||||
if (parameter->quant_type_ == QuantType_WeightQuant) {
|
||||
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
||||
output->data_type_ = kNumberTypeFloat32;
|
||||
}
|
||||
output->format_ = input->format_;
|
||||
|
|
|
@ -56,9 +56,11 @@ table Tensor {
|
|||
|
||||
enum QuantType: int {
|
||||
QUANT_NONE,
|
||||
AwareTraining,
|
||||
WeightQuant,
|
||||
PostTraining
|
||||
AwareTraining, // deprecated, use QUANT_ALL instead
|
||||
WeightQuant, // deprecated, use QUANT_WEIGHT instead
|
||||
PostTraining, // deprecated, use QUANT_ALL instead
|
||||
QUANT_WEIGHT,
|
||||
QUANT_ALL
|
||||
}
|
||||
|
||||
table Primitive {
|
||||
|
|
|
@ -65,6 +65,11 @@ class LiteModel : public Model {
|
|||
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
||||
node->primitive_ = c_node->primitive();
|
||||
node->quant_type_ = c_node->quantType();
|
||||
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_ALL;
|
||||
} else if (node->quant_type_ == schema::QuantType_WeightQuant) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
|
||||
}
|
||||
node->name_ = c_node->name()->c_str();
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
|
|
|
@ -430,7 +430,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
MS_ASSERT(node != nullptr);
|
||||
// why we need this
|
||||
TypeId data_type =
|
||||
(node->quant_type_ == schema::QuantType_WeightQuant) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||
(node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||
OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_));
|
||||
|
|
|
@ -669,6 +669,9 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
|
|||
MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat)));
|
||||
}
|
||||
schema_tensor->name = param_node->name();
|
||||
ShapeVector shape_vector;
|
||||
TypeId data_type;
|
||||
|
@ -704,9 +707,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
|
|||
}
|
||||
}
|
||||
}
|
||||
if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat)));
|
||||
}
|
||||
schema_tensor->name = input_name;
|
||||
QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr
|
||||
? nullptr
|
||||
|
@ -722,9 +722,8 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
|
|||
}
|
||||
|
||||
int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
|
||||
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
|
||||
int ret;
|
||||
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
|
||||
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
|
||||
auto valueAbstract = value_node->abstract();
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||
|
@ -742,15 +741,30 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<s
|
|||
(*schema_tensor)->nodeType = NodeType_ValueNode;
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
(*schema_tensor)->data.resize(data->Size());
|
||||
ret = memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
(*schema_tensor)->format = schema::Format_NHWC;
|
||||
|
||||
(*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType);
|
||||
if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) {
|
||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// process weight tensor
|
||||
if (data->Size() > 0) {
|
||||
if (memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (primitive->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
(*schema_tensor)->format = static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat)));
|
||||
}
|
||||
}
|
||||
|
||||
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
|
||||
return ret;
|
||||
return RET_OK;
|
||||
}
|
||||
int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
|
||||
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
|
||||
|
@ -877,6 +891,7 @@ int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_p
|
|||
}
|
||||
|
||||
int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
|
||||
const std::shared_ptr<PrimitiveC> &primitive,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *output_cnode) {
|
||||
auto value_node = input_anode->cast<ValueNodePtr>();
|
||||
|
@ -888,7 +903,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
schema_tensor->name = value_node->fullname_with_scope();
|
||||
}
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
ret = ProcessTensor(value_node, &schema_tensor, value, output_cnode, meta_graphT);
|
||||
ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT);
|
||||
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
|
||||
ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT);
|
||||
} else if (value->isa<mindspore::BoolImm>()) {
|
||||
|
@ -945,7 +960,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
|
|||
is_graph_input = true;
|
||||
}
|
||||
} else if (input_node->isa<ValueNode>()) {
|
||||
auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node);
|
||||
auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertInputValueNode failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -51,11 +51,11 @@ class AnfExporter {
|
|||
int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode);
|
||||
int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
|
||||
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
|
||||
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
|
||||
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
|
||||
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
|
|
|
@ -16,14 +16,19 @@
|
|||
|
||||
#include "tools/common/node_util.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr size_t kInitialSize = 1024;
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion,
|
||||
schema::PrimitiveType_Conv2DBackpropInputFusion,
|
||||
schema::PrimitiveType_AvgPoolGrad,
|
||||
|
@ -350,6 +355,33 @@ static int Convert2CKHW(int srcFormat) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs) {
|
||||
flatbuffers::FlatBufferBuilder fbb(kInitialSize);
|
||||
auto prim = ConvertToPrimitive(node.primitive.get(), &fbb);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "get primitive failed.";
|
||||
fbb.Clear();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
|
||||
if (parameter_gen == nullptr) {
|
||||
fbb.Clear();
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto parameter = parameter_gen(prim);
|
||||
if (parameter == nullptr) {
|
||||
fbb.Clear();
|
||||
MS_LOG(ERROR) << "parameter is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter->infer_flag_ = true;
|
||||
auto ret = KernelInferShape(inputs, outputs, parameter);
|
||||
fbb.Clear();
|
||||
free(parameter);
|
||||
return ret;
|
||||
}
|
||||
|
||||
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
|
|
|
@ -44,6 +44,8 @@ int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema:
|
|||
using STATUS = int;
|
||||
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
|
||||
|
||||
inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; }
|
||||
|
||||
inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
|
||||
|
|
|
@ -205,13 +205,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer quantNodeOptimizer;
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
|
@ -221,6 +214,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
return status;
|
||||
}
|
||||
auto old_nodes2 = GetGraphNodes();
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(dTypeTransPass);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "src/common/common.h"
|
||||
|
@ -36,17 +38,18 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = DoNodeInoutDTypeTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = DoModelOutputDTypeTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = DoNodeInoutDTypeTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -128,48 +131,103 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
// insert transNode before and after existNode
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
auto nodeName = (*iter)->name;
|
||||
if ((*iter)->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
|
||||
return RET_ERROR;
|
||||
}
|
||||
STATUS status;
|
||||
// insert pre
|
||||
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
|
||||
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
|
||||
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
|
||||
if (preTensor->dataType != TypeId::kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
|
||||
STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) {
|
||||
auto node_name = (**iter)->name;
|
||||
auto status = RET_OK;
|
||||
// insert fp32 to int8 before
|
||||
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
|
||||
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
|
||||
if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() &&
|
||||
pre_tensor->quantParams.front()->inited) {
|
||||
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
|
||||
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
// insert post
|
||||
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
|
||||
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
|
||||
if (postTensor->dataType != TypeId::kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
(*iter)->quantType = QuantType_QUANT_NONE;
|
||||
}
|
||||
|
||||
// insert int8 to fp32 after
|
||||
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
|
||||
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
|
||||
if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() &&
|
||||
post_tensor->quantParams.front()->inited) {
|
||||
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter) {
|
||||
auto node_name = (**iter)->name;
|
||||
auto status = RET_OK;
|
||||
// insert int8 to fp32 before
|
||||
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
|
||||
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
|
||||
if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() &&
|
||||
pre_tensor->quantParams.front()->inited) {
|
||||
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// insert fp32 to int8 after
|
||||
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
|
||||
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
|
||||
if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() &&
|
||||
post_tensor->quantParams.front()->inited) {
|
||||
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto node_name = (*iter)->name;
|
||||
if ((*iter)->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast ||
|
||||
(*iter)->primitive->value.type == schema::PrimitiveType_Cast) {
|
||||
continue;
|
||||
}
|
||||
|
||||
STATUS status = RET_OK;
|
||||
// quant_type is quant_all, but inputs/outputs are float32
|
||||
if ((*iter)->quantType == QuantType_QUANT_ALL) {
|
||||
status = InsetDTypeTransNodeForWrongDtypeQuantOp(graph, &iter);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
|
||||
return status;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// quant_type is quant_none, but inputs/outputs have quant params and dtype is int8, which means this int8 op is not
|
||||
// supported yet
|
||||
if ((*iter)->quantType == QuantType_QUANT_NONE) {
|
||||
status = InsetDTypeTransNodeForUnsupportedInt8Op(graph, &iter);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,11 @@ class DTypeTransPass : public GraphPass {
|
|||
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
|
||||
|
||||
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
|
||||
|
||||
STATUS InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter);
|
||||
|
||||
STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter);
|
||||
|
||||
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
|
||||
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode);
|
||||
|
||||
|
|
|
@ -18,75 +18,32 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/quantizer/calc_quant_param.h"
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "graph is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (node->quantType == schema::QuantType_WeightQuant) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// make sure reenter node will keep the node status
|
||||
if (node->quantType != schema::QuantType_QUANT_NONE) {
|
||||
continue;
|
||||
}
|
||||
DetermineNodeQuantType(*graph, node.get());
|
||||
if (node->quantType == schema::QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
|
||||
MS_ASSERT(false);
|
||||
}
|
||||
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
|
||||
if (quantParamCalcer == nullptr) {
|
||||
MS_LOG(DEBUG) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
auto status = quantParamCalcer->Calc(graph, *node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(DEBUG) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
}
|
||||
}
|
||||
|
||||
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
|
||||
|
||||
quant_helper->NodeQuantPreprocess(graph, node.get());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
bool canQuant = true;
|
||||
for (auto &inputTensorIdx : cnode->inputIndex) {
|
||||
MS_ASSERT(graph.allTensors.size() > inputTensorIdx);
|
||||
auto &inTensor = graph.allTensors.at(inputTensorIdx);
|
||||
MS_ASSERT(inTensor != nullptr);
|
||||
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
|
||||
!inTensor->quantParams.front()->inited) {
|
||||
canQuant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &outTensorIdx : cnode->outputIndex) {
|
||||
MS_ASSERT(graph.allTensors.size() > outTensorIdx);
|
||||
auto &outTensor = graph.allTensors.at(outTensorIdx);
|
||||
MS_ASSERT(outTensor != nullptr);
|
||||
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
|
||||
!outTensor->quantParams.front()->inited) {
|
||||
canQuant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (canQuant) {
|
||||
cnode->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
cnode->quantType = schema::QuantType_QUANT_NONE;
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -24,14 +24,11 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class InferQuantParamPass : public GraphPass {
|
||||
public:
|
||||
InferQuantParamPass() {}
|
||||
InferQuantParamPass() = default;
|
||||
|
||||
~InferQuantParamPass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
namespace {
|
||||
constexpr int DEFAULT_DIM_VALUE = -1;
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
constexpr size_t kInitialSize = 1024;
|
||||
|
||||
void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> output_tensors) {
|
||||
for (auto &tensor : input_tensors) {
|
||||
|
@ -118,7 +118,7 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve
|
|||
|
||||
STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
|
||||
std::vector<Tensor *> *outputs) {
|
||||
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
|
||||
flatbuffers::FlatBufferBuilder fbb(kInitialSize);
|
||||
auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "get primitive failed.";
|
||||
|
|
|
@ -5,6 +5,7 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4)
|
|||
|
||||
file(GLOB QUANTIZER
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quant_helper/*
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
|
||||
|
|
|
@ -829,7 +829,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
<< "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size()
|
||||
<< ", but index: " << index;
|
||||
}
|
||||
primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining);
|
||||
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
continue;
|
||||
} else if (op_type == ops::kNameConv2DFusion || op_type == ops::kNameConv2dTransposeFusion ||
|
||||
op_type == ops::kNameFullConnection || op_type == ops::kNameLayerNormFusion) {
|
||||
|
@ -870,7 +870,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
output_min_max.min = info->min;
|
||||
|
||||
DoQuantOutput(output_scale, output_zp, &output_min_max, primitive);
|
||||
primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining);
|
||||
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -26,7 +26,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
|
|||
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
|
||||
prim_c->Init(src_type, dst_type);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>();
|
||||
quant_params_holder->set_quant_type(schema::QuantType_PostTraining);
|
||||
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
for (auto &quant_param : quant_params) {
|
||||
std::vector<schema::QuantParamT> quant_params_in = {quant_param};
|
||||
quant_params_holder->AddInputQuantParam(quant_params_in);
|
||||
|
@ -80,7 +80,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
|
|||
|
||||
if (curnode_quant_type != input_cnode_quant_type) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
if (curnode_quant_type == schema::QuantType_PostTraining &&
|
||||
if (curnode_quant_type == schema::QuantType_QUANT_ALL &&
|
||||
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
|
||||
if (primitive_quant_param_holder->input_quant_params().size() < i) {
|
||||
MS_LOG(ERROR) << "quant param is invalid.";
|
||||
|
@ -89,7 +89,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
|
|||
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
|
||||
primitive_quant_param_holder->input_quant_params()[i - 1]);
|
||||
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
|
||||
input_cnode_quant_type == schema::QuantType_PostTraining) {
|
||||
input_cnode_quant_type == schema::QuantType_QUANT_ALL) {
|
||||
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
|
||||
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
|
||||
input_primitive_quant_param_holder->output_quant_params().front());
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
static constexpr size_t kBiasAddSize = 2;
|
||||
STATUS BiasAddQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
||||
const mindspore::schema::CNodeT &node) {
|
||||
if (node.inputIndex.size() == kBiasAddSize) {
|
||||
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAddSize - 1));
|
||||
for (auto &quantParam : bias_tensor->quantParams) {
|
||||
quantParam->dstDtype = TypeId::kNumberTypeInt32;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class BiasAddQuantParamPropogator : public QuantParamPropogator {
|
||||
public:
|
||||
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "tools/common/tensor_util.h"
|
||||
namespace mindspore::lite {
|
||||
STATUS CarryDataQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
|
||||
UpdateQuantParamsNum(*graph, node);
|
||||
|
||||
// refresh in_tensor quant_params by out_tensor quant_params
|
||||
if (input_inited_quant_params_ < 1) {
|
||||
auto &out_tensor = graph->allTensors.at(node.outputIndex.at(0));
|
||||
auto out_quant_param = GetTensorQuantParam(out_tensor);
|
||||
MS_ASSERT(out_quant_param != nullptr);
|
||||
if (out_quant_param == nullptr || !out_quant_param->inited) {
|
||||
MS_LOG(DEBUG) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
|
||||
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
||||
if (in_quant_param != nullptr && !in_quant_param->inited) {
|
||||
in_tensor->quantParams.front() = std::move(out_quant_param);
|
||||
}
|
||||
}
|
||||
|
||||
// refresh out_tensor quant_params by in_tensor quant_params
|
||||
if (output_inited_quant_params_ < 1) {
|
||||
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
|
||||
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
||||
if (in_quant_param == nullptr || !in_quant_param->inited) {
|
||||
MS_LOG(DEBUG) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (unsigned int i : node.outputIndex) {
|
||||
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i));
|
||||
auto &out_tensor = graph->allTensors.at(i);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
auto out_quant_param = GetTensorQuantParam(out_tensor);
|
||||
if (out_quant_param == nullptr) {
|
||||
out_tensor->quantParams.emplace_back(std::move(in_quant_param));
|
||||
continue;
|
||||
}
|
||||
if (out_quant_param->inited) {
|
||||
continue;
|
||||
}
|
||||
out_tensor->quantParams.front() = std::move(in_quant_param);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class CarryDataQuantParamPropogator : public QuantParamPropogator {
|
||||
public:
|
||||
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "tools/common/tensor_util.h"
|
||||
namespace mindspore::lite {
|
||||
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
MS_ASSERT(node->inputIndex.size() >= 1);
|
||||
MS_ASSERT(node->outputIndex.size() >= 1);
|
||||
|
||||
// check first in tensor
|
||||
auto &in_tensor = graph.allTensors.at(node->inputIndex.at(0));
|
||||
if (!in_tensor->quantParams.empty()) {
|
||||
if (std::any_of(in_tensor->quantParams.begin(), in_tensor->quantParams.end(),
|
||||
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check first out tensor
|
||||
auto &out_tensor = graph.allTensors.at(node->outputIndex.at(0));
|
||||
if (!out_tensor->quantParams.empty()) {
|
||||
if (std::any_of(out_tensor->quantParams.begin(), out_tensor->quantParams.end(),
|
||||
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
|
||||
return false;
|
||||
}
|
||||
node->quantType = schema::QuantType_QUANT_ALL;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class CarryDataQuantTypeDeterminer : public QuantTypeDeterminer {
|
||||
public:
|
||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
|
||||
#include <cfloat>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
STATUS ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
||||
const mindspore::schema::CNodeT &node) {
|
||||
UpdateQuantParamsNum(*graph, node);
|
||||
|
||||
if (input_inited_quant_params_ != node.inputIndex.size()) {
|
||||
MS_LOG(DEBUG) << "Can not determine concat inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (output_inited_quant_params_ != 1) {
|
||||
MS_ASSERT(output_inited_quant_params_ == 0);
|
||||
float min_min = FLT_MAX;
|
||||
float max_max = FLT_MIN;
|
||||
bool narrow_range = false;
|
||||
int num_bits = -1;
|
||||
for (size_t i = 0; i < node.inputIndex.size(); i++) {
|
||||
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
|
||||
auto &in_tensor = graph->allTensors.at(i);
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
||||
if (in_quant_param == nullptr || !in_quant_param->inited) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (num_bits == -1) {
|
||||
narrow_range = in_quant_param->narrowRange;
|
||||
num_bits = in_quant_param->numBits;
|
||||
} else {
|
||||
MS_ASSERT(narrow_range == quantParam->narrowRange);
|
||||
MS_ASSERT(num_bits == quantParam->numBits);
|
||||
}
|
||||
if (min_min > in_quant_param->min) {
|
||||
min_min = in_quant_param->min;
|
||||
}
|
||||
if (max_max < in_quant_param->max) {
|
||||
max_max = in_quant_param->max;
|
||||
}
|
||||
}
|
||||
|
||||
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
|
||||
auto &out_tensor = graph->allTensors.at(node.outputIndex.front());
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
auto out_quant_param = std::make_unique<QuantParamT>();
|
||||
|
||||
auto status = quant::CalQuantizationParams(out_quant_param.get(), min_min, max_max, narrow_range, num_bits);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(DEBUG) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_tensor->quantParams.emplace_back(std::move(out_quant_param));
|
||||
output_inited_quant_params_++;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class ConcatQuantParamPropogator : public QuantParamPropogator {
|
||||
public:
|
||||
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
namespace mindspore::lite {
|
||||
static constexpr size_t kBiasAdd = 3;
|
||||
|
||||
STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
||||
const mindspore::schema::CNodeT &node) {
|
||||
if (node.inputIndex.size() == kBiasAdd) {
|
||||
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
|
||||
for (auto &quantParam : bias_tensor->quantParams) {
|
||||
quantParam->dstDtype = TypeId::kNumberTypeInt32;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class ConvQuantParamPropogator : public QuantParamPropogator {
|
||||
public:
|
||||
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
namespace mindspore::lite {
|
||||
static constexpr size_t kInputIndex = 0;
|
||||
static constexpr size_t kWeightIndex = 1;
|
||||
|
||||
bool ConvQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
|
||||
mindspore::schema::CNodeT *node) {
|
||||
MS_ASSERT(node->inputIndex.size() >= 2);
|
||||
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
|
||||
auto &weight_tensor = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
|
||||
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
|
||||
if (!weight_tensor->quantParams.empty() && weight_tensor->quantParams.front()->inited) {
|
||||
node->quantType = schema::QuantType_QUANT_WEIGHT;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class ConvQuantTypeDeterminer : public QuantTypeDeterminer {
|
||||
public:
|
||||
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class DefaultQuantAllQuantTypeDeterminer : public QuantTypeDeterminer {
|
||||
public:
|
||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
UpdateQuantParamsNum(graph, *node);
|
||||
if (input_inited_quant_params_ == node->inputIndex.size()) {
|
||||
node->quantType = schema::QuantType_QUANT_ALL;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
namespace mindspore::lite {
|
||||
class OnlyNeedInputsQuantTypeDeterminer : public QuantTypeDeterminer {
|
||||
public:
|
||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
|
||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
|
||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
|
||||
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
|
||||
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
|
||||
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) {
|
||||
// update input quant params num
|
||||
input_inited_quant_params_ = 0;
|
||||
for (auto index : node.inputIndex) {
|
||||
auto &input_tensor = graph.allTensors.at(index);
|
||||
if (!input_tensor->quantParams.empty()) {
|
||||
bool is_quant_params_inited =
|
||||
!std::any_of(input_tensor->quantParams.begin(), input_tensor->quantParams.end(),
|
||||
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
|
||||
|
||||
if (is_quant_params_inited) {
|
||||
input_inited_quant_params_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update output quant params num
|
||||
output_inited_quant_params_ = 0;
|
||||
for (auto index : node.outputIndex) {
|
||||
auto &output_tensor = graph.allTensors.at(index);
|
||||
if (!output_tensor->quantParams.empty()) {
|
||||
bool is_quant_params_inited =
|
||||
!std::any_of(output_tensor->quantParams.begin(), output_tensor->quantParams.end(),
|
||||
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
|
||||
|
||||
if (is_quant_params_inited) {
|
||||
output_inited_quant_params_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
if (node->quantType != schema::QuantType_QUANT_NONE) {
|
||||
return node->quantType == schema::QuantType_QUANT_ALL;
|
||||
}
|
||||
|
||||
UpdateQuantParamsNum(graph, *node);
|
||||
if (input_inited_quant_params_ == node->inputIndex.size() &&
|
||||
output_inited_quant_params_ == node->outputIndex.size()) {
|
||||
node->quantType = schema::QuantType_QUANT_ALL;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool QuantTypeDeterminer::DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
return node->quantType == schema::QuantType_QUANT_WEIGHT;
|
||||
}
|
||||
|
||||
void QuantNodeHelper::NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node) {
|
||||
if (quant_type_determiner_->DetermineQuantWeight(*graph, node)) {
|
||||
return;
|
||||
}
|
||||
quant_param_propogator_->PropogateQuantParams(graph, *node);
|
||||
quant_type_determiner_->DetermineQuantAll(*graph, node);
|
||||
}
|
||||
|
||||
QuantHelperRegister *QuantHelperRegister::GetInstance() {
|
||||
static QuantHelperRegister instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_type) {
|
||||
auto it = register_map_.find(op_type);
|
||||
if (it != register_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return register_map_[schema::PrimitiveType_NONE];
|
||||
}
|
||||
|
||||
QuantHelperRegister::QuantHelperRegister() {
|
||||
auto base_propogator = std::make_shared<QuantParamPropogator>();
|
||||
auto base_determiner = std::make_shared<QuantTypeDeterminer>();
|
||||
auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>();
|
||||
auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>();
|
||||
auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>();
|
||||
auto concat_propogator = std::make_shared<ConcatQuantParamPropogator>();
|
||||
auto conv_propogator = std::make_shared<ConvQuantParamPropogator>();
|
||||
auto conv_determiner = std::make_shared<ConvQuantTypeDeterminer>();
|
||||
auto default_quant_all_determiner = std::make_shared<DefaultQuantAllQuantTypeDeterminer>();
|
||||
auto only_need_inputs_determiner = std::make_shared<OnlyNeedInputsQuantTypeDeterminer>();
|
||||
|
||||
register_map_[schema::PrimitiveType_BiasAdd] = new QuantNodeHelper(bias_add_propogator, base_determiner);
|
||||
|
||||
register_map_[schema::PrimitiveType_MaxPoolFusion] =
|
||||
new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_Resize] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_Reshape] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_StridedSlice] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_Transpose] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_PadFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_ReduceFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
register_map_[schema::PrimitiveType_Gather] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
||||
|
||||
register_map_[schema::PrimitiveType_Concat] = new QuantNodeHelper(concat_propogator, base_determiner);
|
||||
|
||||
register_map_[schema::PrimitiveType_Conv2DFusion] = new QuantNodeHelper(conv_propogator, conv_determiner);
|
||||
register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner);
|
||||
|
||||
register_map_[schema::PrimitiveType_QuantDTypeCast] =
|
||||
new QuantNodeHelper(base_propogator, default_quant_all_determiner);
|
||||
|
||||
register_map_[schema::PrimitiveType_DetectionPostProcess] =
|
||||
new QuantNodeHelper(base_propogator, only_need_inputs_determiner);
|
||||
register_map_[schema::PrimitiveType_NONE] = new QuantNodeHelper(base_propogator, base_determiner);
|
||||
}
|
||||
|
||||
QuantHelperRegister::~QuantHelperRegister() {
|
||||
for (const auto &iter : register_map_) {
|
||||
delete iter.second;
|
||||
}
|
||||
this->register_map_.clear();
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class QuantNodeBase {
|
||||
public:
|
||||
void UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node);
|
||||
|
||||
protected:
|
||||
size_t input_inited_quant_params_ = 0;
|
||||
size_t output_inited_quant_params_ = 0;
|
||||
};
|
||||
|
||||
class QuantParamPropogator : public QuantNodeBase {
|
||||
public:
|
||||
virtual STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { return RET_OK; }
|
||||
};
|
||||
|
||||
class QuantTypeDeterminer : public QuantNodeBase {
|
||||
public:
|
||||
virtual bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node);
|
||||
virtual bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node);
|
||||
};
|
||||
|
||||
class QuantNodeHelper {
|
||||
public:
|
||||
void NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node);
|
||||
QuantNodeHelper(std::shared_ptr<QuantParamPropogator> quant_param_propogator,
|
||||
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner) {
|
||||
quant_param_propogator_ = quant_param_propogator;
|
||||
quant_type_determiner_ = quant_type_determiner;
|
||||
}
|
||||
virtual ~QuantNodeHelper() = default;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<QuantParamPropogator> quant_param_propogator_;
|
||||
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner_;
|
||||
};
|
||||
|
||||
class QuantHelperRegister {
|
||||
public:
|
||||
virtual ~QuantHelperRegister();
|
||||
QuantNodeHelper *GetQuantHelper(schema::PrimitiveType op_type);
|
||||
static QuantHelperRegister *GetInstance();
|
||||
|
||||
private:
|
||||
QuantHelperRegister();
|
||||
std::unordered_map<schema::PrimitiveType, QuantNodeHelper *> register_map_;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H
|
|
@ -73,7 +73,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const
|
|||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
abstract_tensor->element()->set_type(TypeIdToType(type_id_));
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_WeightQuant);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -296,6 +296,7 @@ bool MindirAdjustPass::Run(const FuncGraphPtr &graph) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
MS_ASSERT(graph != nullptr);
|
||||
graph_ = graph;
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
int status = lite::RET_OK;
|
||||
bool success_flag = true;
|
||||
|
|
|
@ -31,15 +31,17 @@ class MindirAdjustPass : public Pass {
|
|||
~MindirAdjustPass() override = default;
|
||||
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
|
||||
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
|
||||
int ValueNodeInt64Convert(AnfNodePtr anf_node);
|
||||
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
|
||||
int ComputeQuantParams(AnfNodePtr anf_node);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
protected:
|
||||
int ValueNodeInt64Convert(AnfNodePtr anf_node);
|
||||
int ComputeQuantParams(AnfNodePtr anf_node);
|
||||
|
||||
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||
bool train_flag_ = false;
|
||||
FuncGraphPtr graph_;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||
|
|
|
@ -26,7 +26,9 @@ using mindspore::lite::converter::FmkType_TF;
|
|||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
using mindspore::schema::QuantType_AwareTraining;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_QUANT_ALL;
|
||||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_QUANT_WEIGHT;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
|
@ -122,9 +124,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node,
|
|||
}
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
switch (this->quant_type) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_QUANT_WEIGHT:
|
||||
case QuantType_QUANT_ALL:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// sum up from current ms quant models
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
|
||||
|
|
Loading…
Reference in New Issue