forked from mindspore-Ecosystem/mindspore
!4451 add quant param during anf_import
Merge pull request !4451 from yankai10/merge
This commit is contained in:
commit
180689c095
|
@ -27,7 +27,7 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "src/ir/primitive_value.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
#include "schema/inner/model_generated.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#if 0
|
||||
|
@ -159,7 +159,7 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) {
|
|||
}
|
||||
#endif
|
||||
|
||||
int AnfImporter::Import() {
|
||||
int AnfImporter::Import(const schema::QuantType &quantType) {
|
||||
ConverterConstTensor();
|
||||
auto ret = ConverterCNode();
|
||||
if (RET_OK != ret) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "base/base.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfImporter {
|
||||
|
@ -29,7 +30,7 @@ class AnfImporter {
|
|||
|
||||
virtual ~AnfImporter() = default;
|
||||
|
||||
virtual int Import();
|
||||
virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE);
|
||||
|
||||
virtual FuncGraphPtr GetResult() = 0;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
* This is the C++ adaptation and derivative work of Myia
|
||||
* (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -17,101 +18,218 @@
|
|||
*/
|
||||
|
||||
#include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
|
||||
|
||||
#include <mindspore/lite/src/ir/tensor.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
int group = GetValue<int>(prim->GetAttr("group"));
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (group > 1) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
void AnfConvPopulater::PopulaterConv2DMultiGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void AnfConvPopulater::PopulaterConv2DSingleGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) {
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
attr->group = group;
|
||||
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
} else if (format == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
} else if (pad_mode == "same") {
|
||||
attr->padMode = schema::PadMode_SAME;
|
||||
} else {
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev,
|
||||
float *mMin, float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfConvPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim->GetAttr("mean");
|
||||
auto std_dev = prim->GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim->GetAttr("input_minq");
|
||||
auto inputMax = prim->GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
auto filterMin = prim->GetAttr("filter_minq");
|
||||
auto filterMax = prim->GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
biasQuantSize = filterMinPtr->DataSize();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale =
|
||||
vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
auto outputMax = prim->GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int AnfConvPopulater::Populate(const PrimitivePtr &prim,
|
||||
PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
|
||||
int group = GetValue<int>(prim->GetAttr("group"));
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, primitive, group);
|
||||
} else {
|
||||
PopulaterConv2DSingleGroup(prim, primitive, group);
|
||||
}
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
* This is the C++ adaptation and derivative work of Myia
|
||||
* (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -18,8 +19,9 @@
|
|||
|
||||
#ifndef MINDSPORE_ANF_CONV_PARSER_H
|
||||
#define MINDSPORE_ANF_CONV_PARSER_H
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
|
||||
namespace mindspore::lite {
|
||||
class AnfConvPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
|
@ -27,6 +29,18 @@ class AnfConvPopulater : public AnfNodePopulater {
|
|||
~AnfConvPopulater() override = default;
|
||||
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
private:
|
||||
void PopulaterConv2DMultiGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
||||
void PopulaterConv2DSingleGroup(
|
||||
const PrimitivePtr &prim,
|
||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -14,15 +14,113 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <memory>
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean,
|
||||
const double &stdDev, float *mMin,
|
||||
float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim->GetAttr("mean");
|
||||
auto std_dev = prim->GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim->GetAttr("input_minq");
|
||||
auto inputMax = prim->GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
int biasQuantSize = 0;
|
||||
auto filterMin = prim->GetAttr("filter_minq");
|
||||
auto filterMax = prim->GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
biasQuantSize = filterMinPtr->DataSize();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
for (int i = 0; i < biasQuantSize; ++i) {
|
||||
quantParam.min = 0.0;
|
||||
quantParam.max = 0.0;
|
||||
quantParam.zeroPoint = 0;
|
||||
quantParam.scale =
|
||||
vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
auto outputMax = prim->GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim,
|
||||
PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
|
@ -36,9 +134,9 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
|
|||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads"));
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
||||
|
@ -73,10 +171,13 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
|
|||
auto abstractBase = paramNode->abstract();
|
||||
MS_ASSERT(abstractBase != nullptr);
|
||||
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
auto abstractTensor =
|
||||
utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
MS_ASSERT(abstractTensor != nullptr);
|
||||
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
auto dims =
|
||||
utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())
|
||||
->shape();
|
||||
attr->channelIn = dims[kAnfPopulaterOne];
|
||||
}
|
||||
}
|
||||
|
@ -86,8 +187,16 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
|
|||
primitive->value.value = attr.release();
|
||||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
|
||||
if (primitiveTValuePtr->GetQuantType()) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dPopulater(
|
||||
"DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater(
|
||||
"DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -15,8 +15,9 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
|
||||
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
|
||||
namespace mindspore::lite {
|
||||
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
|
@ -24,6 +25,11 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
|
|||
~AnfDepwiseconv2DPopulater() override = default;
|
||||
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
private:
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -14,14 +14,98 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h"
|
||||
#include <vector>
|
||||
|
||||
#include <memory>
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include <vector>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
|
||||
float *mMin, float *mMax) {
|
||||
constexpr float qmin = 0;
|
||||
constexpr float qmax = 255;
|
||||
*mMin = static_cast<float>((qmin - mean) / stdDev);
|
||||
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
||||
}
|
||||
|
||||
void AnfMatmulPopulater::PopulaterQuantParam(
|
||||
const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
|
||||
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quantParam;
|
||||
auto mean = prim->GetAttr("mean");
|
||||
auto std_dev = prim->GetAttr("std_dev");
|
||||
if (mean != nullptr && std_dev != nullptr) {
|
||||
auto meanQuantOaram = GetValue<double>(mean);
|
||||
double stddevQuantOaram = GetValue<double>(std_dev);
|
||||
float mMin = 0.0;
|
||||
float mMax = 0.0;
|
||||
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
|
||||
quantParam.min = mMin;
|
||||
quantParam.max = mMax;
|
||||
} else {
|
||||
auto inputMin = prim->GetAttr("input_minq");
|
||||
auto inputMax = prim->GetAttr("input_maxq");
|
||||
auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(inputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(inputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
}
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
|
||||
quants.clear();
|
||||
auto filterMin = prim->GetAttr("filter_minq");
|
||||
auto filterMax = prim->GetAttr("filter_maxq");
|
||||
if (filterMin != nullptr && filterMax != nullptr) {
|
||||
auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>();
|
||||
auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(filterMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(filterMaxPtr->Data());
|
||||
for (int i = 0; i < filterMinPtr->DataSize(); ++i) {
|
||||
quantParam.min = *(minBuf++);
|
||||
quantParam.max = *(maxBuf++);
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam,
|
||||
numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
}
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
|
||||
quants.clear();
|
||||
auto outputMin = prim->GetAttr("output_minq");
|
||||
auto outputMax = prim->GetAttr("output_maxq");
|
||||
if (outputMin != nullptr && outputMax != nullptr) {
|
||||
auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>();
|
||||
auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>();
|
||||
float *minBuf = static_cast<float *>(outputMinPtr->Data());
|
||||
float *maxBuf = static_cast<float *>(outputMaxPtr->Data());
|
||||
quantParam.min = *minBuf;
|
||||
quantParam.max = *maxBuf;
|
||||
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max,
|
||||
narrowRangeQuantParam, numbitsRangeQuantParam);
|
||||
quants.emplace_back(quantParam);
|
||||
vecQuantParam->emplace_back(quants);
|
||||
}
|
||||
}
|
||||
|
||||
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim,
|
||||
PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
auto attr = std::make_unique<schema::MatMulT>();
|
||||
|
@ -32,8 +116,16 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
|
|||
primitive->value.value = attr.release();
|
||||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
if (primitiveTValuePtr->GetQuantType()) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
|
||||
AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater());
|
||||
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul",
|
||||
new AnfMatmulPopulater());
|
||||
AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul",
|
||||
new AnfMatmulPopulater());
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -24,6 +24,11 @@ class AnfMatmulPopulater : public AnfNodePopulater {
|
|||
~AnfMatmulPopulater() override = default;
|
||||
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
||||
const std::vector<AnfNodePtr> &inputs) override;
|
||||
private:
|
||||
void PopulaterQuantParam(const PrimitivePtr &prim,
|
||||
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
|
||||
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
|
||||
float *mMax);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -28,18 +28,18 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
||||
|
||||
using string = std::string;
|
||||
using int32 = int32_t;
|
||||
|
@ -60,16 +60,24 @@ enum ParseForm : int {
|
|||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
|
||||
{"type", FORM_PARSE_TYPE},
|
||||
{"scalar", FORM_PARSE_SCALAR},
|
||||
{"tensor", FORM_PARSE_TENSOR}};
|
||||
|
||||
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
|
||||
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
|
||||
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
|
||||
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
|
||||
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
|
||||
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
};
|
||||
|
||||
#if 0
|
||||
|
@ -189,15 +197,16 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a
|
|||
return {};
|
||||
}
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
|
||||
if (attr_tensor.type##_data_size() == 1) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} else { \
|
||||
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||
} \
|
||||
return {}; \
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype( \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
if (attr_tensor.type##_data_size() == 1) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} else { \
|
||||
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||
} \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
|
@ -643,20 +652,21 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
}
|
||||
#else
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
MS_EXCEPTION_IF_NULL(prim); \
|
||||
std::vector<ValuePtr> attr_value_vec; \
|
||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
||||
} \
|
||||
if (attr_value_vec.size() == 1) { \
|
||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
||||
} else { \
|
||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
||||
} \
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
void ParseAttrInScalar_##type##_##valuetype( \
|
||||
const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
MS_EXCEPTION_IF_NULL(prim); \
|
||||
std::vector<ValuePtr> attr_value_vec; \
|
||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
||||
} \
|
||||
if (attr_value_vec.size() == 1) { \
|
||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
||||
} else { \
|
||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
||||
} \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
|
@ -667,8 +677,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
|||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto) {
|
||||
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(
|
||||
const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||
|
@ -691,24 +701,30 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
|
|||
shape.push_back(tensor_shape.dim(i).dim_value());
|
||||
}
|
||||
|
||||
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
|
||||
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
auto type_ptr =
|
||||
TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
node->set_abstract(abstract_tensor);
|
||||
|
||||
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
||||
tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||
tensor::Tensor *tensor_info = new tensor::Tensor(
|
||||
kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
tensor_info->MallocData();
|
||||
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
|
||||
const onnx::TensorProto initialize_proto =
|
||||
default_para_map_[value_proto.name()];
|
||||
std::string initial_data = initialize_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(),
|
||||
initial_data.data(), initial_data.size());
|
||||
if (EOK != ret) {
|
||||
MS_LOG(ERROR) << "memcpy_s error";
|
||||
return false;
|
||||
|
@ -724,15 +740,18 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
bool AnfImporterFromProtobuf::ImportParametersForGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: "
|
||||
<< importProto.initializer_size();
|
||||
|
||||
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
||||
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
|
||||
if (!initializer_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
|
||||
MS_LOG(ERROR)
|
||||
<< "initializer vector of onnx GraphProto has no name at index: "
|
||||
<< i;
|
||||
return false;
|
||||
}
|
||||
default_para_map_[initializer_proto.name()] = initializer_proto;
|
||||
|
@ -741,7 +760,8 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu
|
|||
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const onnx::ValueInfoProto &input_proto = importProto.input(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(),
|
||||
input_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
|
@ -749,20 +769,25 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:"
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
prim->AddAttr(attr_name,
|
||||
TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
switch (attr_tensor_type) {
|
||||
|
@ -796,20 +821,59 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr
|
|||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
|
||||
return false;
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
std::vector<int> shape;
|
||||
auto ret = EOK;
|
||||
if (attr_tensor.dims_size() != 0) {
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
|
||||
kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
tensor_info->MallocData();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
|
||||
ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
prim->set_attr(attr_name, MakeValue(tensor_info));
|
||||
} else {
|
||||
if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) {
|
||||
size_t data_size = sizeof(double);
|
||||
double attr_value = 0.0;
|
||||
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
prim->set_attr(attr_name, MakeValue<double>(attr_value));
|
||||
} else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) {
|
||||
size_t data_size = sizeof(int64_t);
|
||||
int32_t attr_value = 0;
|
||||
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
prim->set_attr(attr_name, MakeValue<int32_t>(attr_value));
|
||||
} else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) {
|
||||
size_t data_size = sizeof(bool);
|
||||
bool attr_value = false;
|
||||
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
|
||||
}
|
||||
}
|
||||
|
||||
return ret == EOK;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(
|
||||
const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
|
@ -833,18 +897,20 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
|||
return false;
|
||||
}
|
||||
}
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
std::vector<int> shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
|
||||
kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
tensor_info->MallocData();
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
if (EOK != ret) {
|
||||
MS_LOG(ERROR) << "memcpy_s error";
|
||||
return false;
|
||||
|
@ -852,14 +918,15 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
new_value_node->set_abstract(abstract_tensor);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ValuePtr value_ptr = nullptr;
|
||||
switch (attr_tensor_type) {
|
||||
|
@ -871,7 +938,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
|
|||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<int32>>(add_data);
|
||||
value_ptr = MakeValue<std::vector<int32> >(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -884,7 +951,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
|
|||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<float>>(add_data);
|
||||
value_ptr = MakeValue<std::vector<float> >(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -894,7 +961,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
|
|||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(value_ptr);
|
||||
|
@ -905,23 +973,28 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Obtain ValueNode attr in type-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
auto new_value_node =
|
||||
NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type =
|
||||
std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name,
|
||||
const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(
|
||||
const std::string &ref_attr_name, const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||
case FORM_PARSE_SCALAR: {
|
||||
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
|
||||
|
@ -933,12 +1006,14 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at
|
|||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
|
||||
MS_LOG(ERROR)
|
||||
<< "parse ValueNode value don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(
|
||||
const onnx::NodeProto &node_proto) {
|
||||
const std::string &value_node_name = node_proto.output(0);
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
|
@ -951,20 +1026,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
|
|||
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
||||
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
std::vector<int> shape_vec;
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape_vec.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||
return abstract_tensor;
|
||||
}
|
||||
|
||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto) {
|
||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
|
@ -1004,20 +1082,24 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name
|
||||
<< "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
std::string opType = prim->name();
|
||||
auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
|
||||
auto node_parser =
|
||||
AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
|
||||
return nullptr;
|
||||
}
|
||||
auto primitiveT = std::make_unique<schema::PrimitiveT>();
|
||||
// auto * primitiveTValue = new PrimitiveTValue(primitiveT.release());
|
||||
std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release());
|
||||
std::shared_ptr<PrimitiveTValue> primitiveTValuePtr =
|
||||
std::make_shared<PrimitiveTValue>(primitiveT.release());
|
||||
primitiveTValuePtr->SetQuantType(quantType);
|
||||
node_parser->Populate(prim, primitiveTValuePtr.get(), inputs);
|
||||
MS_ASSERT(primitiveTValuePtr != nullptr);
|
||||
inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr));
|
||||
|
@ -1048,8 +1130,9 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
@ -1064,7 +1147,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
maketuple_ptr->set_abstract(
|
||||
std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
|
@ -1077,11 +1161,14 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
const onnx::TypeProto &output_typeproto = output_node.type();
|
||||
int output_type = output_typeproto.tensor_type().elem_type();
|
||||
std::vector<int> output_shape;
|
||||
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
|
||||
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
|
||||
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size();
|
||||
++i) {
|
||||
output_shape.push_back(
|
||||
output_typeproto.tensor_type().shape().dim(i).dim_value());
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
|
@ -1095,8 +1182,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
bool AnfImporterFromProtobuf::ImportNodesForGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
|
@ -1110,7 +1198,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
}
|
||||
continue;
|
||||
}
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
|
@ -1122,7 +1210,9 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
}
|
||||
#endif
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||
|
@ -1135,10 +1225,11 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
|
|||
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||
return false;
|
||||
}
|
||||
return ImportNodesForGraph(outputFuncGraph, importProto);
|
||||
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(
|
||||
const onnx::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return false;
|
||||
|
@ -1159,14 +1250,14 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo
|
|||
return true;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::Import() {
|
||||
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (!ParseModelConfigureInfo(*onnx_model_)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
}
|
||||
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) {
|
||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
func_graph_ = nullptr;
|
||||
return RET_ERROR;
|
||||
|
@ -1176,7 +1267,8 @@ int AnfImporterFromProtobuf::Import() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(
|
||||
const std::string &model_path) {
|
||||
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
||||
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) {
|
||||
MS_LOG(ERROR) << "open file failed.";
|
||||
|
|
|
@ -17,20 +17,21 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
|
||||
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||
#include "src/common/anf_importer/anf_importer.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "src/common/anf_importer/anf_importer.h"
|
||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfImporterFromProtobuf : public AnfImporter {
|
||||
public:
|
||||
explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
|
||||
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
|
||||
explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model,
|
||||
FuncGraphPtr func_graph)
|
||||
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
|
||||
|
||||
~AnfImporterFromProtobuf() override = default;
|
||||
|
||||
|
@ -38,15 +39,17 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
|
||||
FuncGraphPtr GetResult() override;
|
||||
|
||||
int Import() override;
|
||||
int Import(const schema::QuantType &quantType =
|
||||
schema::QuantType_QUANT_NONE) override;
|
||||
|
||||
private:
|
||||
void ConverterConstTensor() override {};
|
||||
int ConverterCNode() override {};
|
||||
void AddReturnCNode() override {};
|
||||
void ConverterConstTensor() override{};
|
||||
int ConverterCNode() override{};
|
||||
void AddReturnCNode() override{};
|
||||
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto);
|
||||
const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType);
|
||||
#if 0
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto);
|
||||
|
@ -78,31 +81,46 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
std::unordered_map<std::string, abstract::AbstractTensorPtr>
|
||||
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
#else
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim,
|
||||
const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim,
|
||||
const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim,
|
||||
const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
||||
const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
|
||||
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
||||
bool ObtainValueNodeInScalarForm(const string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const string &ref_attr_name,
|
||||
const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
abstract::AbstractTensorPtr GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
private:
|
||||
std::string producer_name_;
|
||||
int model_version_{};
|
||||
|
@ -115,4 +133,3 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_
|
||||
|
||||
|
|
|
@ -46,6 +46,9 @@ class PrimitiveTValue : public Value {
|
|||
}
|
||||
}
|
||||
|
||||
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {
|
||||
}
|
||||
|
||||
void AddInputQuantParam(schema::QuantParamT quant_param) {
|
||||
this->input_quant_param_.emplace_back(quant_param);
|
||||
}
|
||||
|
|
|
@ -73,6 +73,9 @@ class Tensor : public mindspore::tensor::MetaTensor {
|
|||
size_t Size() const {
|
||||
size_t size = 0;
|
||||
switch (this->data_type_) {
|
||||
case kNumberTypeFloat64:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
case kNumberTypeFloat:
|
||||
case kNumberTypeFloat32:
|
||||
size = sizeof(float);
|
||||
|
|
|
@ -71,7 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
FuncGraphPtr graph = nullptr;
|
||||
if (flag->fmk == converter::FmkType_MS) {
|
||||
MS_ASSERT(nullptr != modelImporter);
|
||||
modelImporter->Import();
|
||||
modelImporter->Import(flag->quantType);
|
||||
graph = modelImporter->GetResult();
|
||||
} else {
|
||||
MS_ASSERT(nullptr != modelParser);
|
||||
|
|
Loading…
Reference in New Issue