!14180 [MS][LITE] Support mindir aware quant model

From: @cjh9368
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-08 16:31:55 +08:00 committed by Gitee
commit 48a5748354
40 changed files with 1008 additions and 144 deletions

View File

@ -123,8 +123,10 @@ enum NNACLQuantType {
QuantType_AwareTraining = 1,
QuantType_WeightQuant = 2,
QuantType_PostTraining = 3,
QuantType_QUANT_WEIGHT = 4,
QuantType_QUANT_ALL = 5,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_PostTraining
QuantType_MAX = QuantType_QUANT_ALL
};
typedef struct vvector {

View File

@ -26,7 +26,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
const TensorC *indices = inputs[1];
TensorC *output = outputs[0];
output->data_type_ = input->data_type_;
if (parameter->quant_type_ == QuantType_WeightQuant) {
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
output->data_type_ = kNumberTypeFloat32;
}
output->format_ = input->format_;

View File

@ -56,9 +56,11 @@ table Tensor {
enum QuantType: int {
QUANT_NONE,
AwareTraining,
WeightQuant,
PostTraining
AwareTraining, // deprecated, use QUANT_ALL instead
WeightQuant, // deprecated, use QUANT_WEIGHT instead
PostTraining, // deprecated, use QUANT_ALL instead
QUANT_WEIGHT,
QUANT_ALL
}
table Primitive {

View File

@ -65,6 +65,11 @@ class LiteModel : public Model {
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
node->primitive_ = c_node->primitive();
node->quant_type_ = c_node->quantType();
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
node->quant_type_ = schema::QuantType_QUANT_ALL;
} else if (node->quant_type_ == schema::QuantType_WeightQuant) {
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
}
node->name_ = c_node->name()->c_str();
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {

View File

@ -430,7 +430,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_ASSERT(node != nullptr);
// why we need this
TypeId data_type =
(node->quant_type_ == schema::QuantType_WeightQuant) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
(node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_));

View File

@ -669,6 +669,9 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format;
return RET_ERROR;
}
if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) {
schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat)));
}
schema_tensor->name = param_node->name();
ShapeVector shape_vector;
TypeId data_type;
@ -704,9 +707,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
}
}
}
if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) {
schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat)));
}
schema_tensor->name = input_name;
QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr
? nullptr
@ -722,9 +722,8 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
}
int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
int ret;
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto valueAbstract = value_node->abstract();
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
@ -742,15 +741,30 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<s
(*schema_tensor)->nodeType = NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();
(*schema_tensor)->data.resize(data->Size());
ret = memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
(*schema_tensor)->format = schema::Format_NHWC;
(*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType);
if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format;
return RET_ERROR;
}
// process weight tensor
if (data->Size() > 0) {
if (memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
}
if (primitive->GetAttr(opt::kWeightFormat) != nullptr) {
(*schema_tensor)->format = static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat)));
}
}
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
return ret;
return RET_OK;
}
int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
@ -877,6 +891,7 @@ int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_p
}
int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *output_cnode) {
auto value_node = input_anode->cast<ValueNodePtr>();
@ -888,7 +903,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
schema_tensor->name = value_node->fullname_with_scope();
}
if (value->isa<tensor::Tensor>()) {
ret = ProcessTensor(value_node, &schema_tensor, value, output_cnode, meta_graphT);
ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT);
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT);
} else if (value->isa<mindspore::BoolImm>()) {
@ -945,7 +960,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
is_graph_input = true;
}
} else if (input_node->isa<ValueNode>()) {
auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node);
auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertInputValueNode failed";
return RET_ERROR;

View File

@ -51,11 +51,11 @@ class AnfExporter {
int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode);
int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);

View File

