!4451 add quant param during anf_import

Merge pull request !4451 from yankai10/merge
This commit is contained in:
mindspore-ci-bot 2020-08-14 17:57:01 +08:00 committed by Gitee
commit 180689c095
13 changed files with 698 additions and 238 deletions

View File

@ -27,7 +27,7 @@
#include "abstract/abstract_value.h"
#include "src/ir/primitive_value.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
#if 0
@ -159,7 +159,7 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) {
}
#endif
int AnfImporter::Import() {
int AnfImporter::Import(const schema::QuantType &quantType) {
ConverterConstTensor();
auto ret = ConverterCNode();
if (RET_OK != ret) {

View File

@ -21,6 +21,7 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "base/base.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
class AnfImporter {
@ -29,7 +30,7 @@ class AnfImporter {
virtual ~AnfImporter() = default;
virtual int Import();
virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE);
virtual FuncGraphPtr GetResult() = 0;

View File

@ -1,5 +1,6 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* This is the C++ adaptation and derivative work of Myia
* (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
@ -17,101 +18,218 @@
*/
#include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
#include <mindspore/lite/src/ir/tensor.h>
#include <memory>
#include <string>
#include <vector>
#include <memory>
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "src/ir/tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
int group = GetValue<int>(prim->GetAttr("group"));
auto primitive = std::make_unique<schema::PrimitiveT>();
if (group > 1) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
void AnfConvPopulater::PopulaterConv2DMultiGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
void AnfConvPopulater::PopulaterConv2DSingleGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev,
float *mMin, float *mMax) {
constexpr float qmin = 0;
constexpr float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfConvPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim->GetAttr("mean");
auto std_dev = prim->GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim->GetAttr("input_minq");
auto inputMax = prim->GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
auto filterMin = prim->GetAttr("filter_minq");
auto filterMax = prim->GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
biasQuantSize = filterMinPtr->DataSize();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
}
quants.clear();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale =
vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
auto outputMax = prim->GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
}
}
int AnfConvPopulater::Populate(const PrimitivePtr &prim,
PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
MS_ASSERT(primitiveTValuePtr != nullptr);
auto primitive = std::make_unique<schema::PrimitiveT>();
int group = GetValue<int>(prim->GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, primitive, group);
} else {
PopulaterConv2DSingleGroup(prim, primitive, group);
}
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
}
return 0;
}
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());

View File

@ -1,5 +1,6 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* This is the C++ adaptation and derivative work of Myia
* (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
@ -18,8 +19,9 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater {
public:
@ -27,6 +29,18 @@ class AnfConvPopulater : public AnfNodePopulater {
~AnfConvPopulater() override = default;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterConv2DMultiGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterConv2DSingleGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -14,15 +14,113 @@
* limitations under the License.
*/
#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include <string>
#include <vector>
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "src/ir/tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean,
const double &stdDev, float *mMin,
float *mMax) {
constexpr float qmin = 0;
constexpr float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim->GetAttr("mean");
auto std_dev = prim->GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim->GetAttr("input_minq");
auto inputMax = prim->GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
auto filterMin = prim->GetAttr("filter_minq");
auto filterMax = prim->GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
biasQuantSize = filterMinPtr->DataSize();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
}
quants.clear();
for (int i = 0; i < biasQuantSize; ++i) {
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale =
vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
auto outputMax = prim->GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
}
}
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim,
PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
@ -36,9 +134,9 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
@ -73,10 +171,13 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
auto abstractBase = paramNode->abstract();
MS_ASSERT(abstractBase != nullptr);
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
auto abstractTensor =
utils::cast<abstract::AbstractTensorPtr>(abstractBase);
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
auto dims =
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
->shape();
attr->channelIn = dims[kAnfPopulaterOne];
}
}
@ -86,8 +187,16 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
}
return 0;
}
AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dPopulater(
"DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater(
"DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
} // namespace mindspore::lite

View File

