From 4757565686902bc101516959735d8e3a645bb892 Mon Sep 17 00:00:00 2001 From: yankai Date: Thu, 13 Aug 2020 16:28:41 +0800 Subject: [PATCH] add quant param --- .../src/common/anf_importer/anf_importer.cc | 4 +- .../src/common/anf_importer/anf_importer.h | 3 +- .../anf_populater/anf_conv_populater.cc | 290 ++++++++++++----- .../anf_populater/anf_conv_populater.h | 18 +- .../anf_depthwiseconv2d_populater.cc | 131 +++++++- .../anf_depthwiseconv2d_populater.h | 8 +- .../anf_populater/anf_matmul_populater.cc | 102 +++++- .../anf_populater/anf_matmul_populater.h | 5 + .../anf_importer/import_from_protobuf.cc | 298 ++++++++++++------ .../anf_importer/import_from_protobuf.h | 69 ++-- mindspore/lite/src/ir/primitive_t_value.h | 3 + mindspore/lite/src/ir/tensor.h | 3 + mindspore/lite/tools/converter/converter.cc | 2 +- 13 files changed, 698 insertions(+), 238 deletions(-) diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.cc b/mindspore/lite/src/common/anf_importer/anf_importer.cc index eb9f84eca33..2921f9422ba 100644 --- a/mindspore/lite/src/common/anf_importer/anf_importer.cc +++ b/mindspore/lite/src/common/anf_importer/anf_importer.cc @@ -27,7 +27,7 @@ #include "abstract/abstract_value.h" #include "src/ir/primitive_value.h" #include "include/errorcode.h" - +#include "schema/inner/model_generated.h" namespace mindspore { namespace lite { #if 0 @@ -159,7 +159,7 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { } #endif -int AnfImporter::Import() { +int AnfImporter::Import(const schema::QuantType &quantType) { ConverterConstTensor(); auto ret = ConverterCNode(); if (RET_OK != ret) { diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.h b/mindspore/lite/src/common/anf_importer/anf_importer.h index 3281294f409..87e0edd3dc7 100644 --- a/mindspore/lite/src/common/anf_importer/anf_importer.h +++ b/mindspore/lite/src/common/anf_importer/anf_importer.h @@ -21,6 +21,7 @@ #include "ir/func_graph.h" #include "ir/anf.h" #include "base/base.h" +#include "schema/inner/model_generated.h" namespace mindspore::lite { class AnfImporter { @@ -29,7 +30,7 @@ class AnfImporter { virtual ~AnfImporter() = default; - virtual int Import(); + virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE); virtual FuncGraphPtr GetResult() = 0; diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc index c662f832b82..29407f9c8e4 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc @@ -1,5 +1,6 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * This is the C++ adaptation and derivative work of Myia + * (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * @@ -17,101 +18,218 @@ */ #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" + +#include + +#include #include #include -#include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" + #include "ir/func_graph.h" #include "ir/primitive.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "src/ir/tensor.h" +#include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, - const std::vector &inputs) { - int group = GetValue(prim->GetAttr("group")); - auto primitive = std::make_unique(); - if (group > 1) { - auto attr = std::make_unique(); - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - auto pad_list = GetValue>(prim->GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; - - auto dilation = GetValue>(prim->GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; - - auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; - - auto stride = GetValue>(prim->GetAttr("stride")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - auto pad_mode = GetValue(prim->GetAttr("pad_mode")); - if (pad_mode == "valid") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "same") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - - primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - primitive->value.value = attr.release(); +void AnfConvPopulater::PopulaterConv2DMultiGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group) { + auto attr = std::make_unique(); + auto format = GetValue(prim->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; } else { - auto attr = std::make_unique(); - attr->group = group; - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - auto pad_list = GetValue>(prim->GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; - - auto dilation = GetValue>(prim->GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; - - auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; - - auto stride = GetValue>(prim->GetAttr("stride")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - attr->channelOut = GetValue(prim->GetAttr("out_channel")); - - auto pad_mode = GetValue(prim->GetAttr("pad_mode")); - if (pad_mode == "valid") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "same") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - primitive->value.type = schema::PrimitiveType_Conv2D; - primitive->value.value = attr.release(); + attr->format = schema::Format_NUM_OF_FORMAT; } + auto pad_list = GetValue>(prim->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(prim->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); +} + +void AnfConvPopulater::PopulaterConv2DSingleGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group) { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(prim->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(prim->GetAttr("out_channel")); + + auto pad_mode = GetValue(prim->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); +} + +void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, + float *mMin, float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void AnfConvPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecQuantParam) { + auto narrow_range = prim->GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim->GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim->GetAttr("mean"); + auto std_dev = prim->GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim->GetAttr("input_minq"); + auto inputMax = prim->GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + + quants.clear(); + int biasQuantSize = 0; + auto filterMin = prim->GetAttr("filter_minq"); + auto filterMax = prim->GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + biasQuantSize = filterMinPtr->DataSize(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecQuantParam->emplace_back(quants); + } + + quants.clear(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + quantParam.scale = + vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + quants.emplace_back(quantParam); + } + vecQuantParam->emplace_back(quants); + + quants.clear(); + auto outputMin = prim->GetAttr("output_minq"); + auto outputMax = prim->GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + } +} + +int AnfConvPopulater::Populate(const PrimitivePtr &prim, + PrimitiveTValue *primitiveTValuePtr, + const std::vector &inputs) { MS_ASSERT(primitiveTValuePtr != nullptr); + auto primitive = std::make_unique(); + + int group = GetValue(prim->GetAttr("group")); + if (group > 1) { + PopulaterConv2DMultiGroup(prim, primitive, group); + } else { + PopulaterConv2DSingleGroup(prim, primitive, group); + } primitiveTValuePtr->SetPrimitiveT(primitive.release()); + if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) { + std::vector> vecQuantParam; + PopulaterQuantParam(prim, &vecQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + } return 0; } AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h index 5614f4c7cc4..eb2905a8bbc 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h @@ -1,5 +1,6 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * This is the C++ adaptation and derivative work of Myia + * (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * @@ -18,8 +19,9 @@ #ifndef MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include +#include +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite { class AnfConvPopulater : public AnfNodePopulater { public: @@ -27,6 +29,18 @@ class AnfConvPopulater : public AnfNodePopulater { ~AnfConvPopulater() override = default; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) override; + + private: + void PopulaterConv2DMultiGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group); + void PopulaterConv2DSingleGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group); + void PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc index b13bc6c8225..9db2805947f 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc @@ -14,15 +14,113 @@ * limitations under the License. */ #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" -#include -#include + #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include +#include + #include "ir/func_graph.h" #include "ir/primitive.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "src/ir/tensor.h" +#include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, + const double &stdDev, float *mMin, + float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void AnfDepwiseconv2DPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecQuantParam) { + auto narrow_range = prim->GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim->GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim->GetAttr("mean"); + auto std_dev = prim->GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim->GetAttr("input_minq"); + auto inputMax = prim->GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + + quants.clear(); + int biasQuantSize = 0; + auto filterMin = prim->GetAttr("filter_minq"); + auto filterMax = prim->GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + biasQuantSize = filterMinPtr->DataSize(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecQuantParam->emplace_back(quants); + } + + quants.clear(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + quantParam.scale = + vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + quants.emplace_back(quantParam); + } + vecQuantParam->emplace_back(quants); + + quants.clear(); + auto outputMin = prim->GetAttr("output_minq"); + auto outputMax = prim->GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + } +} + +int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, + PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -36,9 +134,9 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu attr->format = schema::Format_NUM_OF_FORMAT; } auto pad_list = GetValue>(prim->GetAttr("pads")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; attr->padRight = pad_list[3]; auto dilation = GetValue>(prim->GetAttr("dilation")); @@ -73,10 +171,13 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu auto abstractBase = paramNode->abstract(); MS_ASSERT(abstractBase != nullptr); if (utils::isa(abstractBase)) { - auto abstractTensor = utils::cast(abstractBase); + auto abstractTensor = + utils::cast(abstractBase); MS_ASSERT(abstractTensor != nullptr); if (utils::isa(abstractTensor->BuildShape())) { - auto dims = utils::cast(abstractTensor->BuildShape())->shape(); + auto dims = + utils::cast(abstractTensor->BuildShape()) + ->shape(); attr->channelIn = dims[kAnfPopulaterOne]; } } @@ -86,8 +187,16 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu primitive->value.value = attr.release(); MS_ASSERT(primitiveTValuePtr != nullptr); primitiveTValuePtr->SetPrimitiveT(primitive.release()); + + if (primitiveTValuePtr->GetQuantType()) { + std::vector> vecQuantParam; + PopulaterQuantParam(prim, &vecQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + } return 0; } -AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); -AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dPopulater( + "DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater( + "DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h index c9b63e710db..6377ea372f7 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h @@ -15,8 +15,9 @@ */ #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" #include + +#include "src/common/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite { class AnfDepwiseconv2DPopulater : public AnfNodePopulater { public: @@ -24,6 +25,11 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { ~AnfDepwiseconv2DPopulater() override = default; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) override; + private: + void PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc index 109f7dea7ae..b6bb8908562 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc @@ -14,14 +14,98 @@ * limitations under the License. */ #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" -#include + #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include + #include "ir/func_graph.h" #include "ir/primitive.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "src/ir/tensor.h" +#include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, + float *mMin, float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void AnfMatmulPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecQuantParam) { + auto narrow_range = prim->GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim->GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim->GetAttr("mean"); + auto std_dev = prim->GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim->GetAttr("input_minq"); + auto inputMax = prim->GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + + quants.clear(); + auto filterMin = prim->GetAttr("filter_minq"); + auto filterMax = prim->GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + for (int i = 0; i < filterMinPtr->DataSize(); ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecQuantParam->emplace_back(quants); + } + + quants.clear(); + auto outputMin = prim->GetAttr("output_minq"); + auto outputMax = prim->GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, + narrowRangeQuantParam, numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecQuantParam->emplace_back(quants); + } +} + +int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, + PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -32,8 +116,16 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim primitive->value.value = attr.release(); MS_ASSERT(primitiveTValuePtr != nullptr); primitiveTValuePtr->SetPrimitiveT(primitive.release()); + if (primitiveTValuePtr->GetQuantType()) { + std::vector> vecQuantParam; + PopulaterQuantParam(prim, &vecQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + } + return 0; } -AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); -AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", + new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", + new AnfMatmulPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h index 651b41c9d75..3ce23f5389f 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h @@ -24,6 +24,11 @@ class AnfMatmulPopulater : public AnfNodePopulater { ~AnfMatmulPopulater() override = default; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) override; + private: + void PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc index db11cdcaab0..9bd7a41e284 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -28,18 +28,18 @@ #include #include -#include "schema/inner/model_generated.h" #include "frontend/operator/ops.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "include/errorcode.h" #include "ir/anf.h" #include "ir/func_graph.h" +#include "schema/inner/model_generated.h" +#include "securec/include/securec.h" +#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "src/ir/tensor.h" #include "src/param_value_lite.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" -#include "securec/include/securec.h" -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" using string = std::string; using int32 = int32_t; @@ -60,16 +60,24 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, + {"scalar", FORM_PARSE_SCALAR}, + {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, + {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, + {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, + {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, + {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, + {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, + {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; #if 0 @@ -189,15 +197,16 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map(attr_tensor.type##_data(0)); \ - return MakeValue(value); \ - } else { \ - MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ - } \ - return {}; \ +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype( \ + const onnx::TensorProto &attr_tensor) { \ + if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ + } \ + return {}; \ } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) @@ -643,20 +652,21 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc } #else -#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ - void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ - const onnx::TensorProto &attr_tensor) { \ - MS_EXCEPTION_IF_NULL(prim); \ - std::vector attr_value_vec; \ - for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ - auto value = static_cast(attr_tensor.type##_data(i)); \ - attr_value_vec.push_back(MakeValue(value)); \ - } \ - if (attr_value_vec.size() == 1) { \ - prim->AddAttr(attr_name, attr_value_vec[0]); \ - } else { \ - prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ - } \ +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + void ParseAttrInScalar_##type##_##valuetype( \ + const PrimitivePtr &prim, const std::string &attr_name, \ + const onnx::TensorProto &attr_tensor) { \ + MS_EXCEPTION_IF_NULL(prim); \ + std::vector attr_value_vec; \ + for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ + auto value = static_cast(attr_tensor.type##_data(i)); \ + attr_value_vec.push_back(MakeValue(value)); \ + } \ + if (attr_value_vec.size() == 1) { \ + prim->AddAttr(attr_name, attr_value_vec[0]); \ + } else { \ + prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ + } \ } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) @@ -667,8 +677,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) -bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, - const onnx::ValueInfoProto &value_proto) { +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( + const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { MS_EXCEPTION_IF_NULL(node); if (!value_proto.has_type() || !value_proto.has_name()) { MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; @@ -691,24 +701,30 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod shape.push_back(tensor_shape.dim(i).dim_value()); } - if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == + kDefaultValueSwitchMap.end()) { MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; return false; } - auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); - auto abstract_tensor = std::make_shared(type_ptr, shape); + auto type_ptr = + TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = + std::make_shared(type_ptr, shape); node->set_abstract(abstract_tensor); if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { - tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + tensor::Tensor *tensor_info = new tensor::Tensor( + kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); MS_EXCEPTION_IF_NULL(tensor_info); tensor_info->MallocData(); - const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; + const onnx::TensorProto initialize_proto = + default_para_map_[value_proto.name()]; std::string initial_data = initialize_proto.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); MS_EXCEPTION_IF_NULL(tensor_data_buf); - auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), + initial_data.data(), initial_data.size()); if (EOK != ret) { MS_LOG(ERROR) << "memcpy_s error"; return false; @@ -724,15 +740,18 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod return true; } -bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { +bool AnfImporterFromProtobuf::ImportParametersForGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); + MS_LOG(INFO) << "Parameters had default paramerer size is: " + << importProto.initializer_size(); for (int i = 0; i < importProto.initializer_size(); ++i) { const onnx::TensorProto &initializer_proto = importProto.initializer(i); if (!initializer_proto.has_name()) { - MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; + MS_LOG(ERROR) + << "initializer vector of onnx GraphProto has no name at index: " + << i; return false; } default_para_map_[initializer_proto.name()] = initializer_proto; @@ -741,7 +760,8 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); for (int i = 0; i < importProto.input_size(); ++i) { const onnx::ValueInfoProto &input_proto = importProto.input(i); - if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), + input_proto)) { MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; return false; } @@ -749,20 +769,25 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; + if (kDefaultValueSwitchMap.find(attr_tensor_type) == + kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" + << attr_tensor_type; return false; } - prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + prim->AddAttr(attr_name, + TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); const int attr_tensor_type = attr_tensor.data_type(); switch (attr_tensor_type) { @@ -796,20 +821,59 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr break; } default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " + << attr_tensor_type; return false; } return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( + const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); - MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; - return false; + const int attr_tensor_type = attr_tensor.data_type(); + const std::string &tensor_buf = attr_tensor.raw_data(); + std::vector shape; + auto ret = EOK; + if (attr_tensor.dims_size() != 0) { + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared( + kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor_info->MallocData(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), + tensor_buf.size()); + prim->set_attr(attr_name, MakeValue(tensor_info)); + } else { + if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { + size_t data_size = sizeof(double); + double attr_value = 0.0; + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), + tensor_buf.size()); + prim->set_attr(attr_name, MakeValue(attr_value)); + } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { + size_t data_size = sizeof(int64_t); + int32_t attr_value = 0; + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), + tensor_buf.size()); + prim->set_attr(attr_name, MakeValue(attr_value)); + } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { + size_t data_size = sizeof(bool); + bool attr_value = false; + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), + tensor_buf.size()); + prim->set_attr(attr_name, MakeValue(attr_value)); + } + } + + return ret == EOK; } -bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { +bool AnfImporterFromProtobuf::GetAttrValueForCNode( + const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { MS_EXCEPTION_IF_NULL(prim); const std::string &attr_name = attr_proto.name(); if (!attr_proto.has_ref_attr_name()) { @@ -833,18 +897,20 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con return false; } } -bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); std::vector shape; for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape.push_back(attr_tensor.dims(i)); } - tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor::TensorPtr tensor_info = std::make_shared( + kDefaultValueSwitchMap[attr_tensor_type], shape); tensor_info->MallocData(); const std::string &tensor_buf = attr_tensor.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); - auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), + tensor_buf.size()); if (EOK != ret) { MS_LOG(ERROR) << "memcpy_s error"; return false; @@ -852,14 +918,15 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val auto new_value_node = NewValueNode(MakeValue(tensor_info)); MS_EXCEPTION_IF_NULL(new_value_node); auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); - auto abstract_tensor = std::make_shared(type_ptr, shape); + auto abstract_tensor = + std::make_shared(type_ptr, shape); new_value_node->set_abstract(abstract_tensor); anfnode_build_map_[value_node_name] = new_value_node; return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); ValuePtr value_ptr = nullptr; switch (attr_tensor_type) { @@ -871,7 +938,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val if (add_data.size() == 1) { value_ptr = MakeValue(add_data[0]); } else if (!add_data.empty()) { - value_ptr = MakeValue>(add_data); + value_ptr = MakeValue >(add_data); } break; } @@ -884,7 +951,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val if (add_data.size() == 1) { value_ptr = MakeValue(add_data[0]); } else if (!add_data.empty()) { - value_ptr = MakeValue>(add_data); + value_ptr = MakeValue >(add_data); } break; } @@ -894,7 +961,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val break; } default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " + << attr_tensor_type; return false; } auto new_value_node = NewValueNode(value_ptr); @@ -905,23 +973,28 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( + const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + if (kDefaultValueSwitchMap.find(attr_tensor_type) == + kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) + << "Obtain ValueNode attr in type-form has not support input type: " + << attr_tensor_type; return false; } - auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + auto new_value_node = + NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = + std::make_shared(std::make_shared()); new_value_node->set_abstract(abs_type); anfnode_build_map_[value_node_name] = new_value_node; return true; } -bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, - const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::GetAttrValueForValueNode( + const std::string &ref_attr_name, const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { switch (kParseTypeSwitchMap[ref_attr_name]) { case FORM_PARSE_SCALAR: { return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); @@ -933,12 +1006,14 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); } default: - MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; + MS_LOG(ERROR) + << "parse ValueNode value don't support input of ref_attr_name"; return false; } } -bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( + const onnx::NodeProto &node_proto) { const std::string &value_node_name = node_proto.output(0); const onnx::AttributeProto &attr_proto = node_proto.attribute(0); if (!attr_proto.has_ref_attr_name()) { @@ -951,20 +1026,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); } -abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { +abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( + const onnx::AttributeProto &attr_proto) { std::vector shape_vec; const onnx::TensorProto &attr_tensor = attr_proto.t(); for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape_vec.push_back(attr_tensor.dims(i)); } auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); - auto abstract_tensor = std::make_shared(type_ptr, shape_vec); + auto abstract_tensor = + std::make_shared(type_ptr, shape_vec); MS_EXCEPTION_IF_NULL(abstract_tensor); return abstract_tensor; } -CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::NodeProto &node_proto) { +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); if (!node_proto.has_op_type()) { MS_LOG(ERROR) << "Get CNode op_type failed!"; @@ -1004,20 +1082,24 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out for (int i = 0; i < node_proto.input_size(); ++i) { const std::string &input_name = node_proto.input(i); if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + MS_LOG(ERROR) << node_name << " input " << i << input_name + << "can't find in nodes have parsed"; return nullptr; } inputs.push_back(anfnode_build_map_[input_name]); } std::string opType = prim->name(); - auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + auto node_parser = + AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); if (node_parser == nullptr) { MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; return nullptr; } auto primitiveT = std::make_unique(); // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); - std::shared_ptr primitiveTValuePtr = std::make_shared(primitiveT.release()); + std::shared_ptr primitiveTValuePtr = + std::make_shared(primitiveT.release()); + primitiveTValuePtr->SetQuantType(quantType); node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); MS_ASSERT(primitiveTValuePtr != nullptr); inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); @@ -1048,8 +1130,9 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out return cnode_ptr; } -bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_EXCEPTION_IF_NULL(cnode_ptr); std::vector inputs; @@ -1064,7 +1147,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output elem.push_back(anfnode_build_map_[out_tuple]->abstract()); } auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); - maketuple_ptr->set_abstract(std::make_shared(elem)); + maketuple_ptr->set_abstract( + std::make_shared(elem)); inputs.clear(); inputs.push_back(NewValueNode(prim::kPrimReturn)); inputs.push_back(maketuple_ptr); @@ -1077,11 +1161,14 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); std::vector output_shape; - for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { - output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); + ++i) { + output_shape.push_back( + output_typeproto.tensor_type().shape().dim(i).dim_value()); } auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); - auto abstract_tensor = std::make_shared(type_ptr, output_shape); + auto abstract_tensor = + std::make_shared(type_ptr, output_shape); inputs.clear(); inputs.push_back(NewValueNode(prim::kPrimReturn)); @@ -1095,8 +1182,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output return true; } -bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { +bool AnfImporterFromProtobuf::ImportNodesForGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); CNodePtr cnode_ptr = nullptr; @@ -1110,7 +1198,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc } continue; } - cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; return false; @@ -1122,7 +1210,9 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc } #endif -bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { +bool AnfImporterFromProtobuf::BuildFuncGraph( + const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); MS_EXCEPTION_IF_NULL(debug_info_ptr); @@ -1135,10 +1225,11 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph if (!ImportParametersForGraph(outputFuncGraph, importProto)) { return false; } - return ImportNodesForGraph(outputFuncGraph, importProto); + return ImportNodesForGraph(outputFuncGraph, importProto, quantType); } -bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { +bool AnfImporterFromProtobuf::ParseModelConfigureInfo( + const onnx::ModelProto &model_proto) { if (!model_proto.has_producer_name()) { MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; return false; @@ -1159,14 +1250,14 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo return true; } -int AnfImporterFromProtobuf::Import() { +int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { FuncGraphPtr dstGraph = std::make_shared(); MS_EXCEPTION_IF_NULL(dstGraph); if (!ParseModelConfigureInfo(*onnx_model_)) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; } const onnx::GraphProto &graphBuild = onnx_model_->graph(); - if (!BuildFuncGraph(dstGraph, graphBuild)) { + if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) { MS_LOG(ERROR) << "Build funcgraph failed!"; func_graph_ = nullptr; return RET_ERROR; @@ -1176,7 +1267,8 @@ int AnfImporterFromProtobuf::Import() { return RET_OK; } -onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { +onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( + const std::string &model_path) { std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { MS_LOG(ERROR) << "open file failed."; diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h index 24502ffa90d..e7064fab39a 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h @@ -17,20 +17,21 @@ #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ -#include #include +#include #include #include -#include "tools/converter/parser/onnx/onnx.pb.h" -#include "src/common/anf_importer/anf_importer.h" #include "abstract/abstract_value.h" +#include "src/common/anf_importer/anf_importer.h" +#include "tools/converter/parser/onnx/onnx.pb.h" namespace mindspore::lite { class AnfImporterFromProtobuf : public AnfImporter { public: - explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) - : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} + explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, + FuncGraphPtr func_graph) + : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} ~AnfImporterFromProtobuf() override = default; @@ -38,15 +39,17 @@ class AnfImporterFromProtobuf : public AnfImporter { FuncGraphPtr GetResult() override; - int Import() override; + int Import(const schema::QuantType &quantType = + schema::QuantType_QUANT_NONE) override; private: - void ConverterConstTensor() override {}; - int ConverterCNode() override {}; - void AddReturnCNode() override {}; + void ConverterConstTensor() override{}; + int ConverterCNode() override{}; + void AddReturnCNode() override{}; bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto); + const onnx::GraphProto &importProto, + const schema::QuantType &quantType); #if 0 bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); @@ -78,31 +81,46 @@ class AnfImporterFromProtobuf : public AnfImporter { std::unordered_map GetAbstractForCNode(const onnx::AttributeProto &attr_proto); #else - bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); - bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); - CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); - bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const schema::QuantType &quantType); + bool BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto, + const schema::QuantType &quantType); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr); - bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); - bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + bool GetAttrValueForCNode(const PrimitivePtr &prim, + const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, + const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, + bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, + const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, + const std::string &attr_name, const onnx::TensorProto &attr_tensor); bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); - bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + bool ObtainValueNodeInTensorForm(const string &value_node_name, + const onnx::TensorProto &attr_tensor); - bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, + bool ObtainValueNodeInScalarForm(const string &value_node_name, + const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const string &ref_attr_name, + const std::string &value_node_name, const onnx::TensorProto &attr_tensor); - bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + bool ObtainValueNodeInTypeForm(const string &value_node_name, + const onnx::TensorProto &attr_tensor); + abstract::AbstractTensorPtr GetAbstractForCNode( + const onnx::AttributeProto &attr_proto); #endif - private: std::string producer_name_; int model_version_{}; @@ -115,4 +133,3 @@ class AnfImporterFromProtobuf : public AnfImporter { } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ - diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h index b13f4606eb1..7de7250b231 100644 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -46,6 +46,9 @@ class PrimitiveTValue : public Value { } } + void SetInputQuantParam(std::vector> vec_quant_param) { + } + void AddInputQuantParam(schema::QuantParamT quant_param) { this->input_quant_param_.emplace_back(quant_param); } diff --git a/mindspore/lite/src/ir/tensor.h b/mindspore/lite/src/ir/tensor.h index 3585633c0ab..6dc21c46132 100644 --- a/mindspore/lite/src/ir/tensor.h +++ b/mindspore/lite/src/ir/tensor.h @@ -73,6 +73,9 @@ class Tensor : public mindspore::tensor::MetaTensor { size_t Size() const { size_t size = 0; switch (this->data_type_) { + case kNumberTypeFloat64: + size = sizeof(double); + break; case kNumberTypeFloat: case kNumberTypeFloat32: size = sizeof(float); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index f7a3905c02d..828cfefc032 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -71,7 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { FuncGraphPtr graph = nullptr; if (flag->fmk == converter::FmkType_MS) { MS_ASSERT(nullptr != modelImporter); - modelImporter->Import(); + modelImporter->Import(flag->quantType); graph = modelImporter->GetResult(); } else { MS_ASSERT(nullptr != modelParser);