@ -16,14 +16,19 @@
#include "tools/common/node_util.h"
#include <memory>
#include <set>
#include <vector>
#include "src/ops/populate/populate_register.h"
#include "src/common/common.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "src/runtime/infer_manager.h"
namespace mindspore {
namespace lite {
constexpr size_t kInitialSize = 1024;
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion,
schema::PrimitiveType_Conv2DBackpropInputFusion,
schema::PrimitiveType_AvgPoolGrad,
@ -350,6 +355,33 @@ static int Convert2CKHW(int srcFormat) {
return -1;
}
STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs) {
flatbuffers::FlatBufferBuilder fbb(kInitialSize);
auto prim = ConvertToPrimitive(node.primitive.get(), &fbb);
if (prim == nullptr) {
MS_LOG(ERROR) << "get primitive failed.";
fbb.Clear();
return RET_ERROR;
}
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
if (parameter_gen == nullptr) {
fbb.Clear();
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
return RET_ERROR;
}
auto parameter = parameter_gen(prim);
if (parameter == nullptr) {
fbb.Clear();
MS_LOG(ERROR) << "parameter is nullptr.";
return RET_ERROR;
}
parameter->infer_flag_ = true;
auto ret = KernelInferShape(inputs, outputs, parameter);
fbb.Clear();
free(parameter);
return ret;
}
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";

View File

@ -44,6 +44,8 @@ int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema:
using STATUS = int;
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; }
inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {

View File

@ -205,13 +205,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
@ -221,6 +214,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return status;
}
auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());

View File

@ -17,6 +17,8 @@
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include <string>
#include <set>
#include <vector>
#include <unordered_map>
#include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"
#include "src/common/common.h"
@ -36,17 +38,18 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
status = DoModelOutputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}
@ -128,48 +131,103 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
return RET_OK;
}
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) {
continue;
}
auto nodeName = (*iter)->name;
if ((*iter)->inputIndex.empty()) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
if (preTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) {
auto node_name = (**iter)->name;
auto status = RET_OK;
// insert fp32 to int8 before
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() &&
pre_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
// insert post
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
if (postTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}
// insert int8 to fp32 after
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() &&
post_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter) {
auto node_name = (**iter)->name;
auto status = RET_OK;
// insert int8 to fp32 before
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() &&
pre_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
}
// insert fp32 to int8 after
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() &&
post_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto node_name = (*iter)->name;
if ((*iter)->inputIndex.empty()) {
MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast ||
(*iter)->primitive->value.type == schema::PrimitiveType_Cast) {
continue;
}
STATUS status = RET_OK;
// quant_type is quant_all, but inputs/outputs are float32
if ((*iter)->quantType == QuantType_QUANT_ALL) {
status = InsetDTypeTransNodeForWrongDtypeQuantOp(graph, &iter);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
return status;
}
continue;
}
// quant_type is quant_none, but inputs/outputs have quant params and dtype is int8, which means this int8 op is not
// supported yet
if ((*iter)->quantType == QuantType_QUANT_NONE) {
status = InsetDTypeTransNodeForUnsupportedInt8Op(graph, &iter);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
return status;
}
}
}
return RET_OK;
}

View File

@ -46,6 +46,11 @@ class DTypeTransPass : public GraphPass {
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
STATUS InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter);
STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode);

View File

@ -18,75 +18,32 @@
#include <vector>
#include <memory>
#include "src/common/utils.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
#include "tools/common/node_util.h"
namespace mindspore::lite {
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is null";
return RET_NULL_PTR;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
MS_ASSERT(node != nullptr);
if (node->quantType == schema::QuantType_WeightQuant) {
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
// make sure reenter node will keep the node status
if (node->quantType != schema::QuantType_QUANT_NONE) {
continue;
}
DetermineNodeQuantType(*graph, node.get());
if (node->quantType == schema::QuantType_AwareTraining) {
continue;
}
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(DEBUG) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = schema::QuantType_QUANT_NONE;
} else {
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(DEBUG) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTraining;
}
}
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
quant_helper->NodeQuantPreprocess(graph, node.get());
}
return RET_OK;
}
void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(cnode != nullptr);
bool canQuant = true;
for (auto &inputTensorIdx : cnode->inputIndex) {
MS_ASSERT(graph.allTensors.size() > inputTensorIdx);
auto &inTensor = graph.allTensors.at(inputTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}
for (auto &outTensorIdx : cnode->outputIndex) {
MS_ASSERT(graph.allTensors.size() > outTensorIdx);
auto &outTensor = graph.allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}
if (canQuant) {
cnode->quantType = schema::QuantType_AwareTraining;
} else {
cnode->quantType = schema::QuantType_QUANT_NONE;
}
}
} // namespace mindspore::lite

