!4469 change param_t_ value quant params

Merge pull request !4469 from cjh9368/aware_quant
This commit is contained in:
mindspore-ci-bot 2020-08-15 18:00:06 +08:00 committed by Gitee
commit ed9e62c760
16 changed files with 64 additions and 70 deletions

View File

@ -197,7 +197,7 @@ union PrimitiveType {
enum QuantType: int {
QUANT_NONE,
AwareTrainning,
AwareTraining,
WeightQuant,
PostTraining
}

View File

@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
// add quant param
node->quantType = primitiveT_value->GetQuantType();
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) {
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation
auto input_quant_params = primitiveT_value->GetInputQuantParams();
@ -202,14 +202,12 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto activate_index = node->inputIndex[i];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[i]);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
<< " zp: " << input_quant_param->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
}
}
@ -221,15 +219,18 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
} else {
for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty()) {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
<< " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
}
if (node->quantType != schema::QuantType_AwareTraining &&
!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {
tensor_output->dataType = kNumberTypeInt8;
}
@ -322,18 +323,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
paramTensor->nodeType = schema::NodeType_ValueNode;
paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
for (auto &ite : paramValue->quant_param()) {
auto quantPar = std::make_unique<schema::QuantParamT>();
quantPar->scale = ite->scale;
quantPar->zeroPoint = ite->zeroPoint;
quantPar->min = ite->min;
quantPar->max = ite->max;
quantPar->narrowRange = ite->narrowRange;
quantPar->inited = ite->inited;
quantPar->numBits = ite->numBits;
paramTensor->quantParams.emplace_back(std::move(quantPar));
paramTensor->dataType = paramValue->tensor_type();
}
}
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());

View File

@ -225,7 +225,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim,
PopulaterConv2DSingleGroup(prim, primitive, group);
}
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) {
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);

View File

@ -89,13 +89,15 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
}
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning) {
if (cNode->quantType == schema::QuantType_AwareTraining) {
primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
primTValue->AddInputQuantParam(quant_params);
}
for (int index : cNode->outputIndex) {
primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
primTValue->AddOutputQuantParam(quant_params);
}
}
cNode->primitive = nullptr;

View File