@ -15,8 +15,9 @@
*/
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
public:
@ -24,6 +25,11 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
~AnfDepwiseconv2DPopulater() override = default;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -14,14 +14,98 @@
* limitations under the License.
*/
#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include <vector>
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "src/ir/tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
float *mMin, float *mMax) {
constexpr float qmin = 0;
constexpr float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfMatmulPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim->GetAttr("mean");
auto std_dev = prim->GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
auto inputMin = prim->GetAttr("input_minq");
auto inputMax = prim->GetAttr("input_maxq");
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->Data());
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
quants.clear();
auto filterMin = prim->GetAttr("filter_minq");
auto filterMax = prim->GetAttr("filter_maxq");
if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->Data());
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
for (int i = 0; i < filterMinPtr->DataSize(); ++i) {
quantParam.min = *(minBuf++);
quantParam.max = *(maxBuf++);
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
}
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
auto outputMax = prim->GetAttr("output_maxq");
if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->Data());
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
narrowRangeQuantParam, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
}
}
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim,
PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MatMulT>();
@ -32,8 +116,16 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
}
return 0;
}
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul",
new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul",
new AnfMatmulPopulater());
} // namespace mindspore::lite

View File

@ -24,6 +24,11 @@ class AnfMatmulPopulater : public AnfNodePopulater {
~AnfMatmulPopulater() override = default;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite

View File

@ -28,18 +28,18 @@
#include <unordered_map>
#include <vector>
#include "schema/inner/model_generated.h"
#include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "schema/inner/model_generated.h"
#include "securec/include/securec.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "src/ir/tensor.h"
#include "src/param_value_lite.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
using string = std::string;
using int32 = int32_t;
@ -60,16 +60,24 @@ enum ParseForm : int {
};
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
{"type", FORM_PARSE_TYPE},
{"scalar", FORM_PARSE_SCALAR},
{"tensor", FORM_PARSE_TENSOR}};
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
};
#if 0
@ -189,15 +197,16 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a
return {};
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return {}; \
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype( \
const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return {}; \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
@ -643,20 +652,21 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
}
#else
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<ValuePtr> attr_value_vec; \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
} \
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype( \
const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<ValuePtr> attr_value_vec; \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
} \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
@ -667,8 +677,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
const onnx::ValueInfoProto &value_proto) {
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(
const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!value_proto.has_type() || !value_proto.has_name()) {
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
@ -691,24 +701,30 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
shape.push_back(tensor_shape.dim(i).dim_value());
}
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
return false;
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
auto type_ptr =
TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
node->set_abstract(abstract_tensor);
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
tensor::Tensor *tensor_info = new tensor::Tensor(
kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
tensor_info->MallocData();
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
const onnx::TensorProto initialize_proto =
default_para_map_[value_proto.name()];
std::string initial_data = initialize_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(),
initial_data.data(), initial_data.size());
if (EOK != ret) {
MS_LOG(ERROR) << "memcpy_s error";
return false;
@ -724,15 +740,18 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
return true;
}
bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) {
bool AnfImporterFromProtobuf::ImportParametersForGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
MS_LOG(INFO) << "Parameters had default paramerer size is: "
<< importProto.initializer_size();
for (int i = 0; i < importProto.initializer_size(); ++i) {
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
if (!initializer_proto.has_name()) {
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
MS_LOG(ERROR)
<< "initializer vector of onnx GraphProto has no name at index: "
<< i;
return false;
}
default_para_map_[initializer_proto.name()] = initializer_proto;
@ -741,7 +760,8 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const onnx::ValueInfoProto &input_proto = importProto.input(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(),
input_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
@ -749,20 +769,25 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:"
<< attr_tensor_type;
return false;
}
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
prim->AddAttr(attr_name,
TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
@ -796,20 +821,59 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
<< attr_tensor_type;
return false;
}
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
return false;
const int attr_tensor_type = attr_tensor.data_type();
const std::string &tensor_buf = attr_tensor.raw_data();
std::vector<int> shape;
auto ret = EOK;
if (attr_tensor.dims_size() != 0) {
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
kDefaultValueSwitchMap[attr_tensor_type], shape);
tensor_info->MallocData();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
tensor_buf.size());
prim->set_attr(attr_name, MakeValue(tensor_info));
} else {
if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) {
size_t data_size = sizeof(double);
double attr_value = 0.0;
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
tensor_buf.size());
prim->set_attr(attr_name, MakeValue<double>(attr_value));
} else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) {
size_t data_size = sizeof(int64_t);
int32_t attr_value = 0;
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
tensor_buf.size());
prim->set_attr(attr_name, MakeValue<int32_t>(attr_value));
} else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) {
size_t data_size = sizeof(bool);
bool attr_value = false;
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
tensor_buf.size());
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
}
}
return ret == EOK;
}
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
bool AnfImporterFromProtobuf::GetAttrValueForCNode(
const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const std::string &attr_name = attr_proto.name();
if (!attr_proto.has_ref_attr_name()) {
@ -833,18 +897,20 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
return false;
}
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
std::vector<int> shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
kDefaultValueSwitchMap[attr_tensor_type], shape);
tensor_info->MallocData();
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
tensor_buf.size());
if (EOK != ret) {
MS_LOG(ERROR) << "memcpy_s error";
return false;
@ -852,14 +918,15 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
new_value_node->set_abstract(abstract_tensor);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ValuePtr value_ptr = nullptr;
switch (attr_tensor_type) {
@ -871,7 +938,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<int32>>(add_data);
value_ptr = MakeValue<std::vector<int32> >(add_data);
}
break;
}
@ -884,7 +951,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<float>>(add_data);
value_ptr = MakeValue<std::vector<float> >(add_data);
}
break;
}
@ -894,7 +961,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
<< attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(value_ptr);
@ -905,23 +973,28 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR)
<< "Obtain ValueNode attr in type-form has not support input type: "
<< attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
auto new_value_node =
NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type =
std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
new_value_node->set_abstract(abs_type);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name,
const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(
const std::string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_SCALAR: {
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
@ -933,12 +1006,14 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
MS_LOG(ERROR)
<< "parse ValueNode value don't support input of ref_attr_name";
return false;
}
}
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(
const onnx::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0);
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
if (!attr_proto.has_ref_attr_name()) {
@ -951,20 +1026,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
}
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(
const onnx::AttributeProto &attr_proto) {
std::vector<int> shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.t();
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape_vec.push_back(attr_tensor.dims(i));
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
MS_EXCEPTION_IF_NULL(abstract_tensor);
return abstract_tensor;
}
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto) {
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
const schema::QuantType &quantType) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
@ -1004,20 +1082,24 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
MS_LOG(ERROR) << node_name << " input " << i << input_name
<< "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
std::string opType = prim->name();
auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
auto node_parser =
AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr;
}
auto primitiveT = std::make_unique<schema::PrimitiveT>();
// auto * primitiveTValue = new PrimitiveTValue(primitiveT.release());
std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release());
std::shared_ptr<PrimitiveTValue> primitiveTValuePtr =
std::make_shared<PrimitiveTValue>(primitiveT.release());
primitiveTValuePtr->SetQuantType(quantType);
node_parser->Populate(prim, primitiveTValuePtr.get(), inputs);
MS_ASSERT(primitiveTValuePtr != nullptr);
inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr));
@ -1048,8 +1130,9 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
return cnode_ptr;
}
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_EXCEPTION_IF_NULL(cnode_ptr);
std::vector<AnfNodePtr> inputs;
@ -1064,7 +1147,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
}
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
maketuple_ptr->set_abstract(
std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(maketuple_ptr);
@ -1077,11 +1161,14 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
const onnx::TypeProto &output_typeproto = output_node.type();
int output_type = output_typeproto.tensor_type().elem_type();
std::vector<int> output_shape;
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size();
++i) {
output_shape.push_back(
output_typeproto.tensor_type().shape().dim(i).dim_value());
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
@ -1095,8 +1182,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
return true;
}
bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) {
bool AnfImporterFromProtobuf::ImportNodesForGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
@ -1110,7 +1198,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
}
continue;
}
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType);
if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
return false;
@ -1122,7 +1210,9 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
}
#endif
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
bool AnfImporterFromProtobuf::BuildFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
MS_EXCEPTION_IF_NULL(debug_info_ptr);
@ -1135,10 +1225,11 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
return false;
}
return ImportNodesForGraph(outputFuncGraph, importProto);
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
}
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(
const onnx::ModelProto &model_proto) {
if (!model_proto.has_producer_name()) {
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
return false;
@ -1159,14 +1250,14 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo
return true;
}
int AnfImporterFromProtobuf::Import() {
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
MS_EXCEPTION_IF_NULL(dstGraph);
if (!ParseModelConfigureInfo(*onnx_model_)) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
}
const onnx::GraphProto &graphBuild = onnx_model_->graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
func_graph_ = nullptr;
return RET_ERROR;
@ -1176,7 +1267,8 @@ int AnfImporterFromProtobuf::Import() {
return RET_OK;
}
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(
const std::string &model_path) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) {
MS_LOG(ERROR) << "open file failed.";

View File

@ -17,20 +17,21 @@
#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
#include <string>
#include <map>
#include <string>
#include <unordered_map>
#include <utility>
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "src/common/anf_importer/anf_importer.h"
#include "abstract/abstract_value.h"
#include "src/common/anf_importer/anf_importer.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
namespace mindspore::lite {
class AnfImporterFromProtobuf : public AnfImporter {
public:
explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model,
FuncGraphPtr func_graph)
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
~AnfImporterFromProtobuf() override = default;
@ -38,15 +39,17 @@ class AnfImporterFromProtobuf : public AnfImporter {
FuncGraphPtr GetResult() override;
int Import() override;
int Import(const schema::QuantType &quantType =
schema::QuantType_QUANT_NONE) override;
private:
void ConverterConstTensor() override {};
int ConverterCNode() override {};
void AddReturnCNode() override {};
void ConverterConstTensor() override{};
int ConverterCNode() override{};
void AddReturnCNode() override{};
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
const onnx::GraphProto &importProto,
const schema::QuantType &quantType);
#if 0
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
@ -78,31 +81,46 @@ class AnfImporterFromProtobuf : public AnfImporter {
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#else
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto,
const schema::QuantType &quantType);
bool BuildParameterForFuncGraph(const ParameterPtr &node,
const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto,
const schema::QuantType &quantType);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
bool GetAttrValueForCNode(const PrimitivePtr &prim,
const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim,
const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim,
const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInTensorForm(const string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
bool ObtainValueNodeInScalarForm(const string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name,
const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
bool ObtainValueNodeInTypeForm(const string &value_node_name,
const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(
const onnx::AttributeProto &attr_proto);
#endif
private:
std::string producer_name_;
int model_version_{};
@ -115,4 +133,3 @@ class AnfImporterFromProtobuf : public AnfImporter {
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_

View File

@ -46,6 +46,9 @@ class PrimitiveTValue : public Value {
}
}
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {
}
void AddInputQuantParam(schema::QuantParamT quant_param) {
this->input_quant_param_.emplace_back(quant_param);
}

View File

@ -73,6 +73,9 @@ class Tensor : public mindspore::tensor::MetaTensor {
size_t Size() const {
size_t size = 0;
switch (this->data_type_) {
case kNumberTypeFloat64:
size = sizeof(double);
break;
case kNumberTypeFloat:
case kNumberTypeFloat32:
size = sizeof(float);

View File

@ -71,7 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
FuncGraphPtr graph = nullptr;
if (flag->fmk == converter::FmkType_MS) {
MS_ASSERT(nullptr != modelImporter);
modelImporter->Import();
modelImporter->Import(flag->quantType);
graph = modelImporter->GetResult();
} else {
MS_ASSERT(nullptr != modelParser);