From dbf296dba62ca103d9d46f6186350895e3de244e Mon Sep 17 00:00:00 2001 From: xuanyue Date: Thu, 26 Nov 2020 19:15:55 +0800 Subject: [PATCH] mindir reconstruct --- mindspore/lite/src/ops/primitive_c.h | 1 + mindspore/lite/test/CMakeLists.txt | 1 + .../lite/tools/anf_importer/anf_importer.cc | 4 +- .../lite/tools/anf_importer/anf_importer.h | 3 +- ...from_protobuf.cc => import_from_mindir.cc} | 86 +++++----- ...t_from_protobuf.h => import_from_mindir.h} | 9 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 13 ++ mindspore/lite/tools/converter/converter.cc | 17 +- .../lite/tools/optimizer/common/gllo_utils.cc | 32 ++-- .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../optimizer/graph/mindir_adjust_pass.cc | 147 ++++++++++++++++++ .../optimizer/graph/mindir_adjust_pass.h | 44 ++++++ 13 files changed, 282 insertions(+), 78 deletions(-) rename mindspore/lite/tools/anf_importer/{import_from_protobuf.cc => import_from_mindir.cc} (91%) rename mindspore/lite/tools/anf_importer/{import_from_protobuf.h => import_from_mindir.h} (91%) create mode 100644 mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 5b5e49c49d5..77678578408 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -228,6 +228,7 @@ class PrimitiveC { bool infer_flag_ = true; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; }; +using PrimitiveCPtr = std::shared_ptr; typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); #endif typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 82b989a1f80..4472c49d0af 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -203,6 +203,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc + ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/anf_importer/anf_importer.cc b/mindspore/lite/tools/anf_importer/anf_importer.cc index 789a007b557..88a8f7c0579 100644 --- a/mindspore/lite/tools/anf_importer/anf_importer.cc +++ b/mindspore/lite/tools/anf_importer/anf_importer.cc @@ -14,15 +14,15 @@ * limitations under the License. */ -#include #include "tools/anf_importer/anf_importer.h" +#include #include "schema/model_generated.h" #include "ir/dtype.h" #include "include/errorcode.h" #include "schema/inner/model_generated.h" namespace mindspore { namespace lite { -int AnfImporter::Import(const schema::QuantType &quantType) { +int AnfImporter::Import(const converter::Flags *flag) { auto ret = ConverterConstTensor(); if (RET_OK != ret) { MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; diff --git a/mindspore/lite/tools/anf_importer/anf_importer.h b/mindspore/lite/tools/anf_importer/anf_importer.h index de15ead15fc..5d55b665f82 100644 --- a/mindspore/lite/tools/anf_importer/anf_importer.h +++ b/mindspore/lite/tools/anf_importer/anf_importer.h @@ -22,6 +22,7 @@ #include "ir/anf.h" #include "base/base.h" #include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" namespace mindspore::lite { class AnfImporter { @@ -30,7 +31,7 @@ class AnfImporter { virtual ~AnfImporter() = default; - virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE); + virtual int Import(const converter::Flags *flag = nullptr); virtual FuncGraphPtr GetResult() = 0; diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_mindir.cc similarity index 91% rename from mindspore/lite/tools/anf_importer/import_from_protobuf.cc rename to mindspore/lite/tools/anf_importer/import_from_mindir.cc index c0f3ba4b9bb..3e8e17d93eb 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_mindir.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tools/anf_importer/import_from_protobuf.h" +#include "tools/anf_importer/import_from_mindir.h" #include #include #include @@ -36,6 +36,7 @@ #include "src/common/log_adapter.h" #include "tools/common/protobuf_utils.h" #include "tools/common/graph_util.h" +#include "load_mindir/load_model.h" using string = std::string; using int32 = int32_t; @@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) -int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, - const onnx::ValueInfoProto &value_proto) { +int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto) { if (node == nullptr) { return RET_NULL_PTR; } @@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node return RET_OK; } -int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { +int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { if (outputFuncGraph == nullptr) { return RET_NULL_PTR; } @@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output return status; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { if (prim == nullptr) { return false; } @@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim return true; } -ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { +ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); switch (attr_tensor_type) { case onnx::TensorProto_DataType_STRING: { @@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor } } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { if (prim == nullptr) { return false; } @@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr return ret == EOK; } -bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { +bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { if (prim == nullptr) { return false; } @@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); std::vector shape; for (int i = 0; i < attr_tensor.dims_size(); ++i) { @@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; @@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value return true; } -bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, - const onnx::AttributeProto &attr_proto) { +bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_proto) { if (!attr_proto.has_ref_attr_name()) { MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; return false; @@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_ return true; } -bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { +bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { const std::string &value_node_name = node_proto.output(0); const onnx::AttributeProto &attr_proto = node_proto.attribute(0); if (!attr_proto.has_ref_attr_name()) { @@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & return GetAttrValueForValueNode(value_node_name, attr_proto); } -std::unordered_map AnfImporterFromProtobuf::GetAbstractForCNode( +std::unordered_map AnfImporterFromMindir::GetAbstractForCNode( const onnx::AttributeProto &attr_proto) { std::unordered_map kv; for (int i = 0; i < attr_proto.tensors_size(); i++) { @@ -601,9 +602,9 @@ std::unordered_map AnfImporterFromProt return kv; } -CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::NodeProto &node_proto, - const schema::QuantType &quantType) { +CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto, + const schema::QuantType &quantType) { static bool interrupt = false; if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "output funcgraph is nullptr"; @@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out return cnode_ptr; } -bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { +bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { if (outputFuncGraph == nullptr || cnode_ptr == nullptr) { MS_LOG(ERROR) << "output funcgraph or cnode is nullptr"; return false; @@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output return true; } -int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { +int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "funcgraph is nullptr"; return RET_NULL_PTR; @@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG return status; } -int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { +int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "fundgraph is nullptr"; return RET_NULL_PTR; @@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, return ImportNodesForGraph(outputFuncGraph, importProto, quantType); } -int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { +int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { if (!model_proto.has_producer_name()) { MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; return RET_GRAPH_FILE_ERR; @@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod return RET_OK; } -int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { +int AnfImporterFromMindir::Import(const converter::Flags *flag) { + onnx_model_ = ReadOnnxFromBinary(flag->modelFile); + if (onnx_model_ == nullptr) { + MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; + func_graph_ = LoadMindIR(flag->modelFile); + if (func_graph_ == nullptr) { + MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; + return RET_GRAPH_FILE_ERR; + } + return RET_OK; + } FuncGraphPtr dstGraph = std::make_shared(); if (dstGraph == nullptr) { MS_LOG(ERROR) << "funcgraph is nullptr"; @@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; return status; } - if (onnx_model_ == nullptr) { - MS_LOG(ERROR) << "onnx_model_ is nullptr"; - return RET_NULL_PTR; - } + auto quantType = flag->quantType; const onnx::GraphProto &graphBuild = onnx_model_->graph(); status = BuildFuncGraph(dstGraph, graphBuild, quantType); if (status != RET_OK) { @@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { return RET_OK; } -onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { +onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) { auto onnx_model = new (std::nothrow) onnx::ModelProto; if (onnx_model == nullptr) { MS_LOG(ERROR) << "New onnx ModelProto failed!"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); return nullptr; } if (RET_OK != ValidateFileStr(model_path, ".mindir")) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID); return nullptr; } if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { - MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); + MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model"; return nullptr; } return onnx_model; } -FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } +FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.h b/mindspore/lite/tools/anf_importer/import_from_mindir.h similarity index 91% rename from mindspore/lite/tools/anf_importer/import_from_protobuf.h rename to mindspore/lite/tools/anf_importer/import_from_mindir.h index b12503e4a14..f743f473abf 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/tools/anf_importer/import_from_mindir.h @@ -29,18 +29,17 @@ #include "abstract/abstract_value.h" namespace mindspore::lite { -class AnfImporterFromProtobuf : public AnfImporter { +class AnfImporterFromMindir : public AnfImporter { public: - AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) - : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} + AnfImporterFromMindir() = default; - ~AnfImporterFromProtobuf() override = default; + ~AnfImporterFromMindir() override { delete onnx_model_; } static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); FuncGraphPtr GetResult() override; - int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; + int Import(const converter::Flags *flag) override; private: int ConverterConstTensor() override { return RET_ERROR; }; diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 6953b135b67..6ba901da03a 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/identity_remove_pass.cc ../optimizer/graph/infershape_pass.cc ../optimizer/graph/slice_prepose_pass.cc + ../optimizer/graph/mindir_adjust_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 876389d20b6..e23592d8a15 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -29,6 +29,7 @@ #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h" +#include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" @@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver auto graph_pm = std::make_shared("anf graph pass manager", true); auto convert_pm = std::make_shared("anf graph convert pass manager", true); + // mindir pre adjustment + if (config->fmk == converter::FmkType_MS) { + auto mindir_adjust_pass = std::make_shared(); + mindir_adjust_pass->SetFmkType(config->fmk); + mindir_adjust_pass->SetQuantType(config->quantType); + if (!mindir_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "mindir adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + } + // for now - trainning is not supporting fuse operations if (!config->trainModel) { // remove quantdtype when awaretraining diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 12a8a9269e4..0be34cc56ec 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -30,7 +30,7 @@ #include "parser/onnx/onnx_converter.h" #include "parser/tf/tf_converter.h" #include "tools/anf_exporter/anf_exporter.h" -#include "tools/anf_importer/import_from_protobuf.h" +#include "tools/anf_importer/import_from_mindir.h" #include "proto/onnx.pb.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" @@ -54,9 +54,7 @@ Converter::~Converter() { class MindsporeImporter : public Converter { public: - MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) { - modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph)); - } + MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); } ~MindsporeImporter() override = default; }; @@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { FuncGraphPtr graph = nullptr; if (flag->fmk == converter::FmkType_MS) { MS_ASSERT(nullptr != modelImporter); - int status = modelImporter->Import(flag->quantType); + int status = modelImporter->Import(flag); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); graph = modelImporter->GetResult(); } else { @@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) { MetaGraphT *fb_graph = nullptr; switch (flags->fmk) { case FmkType::FmkType_MS: { - auto graph = std::make_shared(); - auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile); - if (onnx_graph == nullptr) { - MS_LOG(ERROR) << "Read MINDIR from binary return nullptr"; - break; - } - MindsporeImporter mindsporeImporter(onnx_graph, graph); + MindsporeImporter mindsporeImporter; fb_graph = mindsporeImporter.Convert(flags.get()); - delete onnx_graph; break; } case FmkType::FmkType_CAFFE: { diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 0ef05434943..397220d5845 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -26,22 +26,6 @@ namespace mindspore { namespace opt { namespace { constexpr auto kAnfPrimitiveIndex = 0; -bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { - if (node == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); -} - bool IsRealKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive } } // namespace +bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { + if (node == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); +} + bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 070a0cea316..0dcf89b9ff0 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -34,6 +34,8 @@ using mindspore::lite::RET_OK; using mindspore::lite::STATUS; namespace mindspore { namespace opt { +bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); + bool IsRealCNodeKernel(const AnfNodePtr &node); bool IsGraphKernel(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc new file mode 100644 index 00000000000..d127cdc4b75 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/mindir_adjust_pass.h" +#include +#include +#include + +#include "src/ops/primitive_c.h" +#include "tools/converter/quantizer/quant_cast.h" +#include "src/common/log_adapter.h" +#include "src/tensor.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore { +namespace opt { +int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) { + if (!utils::isa(anf_node)) { + MS_LOG(INFO) << "only parameter node need to convert tensor."; + return lite::RET_NO_CHANGE; + } + auto param_node = anf_node->cast(); + if (!param_node->has_default()) { + MS_LOG(INFO) << "this is graph input, don't need to convert."; + return lite::RET_NO_CHANGE; + } + if (utils::isa(param_node->default_param())) { + MS_LOG(INFO) << "the tensor has been a paramvalueLite."; + return lite::RET_NO_CHANGE; + } + ParamValueLitePtr param_value = std::make_shared(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "fail to new a ParamValueLite."; + return lite::RET_ERROR; + } + param_node->set_name(param_node->debug_info()->name()); + auto tensor_info = param_node->default_param()->cast(); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "the node is not a tensor::TensorPtr."; + return lite::RET_ERROR; + } + param_value->set_tensor_size(tensor_info->Size()); + param_value->set_tensor_type(tensor_info->data_type()); + auto tensor_shape = tensor_info->shape(); + std::vector shape; + std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape), + [](int64_t value) { return static_cast(value); }); + param_value->set_tensor_shape(shape); + auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape); + if (tensor == nullptr) { + MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr."; + return lite::RET_MEMORY_FAILED; + } + auto *tensor_data_buf = tensor->MutableData(); + if (tensor_data_buf == nullptr) { + MS_LOG(ERROR) << "malloc tensor data failed."; + delete tensor; + return lite::RET_MEMORY_FAILED; + } + if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { + MS_LOG(ERROR) << "memcpy_s error."; + delete tensor; + return lite::RET_MEMORY_FAILED; + } + tensor->set_data(nullptr); + param_value->set_tensor_addr(tensor_data_buf); + param_node->set_default_param(param_value); + delete tensor; + return lite::RET_OK; +} + +int MindirAdjustPass::PrimitiveConvert(std::shared_ptr anf_node) { + if (!utils::isa(anf_node)) { + MS_LOG(INFO) << "only cnode need to convert primitive."; + return lite::RET_NO_CHANGE; + } + auto cnode = anf_node->cast(); + if (cnode->inputs().empty() || cnode->input(0) == nullptr) { + MS_LOG(ERROR) << "the cnode is invalid."; + return lite::RET_NULL_PTR; + } + auto value_node = cnode->input(0)->cast(); + if (value_node == nullptr || value_node->value() == nullptr) { + MS_LOG(ERROR) << "value node is invalid."; + return lite::RET_NULL_PTR; + } + if (utils::isa(value_node->value())) { + MS_LOG(INFO) << "the value has been primitiveC."; + return lite::RET_NO_CHANGE; + } + auto primitive = value_node->value()->cast(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "the value is not primitive."; + return lite::RET_ERROR; + } + auto inputs = cnode->inputs(); + inputs.erase(inputs.begin()); + if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { + auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + value_node->set_value(primitive_c); + } else { + auto primitiveT = std::make_unique(); + primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return + : schema::PrimitiveType_MakeTuple); + value_node->set_value(std::make_shared(primitiveT.release())); + } + return lite::RET_OK; +} + +bool MindirAdjustPass::Run(const FuncGraphPtr &graph) { + if (this->fmk_type_ != lite::converter::FmkType_MS) { + MS_LOG(INFO) << "The framework type of model should be mindir."; + return lite::RET_OK; + } + MS_ASSERT(graph != nullptr); + auto node_list = TopoSort(graph->get_return()); + int status = lite::RET_OK; + for (auto &node : node_list) { + if (utils::isa(node)) { + status = ParameterNodeConvert(node); + } else if (utils::isa(node)) { + status = PrimitiveConvert(node); + } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + return false; + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h new file mode 100644 index 00000000000..77ac864dab7 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ + +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "src/param_value_lite.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::QuantType; +namespace mindspore::opt { +class MindirAdjustPass : public Pass { + public: + MindirAdjustPass() : Pass("mindir_adjust_pass") {} + ~MindirAdjustPass() override = default; + void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } + void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } + int ParameterNodeConvert(AnfNodePtr anf_node); + int PrimitiveConvert(AnfNodePtr anf_node); + bool Run(const FuncGraphPtr &graph) override; + + protected: + QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; + FmkType fmk_type_ = FmkType::FmkType_MS; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_