forked from mindspore-Ecosystem/mindspore
!9047 [lite]mindir reconstruct compatibility
From: @xu_anyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
61d717032c
|
@ -228,6 +228,7 @@ class PrimitiveC {
|
||||||
bool infer_flag_ = true;
|
bool infer_flag_ = true;
|
||||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||||
};
|
};
|
||||||
|
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
|
||||||
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);
|
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);
|
||||||
#endif
|
#endif
|
||||||
typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive);
|
typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive);
|
||||||
|
|
|
@ -203,6 +203,7 @@ if(ENABLE_CONVERTER)
|
||||||
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
|
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
|
||||||
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
|
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
|
||||||
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
||||||
|
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
### train
|
### train
|
||||||
|
|
|
@ -14,15 +14,15 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
#include "tools/anf_importer/anf_importer.h"
|
#include "tools/anf_importer/anf_importer.h"
|
||||||
|
#include <utility>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
int AnfImporter::Import(const schema::QuantType &quantType) {
|
int AnfImporter::Import(const converter::Flags *flag) {
|
||||||
auto ret = ConverterConstTensor();
|
auto ret = ConverterConstTensor();
|
||||||
if (RET_OK != ret) {
|
if (RET_OK != ret) {
|
||||||
MS_LOG(ERROR) << "ConverterConstTensor failed " << ret;
|
MS_LOG(ERROR) << "ConverterConstTensor failed " << ret;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
|
#include "tools/converter/converter_flags.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class AnfImporter {
|
class AnfImporter {
|
||||||
|
@ -30,7 +31,7 @@ class AnfImporter {
|
||||||
|
|
||||||
virtual ~AnfImporter() = default;
|
virtual ~AnfImporter() = default;
|
||||||
|
|
||||||
virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE);
|
virtual int Import(const converter::Flags *flag = nullptr);
|
||||||
|
|
||||||
virtual FuncGraphPtr GetResult() = 0;
|
virtual FuncGraphPtr GetResult() = 0;
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/anf_importer/import_from_protobuf.h"
|
#include "tools/anf_importer/import_from_mindir.h"
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -36,6 +36,7 @@
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
|
#include "load_mindir/load_model.h"
|
||||||
|
|
||||||
using string = std::string;
|
using string = std::string;
|
||||||
using int32 = int32_t;
|
using int32 = int32_t;
|
||||||
|
@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
||||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||||
const onnx::ValueInfoProto &value_proto) {
|
const onnx::ValueInfoProto &value_proto) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
const onnx::GraphProto &importProto) {
|
const onnx::GraphProto &importProto) {
|
||||||
if (outputFuncGraph == nullptr) {
|
if (outputFuncGraph == nullptr) {
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
switch (attr_tensor_type) {
|
switch (attr_tensor_type) {
|
||||||
case onnx::TensorProto_DataType_STRING: {
|
case onnx::TensorProto_DataType_STRING: {
|
||||||
|
@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
||||||
return ret == EOK;
|
return ret == EOK;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||||
|
@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||||
|
@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
|
bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||||
const onnx::AttributeProto &attr_proto) {
|
const onnx::AttributeProto &attr_proto) {
|
||||||
if (!attr_proto.has_ref_attr_name()) {
|
if (!attr_proto.has_ref_attr_name()) {
|
||||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||||
return false;
|
return false;
|
||||||
|
@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||||
const std::string &value_node_name = node_proto.output(0);
|
const std::string &value_node_name = node_proto.output(0);
|
||||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||||
if (!attr_proto.has_ref_attr_name()) {
|
if (!attr_proto.has_ref_attr_name()) {
|
||||||
|
@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
|
||||||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProtobuf::GetAbstractForCNode(
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode(
|
||||||
const onnx::AttributeProto &attr_proto) {
|
const onnx::AttributeProto &attr_proto) {
|
||||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||||
|
@ -601,9 +602,9 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
|
||||||
return kv;
|
return kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
const onnx::NodeProto &node_proto,
|
const onnx::NodeProto &node_proto,
|
||||||
const schema::QuantType &quantType) {
|
const schema::QuantType &quantType) {
|
||||||
static bool interrupt = false;
|
static bool interrupt = false;
|
||||||
if (outputFuncGraph == nullptr) {
|
if (outputFuncGraph == nullptr) {
|
||||||
MS_LOG(ERROR) << "output funcgraph is nullptr";
|
MS_LOG(ERROR) << "output funcgraph is nullptr";
|
||||||
|
@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
||||||
return cnode_ptr;
|
return cnode_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||||
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
|
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
|
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
|
||||||
return false;
|
return false;
|
||||||
|
@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||||
const onnx::GraphProto &importProto,
|
const schema::QuantType &quantType) {
|
||||||
const schema::QuantType &quantType) {
|
|
||||||
if (outputFuncGraph == nullptr) {
|
if (outputFuncGraph == nullptr) {
|
||||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||||
const schema::QuantType &quantType) {
|
const schema::QuantType &quantType) {
|
||||||
if (outputFuncGraph == nullptr) {
|
if (outputFuncGraph == nullptr) {
|
||||||
MS_LOG(ERROR) << "fundgraph is nullptr";
|
MS_LOG(ERROR) << "fundgraph is nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
|
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||||
if (!model_proto.has_producer_name()) {
|
if (!model_proto.has_producer_name()) {
|
||||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||||
return RET_GRAPH_FILE_ERR;
|
return RET_GRAPH_FILE_ERR;
|
||||||
|
@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
|
||||||
|
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
|
||||||
|
if (onnx_model_ == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
|
||||||
|
func_graph_ = LoadMindIR(flag->modelFile);
|
||||||
|
if (func_graph_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file.";
|
||||||
|
return RET_GRAPH_FILE_ERR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||||
if (dstGraph == nullptr) {
|
if (dstGraph == nullptr) {
|
||||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||||
|
@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
if (onnx_model_ == nullptr) {
|
auto quantType = flag->quantType;
|
||||||
MS_LOG(ERROR) << "onnx_model_ is nullptr";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
||||||
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
|
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
|
@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) {
|
||||||
auto onnx_model = new (std::nothrow) onnx::ModelProto;
|
auto onnx_model = new (std::nothrow) onnx::ModelProto;
|
||||||
if (onnx_model == nullptr) {
|
if (onnx_model == nullptr) {
|
||||||
MS_LOG(ERROR) << "New onnx ModelProto failed!";
|
MS_LOG(ERROR) << "New onnx ModelProto failed!";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
|
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
|
||||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir";
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID);
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
|
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
|
MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return onnx_model;
|
return onnx_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
|
FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; }
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
|
@ -29,18 +29,17 @@
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class AnfImporterFromProtobuf : public AnfImporter {
|
class AnfImporterFromMindir : public AnfImporter {
|
||||||
public:
|
public:
|
||||||
AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
|
AnfImporterFromMindir() = default;
|
||||||
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
|
|
||||||
|
|
||||||
~AnfImporterFromProtobuf() override = default;
|
~AnfImporterFromMindir() override { delete onnx_model_; }
|
||||||
|
|
||||||
static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path);
|
static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path);
|
||||||
|
|
||||||
FuncGraphPtr GetResult() override;
|
FuncGraphPtr GetResult() override;
|
||||||
|
|
||||||
int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;
|
int Import(const converter::Flags *flag) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int ConverterConstTensor() override { return RET_ERROR; };
|
int ConverterConstTensor() override { return RET_ERROR; };
|
|
@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
../optimizer/graph/identity_remove_pass.cc
|
../optimizer/graph/identity_remove_pass.cc
|
||||||
../optimizer/graph/infershape_pass.cc
|
../optimizer/graph/infershape_pass.cc
|
||||||
../optimizer/graph/slice_prepose_pass.cc
|
../optimizer/graph/slice_prepose_pass.cc
|
||||||
|
../optimizer/graph/mindir_adjust_pass.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(../anf_importer anf_importer)
|
add_subdirectory(../anf_importer anf_importer)
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||||
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
|
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
|
||||||
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
||||||
|
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||||
#include "tools/optimizer/graph/identity_remove_pass.h"
|
#include "tools/optimizer/graph/identity_remove_pass.h"
|
||||||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||||
|
@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
||||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||||
|
|
||||||
|
// mindir pre adjustment
|
||||||
|
if (config->fmk == converter::FmkType_MS) {
|
||||||
|
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
|
||||||
|
mindir_adjust_pass->SetFmkType(config->fmk);
|
||||||
|
mindir_adjust_pass->SetQuantType(config->quantType);
|
||||||
|
if (!mindir_adjust_pass->Run(old_graph)) {
|
||||||
|
MS_LOG(ERROR) << "mindir adjust failed.";
|
||||||
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// for now - trainning is not supporting fuse operations
|
// for now - trainning is not supporting fuse operations
|
||||||
if (!config->trainModel) {
|
if (!config->trainModel) {
|
||||||
// remove quantdtype when awaretraining
|
// remove quantdtype when awaretraining
|
||||||
|
|
|
@ -30,7 +30,7 @@
|
||||||
#include "parser/onnx/onnx_converter.h"
|
#include "parser/onnx/onnx_converter.h"
|
||||||
#include "parser/tf/tf_converter.h"
|
#include "parser/tf/tf_converter.h"
|
||||||
#include "tools/anf_exporter/anf_exporter.h"
|
#include "tools/anf_exporter/anf_exporter.h"
|
||||||
#include "tools/anf_importer/import_from_protobuf.h"
|
#include "tools/anf_importer/import_from_mindir.h"
|
||||||
#include "proto/onnx.pb.h"
|
#include "proto/onnx.pb.h"
|
||||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||||
#include "tools/converter/quantizer/quant_cast.h"
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
|
@ -54,9 +54,7 @@ Converter::~Converter() {
|
||||||
|
|
||||||
class MindsporeImporter : public Converter {
|
class MindsporeImporter : public Converter {
|
||||||
public:
|
public:
|
||||||
MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) {
|
MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); }
|
||||||
modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph));
|
|
||||||
}
|
|
||||||
|
|
||||||
~MindsporeImporter() override = default;
|
~MindsporeImporter() override = default;
|
||||||
};
|
};
|
||||||
|
@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||||
FuncGraphPtr graph = nullptr;
|
FuncGraphPtr graph = nullptr;
|
||||||
if (flag->fmk == converter::FmkType_MS) {
|
if (flag->fmk == converter::FmkType_MS) {
|
||||||
MS_ASSERT(nullptr != modelImporter);
|
MS_ASSERT(nullptr != modelImporter);
|
||||||
int status = modelImporter->Import(flag->quantType);
|
int status = modelImporter->Import(flag);
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
graph = modelImporter->GetResult();
|
graph = modelImporter->GetResult();
|
||||||
} else {
|
} else {
|
||||||
|
@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) {
|
||||||
MetaGraphT *fb_graph = nullptr;
|
MetaGraphT *fb_graph = nullptr;
|
||||||
switch (flags->fmk) {
|
switch (flags->fmk) {
|
||||||
case FmkType::FmkType_MS: {
|
case FmkType::FmkType_MS: {
|
||||||
auto graph = std::make_shared<FuncGraph>();
|
MindsporeImporter mindsporeImporter;
|
||||||
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
|
|
||||||
if (onnx_graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Read MINDIR from binary return nullptr";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
MindsporeImporter mindsporeImporter(onnx_graph, graph);
|
|
||||||
fb_graph = mindsporeImporter.Convert(flags.get());
|
fb_graph = mindsporeImporter.Convert(flags.get());
|
||||||
delete onnx_graph;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FmkType::FmkType_CAFFE: {
|
case FmkType::FmkType_CAFFE: {
|
||||||
|
|
|
@ -26,22 +26,6 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr auto kAnfPrimitiveIndex = 0;
|
constexpr auto kAnfPrimitiveIndex = 0;
|
||||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
|
||||||
if (node == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!node->isa<CNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
if (cnode == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsRealKernel(const AnfNodePtr &node) {
|
bool IsRealKernel(const AnfNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
||||||
|
}
|
||||||
|
|
||||||
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||||
|
|
|
@ -34,6 +34,8 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::lite::STATUS;
|
using mindspore::lite::STATUS;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
|
||||||
|
|
||||||
bool IsRealCNodeKernel(const AnfNodePtr &node);
|
bool IsRealCNodeKernel(const AnfNodePtr &node);
|
||||||
|
|
||||||
bool IsGraphKernel(const AnfNodePtr &node);
|
bool IsGraphKernel(const AnfNodePtr &node);
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "src/ops/primitive_c.h"
|
||||||
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/tensor.h"
|
||||||
|
|
||||||
|
using mindspore::lite::PrimitiveC;
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) {
|
||||||
|
if (!utils::isa<ParameterPtr>(anf_node)) {
|
||||||
|
MS_LOG(INFO) << "only parameter node need to convert tensor.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto param_node = anf_node->cast<ParameterPtr>();
|
||||||
|
if (!param_node->has_default()) {
|
||||||
|
MS_LOG(INFO) << "this is graph input, don't need to convert.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
if (utils::isa<ParamValueLitePtr>(param_node->default_param())) {
|
||||||
|
MS_LOG(INFO) << "the tensor has been a paramvalueLite.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||||
|
if (param_value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "fail to new a ParamValueLite.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
param_node->set_name(param_node->debug_info()->name());
|
||||||
|
auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the node is not a tensor::TensorPtr.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
param_value->set_tensor_size(tensor_info->Size());
|
||||||
|
param_value->set_tensor_type(tensor_info->data_type());
|
||||||
|
auto tensor_shape = tensor_info->shape();
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape),
|
||||||
|
[](int64_t value) { return static_cast<int>(value); });
|
||||||
|
param_value->set_tensor_shape(shape);
|
||||||
|
auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr.";
|
||||||
|
return lite::RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
|
auto *tensor_data_buf = tensor->MutableData();
|
||||||
|
if (tensor_data_buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc tensor data failed.";
|
||||||
|
delete tensor;
|
||||||
|
return lite::RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
|
if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error.";
|
||||||
|
delete tensor;
|
||||||
|
return lite::RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
|
tensor->set_data(nullptr);
|
||||||
|
param_value->set_tensor_addr(tensor_data_buf);
|
||||||
|
param_node->set_default_param(param_value);
|
||||||
|
delete tensor;
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) {
|
||||||
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||||
|
MS_LOG(INFO) << "only cnode need to convert primitive.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
if (cnode->inputs().empty() || cnode->input(0) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the cnode is invalid.";
|
||||||
|
return lite::RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr || value_node->value() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
if (utils::isa<PrimitiveCPtr>(value_node->value())) {
|
||||||
|
MS_LOG(INFO) << "the value has been primitiveC.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the value is not primitive.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto inputs = cnode->inputs();
|
||||||
|
inputs.erase(inputs.begin());
|
||||||
|
if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||||
|
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_);
|
||||||
|
if (primitive_c == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
value_node->set_value(primitive_c);
|
||||||
|
} else {
|
||||||
|
auto primitiveT = std::make_unique<schema::PrimitiveT>();
|
||||||
|
primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return
|
||||||
|
: schema::PrimitiveType_MakeTuple);
|
||||||
|
value_node->set_value(std::make_shared<PrimitiveC>(primitiveT.release()));
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MindirAdjustPass::Run(const FuncGraphPtr &graph) {
|
||||||
|
if (this->fmk_type_ != lite::converter::FmkType_MS) {
|
||||||
|
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
MS_ASSERT(graph != nullptr);
|
||||||
|
auto node_list = TopoSort(graph->get_return());
|
||||||
|
int status = lite::RET_OK;
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (utils::isa<ParameterPtr>(node)) {
|
||||||
|
status = ParameterNodeConvert(node);
|
||||||
|
} else if (utils::isa<CNodePtr>(node)) {
|
||||||
|
status = PrimitiveConvert(node);
|
||||||
|
}
|
||||||
|
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
#include "tools/converter/converter_flags.h"
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "src/param_value_lite.h"
|
||||||
|
|
||||||
|
using mindspore::lite::converter::FmkType;
|
||||||
|
using mindspore::schema::QuantType;
|
||||||
|
namespace mindspore::opt {
|
||||||
|
class MindirAdjustPass : public Pass {
|
||||||
|
public:
|
||||||
|
MindirAdjustPass() : Pass("mindir_adjust_pass") {}
|
||||||
|
~MindirAdjustPass() override = default;
|
||||||
|
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
|
||||||
|
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
|
||||||
|
int ParameterNodeConvert(AnfNodePtr anf_node);
|
||||||
|
int PrimitiveConvert(AnfNodePtr anf_node);
|
||||||
|
bool Run(const FuncGraphPtr &graph) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||||
|
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::opt
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
Loading…
Reference in New Issue