@ -49,17 +49,17 @@ class PrimitiveTValue : public Value {
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {
}
void AddInputQuantParam(schema::QuantParamT quant_param) {
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);
}
std::vector<schema::QuantParamT> GetInputQuantParams() const {
std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const {
return input_quant_param_;
}
void AddOutputQuantParam(schema::QuantParamT quant_param) {
void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->output_quant_param_.emplace_back(quant_param);
}
std::vector<schema::QuantParamT> GetOutputQuantParams() const {
std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const {
return output_quant_param_;
}
@ -69,8 +69,8 @@ class PrimitiveTValue : public Value {
protected:
schema::PrimitiveT *primitive = nullptr;
std::vector<schema::QuantParamT> input_quant_param_;
std::vector<schema::QuantParamT> output_quant_param_;
std::vector<std::vector<schema::QuantParamT>> input_quant_param_;
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};
} // namespace mindspore::lite

View File

@ -131,7 +131,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case mindspore::schema::QuantType_AwareTrainning: {
case mindspore::schema::QuantType_AwareTraining: {
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
break;
}

View File

@ -31,7 +31,7 @@ Flags::Flags() {
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
AddFlag(&Flags::inferenceType, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", "");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", "");
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT");
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127");
@ -98,8 +98,8 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
return 1;
}
if (this->quantTypeIn == "AwareTrainning") {
this->quantType = QuantType_AwareTrainning;
if (this->quantTypeIn == "AwareTraining") {
this->quantType = QuantType_AwareTraining;
} else if (this->quantTypeIn == "WeightQuant") {
this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") {
@ -107,7 +107,7 @@ int Flags::Init(int argc, const char **argv) {
} else if (this->quantTypeIn.empty()) {
this->quantType = QuantType_QUANT_NONE;
} else {
std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining";
std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining";
return 1;
}

View File

@ -27,7 +27,7 @@ namespace lite {
using mindspore::schema::QuantType;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTrainning;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_PostTraining;

View File

@ -68,8 +68,8 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _
void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case QuantType::QuantType_AwareTrainning: {
MS_LOG(INFO) << "create AwareTrainningQuantizer!";
case QuantType::QuantType_AwareTraining: {
MS_LOG(INFO) << "create AwareTrainingQuantizer!";
fbQuantizer =
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
break;
@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) {
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
status = mQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
@ -173,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// if (ctx.quantType == QuantType_AwareTrainning) {
// if (ctx.quantType == QuantType_AwareTraining) {
// formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass());
// }
status = formatTransOptimizer.Run(graphDefT);
@ -193,7 +193,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
// insert quantNode and deQuantNode
if (ctx.quantType == QuantType_AwareTrainning) {
if (ctx.quantType == QuantType_AwareTraining) {
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {

View File

@ -136,7 +136,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
auto &node = *iter;
@ -208,7 +208,7 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTrainning;
transNode->quantType = QuantType_AwareTraining;
if (nodeType == kInt8ToFP32) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;

View File

@ -103,7 +103,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
FormatTransNodeType beforeNodeType, afterNodeType;
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use
// if (quantType == QuantType_AwareTraining) { // AwareTraining op use
// nhwc
// if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only
// support nhwc
@ -120,7 +120,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
// beforeNodeType = kNCHW2NHWC;
// afterNodeType = kNHWC2NCHW;
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc
// if (quantType == QuantType_AwareTraining) { // AwareTraining op use nhwc
// if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc
// continue;
// }

View File

@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
return status;
}
if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) {
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) {
status = QuantDataFormatTrans(graphNode);
if (status != 0) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
@ -96,7 +96,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0;
} else if (fmkType == converter::FmkType_MS) {
switch (node->quantType) {
case QuantType_AwareTrainning: {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_HWCK;
} else {
@ -123,7 +123,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0;
} else if (fmkType == converter::FmkType_TF) {
switch (node->quantType) {
case QuantType_AwareTrainning: {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_HWCK;
} else {
@ -148,7 +148,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
} else if (fmkType == converter::FmkType_TFLITE) {
switch (node->quantType) {
case QuantType_QUANT_NONE:
case QuantType_AwareTrainning:
case QuantType_AwareTraining:
case QuantType_PostTraining: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;
@ -170,7 +170,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0;
} else if (fmkType == converter::FmkType_ONNX) {
switch (node->quantType) {
case QuantType_AwareTrainning: {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;

View File

@ -314,7 +314,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
}
}
if (findQuantParams == needQuantParams) {
dst_op->quantType = schema::QuantType_AwareTrainning;
dst_op->quantType = schema::QuantType_AwareTraining;
} else {
dst_op->quantType = schema::QuantType_QUANT_NONE;
}

View File

@ -324,7 +324,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTrainning;
node->quantType = schema::QuantType_AwareTraining;
}
}
}
@ -337,7 +337,7 @@ STATUS AwareQuantizer::DoQuantize() {
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
continue;
}
if (node->quantType != schema::QuantType_AwareTrainning) {
if (node->quantType != schema::QuantType_AwareTraining) {
continue;
}
STATUS status;
@ -584,7 +584,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
}
}
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTrainning;
node->quantType = schema::QuantType_AwareTraining;
} else {
node->quantType = schema::QuantType_QUANT_NONE;
}

View File

@ -509,7 +509,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M
quant_param.min = max_min->min;
quant_param.numBits = bit_num;
quant_param.narrowRange = false;
lite_primitive->AddInputQuantParam(quant_param);
std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddInputQuantParam(quant_params);
// p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT));
return RET_OK;
}
@ -526,7 +527,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
quant_param.min = max_min->min;
quant_param.numBits = bit_num;
quant_param.narrowRange = false;
lite_primitive->AddOutputQuantParam(quant_param);
std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddOutputQuantParam(quant_params);
// p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT));
return RET_OK;
}
@ -569,7 +571,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input
auto quant_params = input->GetInputQuantParams();
size_t sizeX = quant_params.size();
for (size_t i = 0; i < sizeX; i++) {
input_scales.emplace_back(quant_params[i].scale);
input_scales.emplace_back(quant_params[i].front().scale);
}
size_t sizeY = weight_param->quant_param().size();
if (sizeX != sizeY) {

View File

@ -31,7 +31,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
primTValue->SetQuantType(schema::QuantType_PostTraining);
for (auto &quant_param : quant_params) {
primTValue->AddInputQuantParam(quant_param);
std::vector<schema::QuantParamT> quant_params_in = {quant_param};
primTValue->AddInputQuantParam(quant_params_in);
}
return NewValueNode(primTValue);
}
@ -53,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams());
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
@ -84,11 +85,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
primitiveT_value->GetInputQuantParams());
primitiveT_value->GetInputQuantParams().front());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams());
input_cnode_primitiveT_value->GetInputQuantParams().front());
}
if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! "