View File

@ -24,14 +24,11 @@ namespace mindspore {
namespace lite {
class InferQuantParamPass : public GraphPass {
public:
InferQuantParamPass() {}
InferQuantParamPass() = default;
~InferQuantParamPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode);
};
} // namespace lite
} // namespace mindspore

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace lite {
namespace {
constexpr int DEFAULT_DIM_VALUE = -1;
constexpr size_t INITIAL_SIZE = 1024;
constexpr size_t kInitialSize = 1024;
void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> output_tensors) {
for (auto &tensor : input_tensors) {
@ -118,7 +118,7 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve
STATUS NodeInferShape(const std::unique_ptr<schema::CNodeT> &node, const std::vector<Tensor *> &inputs,
std::vector<Tensor *> *outputs) {
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
flatbuffers::FlatBufferBuilder fbb(kInitialSize);
auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
if (prim == nullptr) {
MS_LOG(ERROR) << "get primitive failed.";

View File

@ -5,6 +5,7 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4)
file(GLOB QUANTIZER
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_helper/*
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc

View File

@ -829,7 +829,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
<< "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size()
<< ", but index: " << index;
}
primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining);
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
continue;
} else if (op_type == ops::kNameConv2DFusion || op_type == ops::kNameConv2dTransposeFusion ||
op_type == ops::kNameFullConnection || op_type == ops::kNameLayerNormFusion) {
@ -870,7 +870,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
output_min_max.min = info->min;
DoQuantOutput(output_scale, output_zp, &output_min_max, primitive);
primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining);
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
}
}
return RET_OK;

View File

@ -26,7 +26,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
prim_c->Init(src_type, dst_type);
auto quant_params_holder = std::make_shared<QuantParamHolder>();
quant_params_holder->set_quant_type(schema::QuantType_PostTraining);
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
for (auto &quant_param : quant_params) {
std::vector<schema::QuantParamT> quant_params_in = {quant_param};
quant_params_holder->AddInputQuantParam(quant_params_in);
@ -80,7 +80,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
if (curnode_quant_type != input_cnode_quant_type) {
ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining &&
if (curnode_quant_type == schema::QuantType_QUANT_ALL &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
if (primitive_quant_param_holder->input_quant_params().size() < i) {
MS_LOG(ERROR) << "quant param is invalid.";
@ -89,7 +89,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
primitive_quant_param_holder->input_quant_params()[i - 1]);
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
input_cnode_quant_type == schema::QuantType_QUANT_ALL) {
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_primitive_quant_param_holder->output_quant_params().front());

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
static constexpr size_t kBiasAddSize = 2;
STATUS BiasAddQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
if (node.inputIndex.size() == kBiasAddSize) {
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAddSize - 1));
for (auto &quantParam : bias_tensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class BiasAddQuantParamPropogator : public QuantParamPropogator {
public:
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
#include <utility>
#include <memory>
#include "tools/common/tensor_util.h"
namespace mindspore::lite {
STATUS CarryDataQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
UpdateQuantParamsNum(*graph, node);
// refresh in_tensor quant_params by out_tensor quant_params
if (input_inited_quant_params_ < 1) {
auto &out_tensor = graph->allTensors.at(node.outputIndex.at(0));
auto out_quant_param = GetTensorQuantParam(out_tensor);
MS_ASSERT(out_quant_param != nullptr);
if (out_quant_param == nullptr || !out_quant_param->inited) {
MS_LOG(DEBUG) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(in_tensor != nullptr);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param != nullptr && !in_quant_param->inited) {
in_tensor->quantParams.front() = std::move(out_quant_param);
}
}
// refresh out_tensor quant_params by in_tensor quant_params
if (output_inited_quant_params_ < 1) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(in_tensor != nullptr);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param == nullptr || !in_quant_param->inited) {
MS_LOG(DEBUG) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
return RET_ERROR;
}
for (unsigned int i : node.outputIndex) {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i));
auto &out_tensor = graph->allTensors.at(i);
MS_ASSERT(out_tensor != nullptr);
auto out_quant_param = GetTensorQuantParam(out_tensor);
if (out_quant_param == nullptr) {
out_tensor->quantParams.emplace_back(std::move(in_quant_param));
continue;
}
if (out_quant_param->inited) {
continue;
}
out_tensor->quantParams.front() = std::move(in_quant_param);
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class CarryDataQuantParamPropogator : public QuantParamPropogator {
public:
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
#include <utility>
#include <memory>
#include "tools/common/tensor_util.h"
namespace mindspore::lite {
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
MS_ASSERT(node->inputIndex.size() >= 1);
MS_ASSERT(node->outputIndex.size() >= 1);
// check first in tensor
auto &in_tensor = graph.allTensors.at(node->inputIndex.at(0));
if (!in_tensor->quantParams.empty()) {
if (std::any_of(in_tensor->quantParams.begin(), in_tensor->quantParams.end(),
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
return false;
}
} else {
return false;
}
// check first out tensor
auto &out_tensor = graph.allTensors.at(node->outputIndex.at(0));
if (!out_tensor->quantParams.empty()) {
if (std::any_of(out_tensor->quantParams.begin(), out_tensor->quantParams.end(),
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
return false;
}
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class CarryDataQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H

View File

@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
#include <cfloat>
#include <memory>
#include <utility>
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
STATUS ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
UpdateQuantParamsNum(*graph, node);
if (input_inited_quant_params_ != node.inputIndex.size()) {
MS_LOG(DEBUG) << "Can not determine concat inputTensor quantParam, node " << node.name;
return RET_ERROR;
}
if (output_inited_quant_params_ != 1) {
MS_ASSERT(output_inited_quant_params_ == 0);
float min_min = FLT_MAX;
float max_max = FLT_MIN;
bool narrow_range = false;
int num_bits = -1;
for (size_t i = 0; i < node.inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &in_tensor = graph->allTensors.at(i);
MS_ASSERT(in_tensor != nullptr);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param == nullptr || !in_quant_param->inited) {
return RET_ERROR;
}
if (num_bits == -1) {
narrow_range = in_quant_param->narrowRange;
num_bits = in_quant_param->numBits;
} else {
MS_ASSERT(narrow_range == quantParam->narrowRange);
MS_ASSERT(num_bits == quantParam->numBits);
}
if (min_min > in_quant_param->min) {
min_min = in_quant_param->min;
}
if (max_max < in_quant_param->max) {
max_max = in_quant_param->max;
}
}
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &out_tensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(out_tensor != nullptr);
auto out_quant_param = std::make_unique<QuantParamT>();
auto status = quant::CalQuantizationParams(out_quant_param.get(), min_min, max_max, narrow_range, num_bits);
if (status != RET_OK) {
MS_LOG(DEBUG) << "in aware quantization run CalQuantizationParams failed!";
return RET_ERROR;
}
out_tensor->quantParams.emplace_back(std::move(out_quant_param));
output_inited_quant_params_++;
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConcatQuantParamPropogator : public QuantParamPropogator {
public:
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
static constexpr size_t kBiasAdd = 3;
STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
if (node.inputIndex.size() == kBiasAdd) {
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
for (auto &quantParam : bias_tensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConvQuantParamPropogator : public QuantParamPropogator {
public:
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
static constexpr size_t kInputIndex = 0;
static constexpr size_t kWeightIndex = 1;
bool ConvQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
mindspore::schema::CNodeT *node) {
MS_ASSERT(node->inputIndex.size() >= 2);
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
auto &weight_tensor = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
if (!weight_tensor->quantParams.empty() && weight_tensor->quantParams.front()->inited) {
node->quantType = schema::QuantType_QUANT_WEIGHT;
return true;
}
}
return false;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConvQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
namespace mindspore::lite {
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
return true;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class DefaultQuantAllQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
namespace mindspore::lite {
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
UpdateQuantParamsNum(graph, *node);
if (input_inited_quant_params_ == node->inputIndex.size()) {
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class OnlyNeedInputsQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H

View File

@ -0,0 +1,144 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
#include <unordered_map>
#include <memory>
#include "mindspore/core/utils/log_adapter.h"
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
namespace mindspore::lite {
void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) {
// update input quant params num
input_inited_quant_params_ = 0;
for (auto index : node.inputIndex) {
auto &input_tensor = graph.allTensors.at(index);
if (!input_tensor->quantParams.empty()) {
bool is_quant_params_inited =
!std::any_of(input_tensor->quantParams.begin(), input_tensor->quantParams.end(),
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
if (is_quant_params_inited) {
input_inited_quant_params_++;
}
}
}
// update output quant params num
output_inited_quant_params_ = 0;
for (auto index : node.outputIndex) {
auto &output_tensor = graph.allTensors.at(index);
if (!output_tensor->quantParams.empty()) {
bool is_quant_params_inited =
!std::any_of(output_tensor->quantParams.begin(), output_tensor->quantParams.end(),
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
if (is_quant_params_inited) {
output_inited_quant_params_++;
}
}
}
}
bool QuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
if (node->quantType != schema::QuantType_QUANT_NONE) {
return node->quantType == schema::QuantType_QUANT_ALL;
}
UpdateQuantParamsNum(graph, *node);
if (input_inited_quant_params_ == node->inputIndex.size() &&
output_inited_quant_params_ == node->outputIndex.size()) {
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
bool QuantTypeDeterminer::DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) {
return node->quantType == schema::QuantType_QUANT_WEIGHT;
}
void QuantNodeHelper::NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node) {
if (quant_type_determiner_->DetermineQuantWeight(*graph, node)) {
return;
}
quant_param_propogator_->PropogateQuantParams(graph, *node);
quant_type_determiner_->DetermineQuantAll(*graph, node);
}
QuantHelperRegister *QuantHelperRegister::GetInstance() {
static QuantHelperRegister instance;
return &instance;
}
QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_type) {
auto it = register_map_.find(op_type);
if (it != register_map_.end()) {
return it->second;
}
return register_map_[schema::PrimitiveType_NONE];
}
QuantHelperRegister::QuantHelperRegister() {
auto base_propogator = std::make_shared<QuantParamPropogator>();
auto base_determiner = std::make_shared<QuantTypeDeterminer>();
auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>();
auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>();
auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>();
auto concat_propogator = std::make_shared<ConcatQuantParamPropogator>();
auto conv_propogator = std::make_shared<ConvQuantParamPropogator>();
auto conv_determiner = std::make_shared<ConvQuantTypeDeterminer>();
auto default_quant_all_determiner = std::make_shared<DefaultQuantAllQuantTypeDeterminer>();
auto only_need_inputs_determiner = std::make_shared<OnlyNeedInputsQuantTypeDeterminer>();
register_map_[schema::PrimitiveType_BiasAdd] = new QuantNodeHelper(bias_add_propogator, base_determiner);
register_map_[schema::PrimitiveType_MaxPoolFusion] =
new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Resize] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Reshape] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_StridedSlice] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Transpose] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_PadFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_ReduceFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Gather] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Concat] = new QuantNodeHelper(concat_propogator, base_determiner);
register_map_[schema::PrimitiveType_Conv2DFusion] = new QuantNodeHelper(conv_propogator, conv_determiner);
register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner);
register_map_[schema::PrimitiveType_QuantDTypeCast] =
new QuantNodeHelper(base_propogator, default_quant_all_determiner);
register_map_[schema::PrimitiveType_DetectionPostProcess] =
new QuantNodeHelper(base_propogator, only_need_inputs_determiner);
register_map_[schema::PrimitiveType_NONE] = new QuantNodeHelper(base_propogator, base_determiner);
}
QuantHelperRegister::~QuantHelperRegister() {
for (const auto &iter : register_map_) {
delete iter.second;
}
this->register_map_.clear();
}
} // namespace mindspore::lite

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H
#include <unordered_map>
#include <memory>
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
class QuantNodeBase {
public:
void UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node);
protected:
size_t input_inited_quant_params_ = 0;
size_t output_inited_quant_params_ = 0;
};
class QuantParamPropogator : public QuantNodeBase {
public:
virtual STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { return RET_OK; }
};
class QuantTypeDeterminer : public QuantNodeBase {
public:
virtual bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node);
virtual bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node);
};
class QuantNodeHelper {
public:
void NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node);
QuantNodeHelper(std::shared_ptr<QuantParamPropogator> quant_param_propogator,
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner) {
quant_param_propogator_ = quant_param_propogator;
quant_type_determiner_ = quant_type_determiner;
}
virtual ~QuantNodeHelper() = default;
protected:
std::shared_ptr<QuantParamPropogator> quant_param_propogator_;
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner_;
};
class QuantHelperRegister {
public:
virtual ~QuantHelperRegister();
QuantNodeHelper *GetQuantHelper(schema::PrimitiveType op_type);
static QuantHelperRegister *GetInstance();
private:
QuantHelperRegister();
std::unordered_map<schema::PrimitiveType, QuantNodeHelper *> register_map_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H

View File

@ -73,7 +73,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
abstract_tensor->element()->set_type(TypeIdToType(type_id_));
auto quant_param_holder = GetCNodeQuantHolder(primitive);
quant_param_holder->set_quant_type(schema::QuantType_WeightQuant);
quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
return RET_OK;
}

View File

@ -296,6 +296,7 @@ bool MindirAdjustPass::Run(const FuncGraphPtr &graph) {
return lite::RET_OK;
}
MS_ASSERT(graph != nullptr);
graph_ = graph;
auto node_list = TopoSort(graph->get_return());
int status = lite::RET_OK;
bool success_flag = true;

View File

@ -31,15 +31,17 @@ class MindirAdjustPass : public Pass {
~MindirAdjustPass() override = default;
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
int ValueNodeInt64Convert(AnfNodePtr anf_node);
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
int ComputeQuantParams(AnfNodePtr anf_node);
bool Run(const FuncGraphPtr &graph) override;
protected:
int ValueNodeInt64Convert(AnfNodePtr anf_node);
int ComputeQuantParams(AnfNodePtr anf_node);
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
FmkType fmk_type_ = FmkType::FmkType_MS;
bool train_flag_ = false;
FuncGraphPtr graph_;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_

View File

@ -26,7 +26,9 @@ using mindspore::lite::converter::FmkType_TF;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_ALL;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_QUANT_WEIGHT;
using mindspore::schema::QuantType_WeightQuant;
namespace mindspore::opt {
namespace {
@ -122,9 +124,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node,
}
auto weight_node = conv_node->input(kConvWeightIndex);
switch (this->quant_type) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_PostTraining:
case QuantType_QUANT_WEIGHT:
case QuantType_QUANT_ALL:
case QuantType_QUANT_NONE: {
// sum up from current ms quant models
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));