forked from mindspore-Ecosystem/mindspore
!9047 [lite]mindir reconstruct compatibility
From: @xu_anyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
61d717032c
|
@ -228,6 +228,7 @@ class PrimitiveC {
|
|||
bool infer_flag_ = true;
|
||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||
};
|
||||
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
|
||||
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);
|
||||
#endif
|
||||
typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive);
|
||||
|
|
|
@ -203,6 +203,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
|
|
|
@ -14,15 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include "tools/anf_importer/anf_importer.h"
|
||||
#include <utility>
|
||||
#include "schema/model_generated.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int AnfImporter::Import(const schema::QuantType &quantType) {
|
||||
int AnfImporter::Import(const converter::Flags *flag) {
|
||||
auto ret = ConverterConstTensor();
|
||||
if (RET_OK != ret) {
|
||||
MS_LOG(ERROR) << "ConverterConstTensor failed " << ret;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "base/base.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfImporter {
|
||||
|
@ -30,7 +31,7 @@ class AnfImporter {
|
|||
|
||||
virtual ~AnfImporter() = default;
|
||||
|
||||
virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE);
|
||||
virtual int Import(const converter::Flags *flag = nullptr);
|
||||
|
||||
virtual FuncGraphPtr GetResult() = 0;
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/anf_importer/import_from_protobuf.h"
|
||||
#include "tools/anf_importer/import_from_mindir.h"
|
||||
#include <unistd.h>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -36,6 +36,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/protobuf_utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
|
||||
using string = std::string;
|
||||
using int32 = int32_t;
|
||||
|
@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
|||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||
|
||||
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto) {
|
||||
int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto) {
|
||||
if (node == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
if (outputFuncGraph == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output
|
|||
return status;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
|
|||
return true;
|
||||
}
|
||||
|
||||
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||
ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_STRING: {
|
||||
|
@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor
|
|||
}
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
|||
return ret == EOK;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
std::vector<int> shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
|
@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
|
@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
|
@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||
bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||
const std::string &value_node_name = node_proto.output(0);
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
|
@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
|
|||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProtobuf::GetAbstractForCNode(
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
|
@ -601,9 +602,9 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
|
|||
return kv;
|
||||
}
|
||||
|
||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType) {
|
||||
CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType) {
|
||||
static bool interrupt = false;
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "output funcgraph is nullptr";
|
||||
|
@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
|
||||
return false;
|
||||
|
@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
return true;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
|
|||
return status;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "fundgraph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
|||
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
|
@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
|
||||
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
|
||||
if (onnx_model_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
|
||||
func_graph_ = LoadMindIR(flag->modelFile);
|
||||
if (func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file.";
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||
if (dstGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||
|
@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
return status;
|
||||
}
|
||||
if (onnx_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "onnx_model_ is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quantType = flag->quantType;
|
||||
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
||||
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
|
||||
if (status != RET_OK) {
|
||||
|
@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
||||
onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) {
|
||||
auto onnx_model = new (std::nothrow) onnx::ModelProto;
|
||||
if (onnx_model == nullptr) {
|
||||
MS_LOG(ERROR) << "New onnx ModelProto failed!";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID);
|
||||
return nullptr;
|
||||
}
|
||||
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model";
|
||||
return nullptr;
|
||||
}
|
||||
return onnx_model;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
|
||||
FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; }
|
||||
} // namespace mindspore::lite
|
|
@ -29,18 +29,17 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfImporterFromProtobuf : public AnfImporter {
|
||||
class AnfImporterFromMindir : public AnfImporter {
|
||||
public:
|
||||
AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
|
||||
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
|
||||
AnfImporterFromMindir() = default;
|
||||
|
||||
~AnfImporterFromProtobuf() override = default;
|
||||
~AnfImporterFromMindir() override { delete onnx_model_; }
|
||||
|
||||
static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path);
|
||||
|
||||
FuncGraphPtr GetResult() override;
|
||||
|
||||
int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;
|
||||
int Import(const converter::Flags *flag) override;
|
||||
|
||||
private:
|
||||
int ConverterConstTensor() override { return RET_ERROR; };
|
|
@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/identity_remove_pass.cc
|
||||
../optimizer/graph/infershape_pass.cc
|
||||
../optimizer/graph/slice_prepose_pass.cc
|
||||
../optimizer/graph/mindir_adjust_pass.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_importer anf_importer)
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
||||
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/identity_remove_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||
|
@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||
|
||||
// mindir pre adjustment
|
||||
if (config->fmk == converter::FmkType_MS) {
|
||||
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
|
||||
mindir_adjust_pass->SetFmkType(config->fmk);
|
||||
mindir_adjust_pass->SetQuantType(config->quantType);
|
||||
if (!mindir_adjust_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "mindir adjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// for now - trainning is not supporting fuse operations
|
||||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "parser/onnx/onnx_converter.h"
|
||||
#include "parser/tf/tf_converter.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/anf_importer/import_from_protobuf.h"
|
||||
#include "tools/anf_importer/import_from_mindir.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
|
@ -54,9 +54,7 @@ Converter::~Converter() {
|
|||
|
||||
class MindsporeImporter : public Converter {
|
||||
public:
|
||||
MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) {
|
||||
modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph));
|
||||
}
|
||||
MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); }
|
||||
|
||||
~MindsporeImporter() override = default;
|
||||
};
|
||||
|
@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
FuncGraphPtr graph = nullptr;
|
||||
if (flag->fmk == converter::FmkType_MS) {
|
||||
MS_ASSERT(nullptr != modelImporter);
|
||||
int status = modelImporter->Import(flag->quantType);
|
||||
int status = modelImporter->Import(flag);
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
graph = modelImporter->GetResult();
|
||||
} else {
|
||||
|
@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) {
|
|||
MetaGraphT *fb_graph = nullptr;
|
||||
switch (flags->fmk) {
|
||||
case FmkType::FmkType_MS: {
|
||||
auto graph = std::make_shared<FuncGraph>();
|
||||
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
|
||||
if (onnx_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Read MINDIR from binary return nullptr";
|
||||
break;
|
||||
}
|
||||
MindsporeImporter mindsporeImporter(onnx_graph, graph);
|
||||
MindsporeImporter mindsporeImporter;
|
||||
fb_graph = mindsporeImporter.Convert(flags.get());
|
||||
delete onnx_graph;
|
||||
break;
|
||||
}
|
||||
case FmkType::FmkType_CAFFE: {
|
||||
|
|
|
@ -26,22 +26,6 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
||||
}
|
||||
|
||||
bool IsRealKernel(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
||||
}
|
||||
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
|
|
|
@ -34,6 +34,8 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::lite::STATUS;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
|
||||
|
||||
bool IsRealCNodeKernel(const AnfNodePtr &node);
|
||||
|
||||
bool IsGraphKernel(const AnfNodePtr &node);
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
using mindspore::lite::PrimitiveC;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) {
|
||||
if (!utils::isa<ParameterPtr>(anf_node)) {
|
||||
MS_LOG(INFO) << "only parameter node need to convert tensor.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto param_node = anf_node->cast<ParameterPtr>();
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(INFO) << "this is graph input, don't need to convert.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
if (utils::isa<ParamValueLitePtr>(param_node->default_param())) {
|
||||
MS_LOG(INFO) << "the tensor has been a paramvalueLite.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "fail to new a ParamValueLite.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
param_node->set_name(param_node->debug_info()->name());
|
||||
auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "the node is not a tensor::TensorPtr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
param_value->set_tensor_size(tensor_info->Size());
|
||||
param_value->set_tensor_type(tensor_info->data_type());
|
||||
auto tensor_shape = tensor_info->shape();
|
||||
std::vector<int> shape;
|
||||
std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape),
|
||||
[](int64_t value) { return static_cast<int>(value); });
|
||||
param_value->set_tensor_shape(shape);
|
||||
auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr.";
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
auto *tensor_data_buf = tensor->MutableData();
|
||||
if (tensor_data_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tensor data failed.";
|
||||
delete tensor;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
delete tensor;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
tensor->set_data(nullptr);
|
||||
param_value->set_tensor_addr(tensor_data_buf);
|
||||
param_node->set_default_param(param_value);
|
||||
delete tensor;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(INFO) << "only cnode need to convert primitive.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (cnode->inputs().empty() || cnode->input(0) == nullptr) {
|
||||
MS_LOG(ERROR) << "the cnode is invalid.";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr || value_node->value() == nullptr) {
|
||||
MS_LOG(ERROR) << "value node is invalid.";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
if (utils::isa<PrimitiveCPtr>(value_node->value())) {
|
||||
MS_LOG(INFO) << "the value has been primitiveC.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "the value is not primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
value_node->set_value(primitive_c);
|
||||
} else {
|
||||
auto primitiveT = std::make_unique<schema::PrimitiveT>();
|
||||
primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return
|
||||
: schema::PrimitiveType_MakeTuple);
|
||||
value_node->set_value(std::make_shared<PrimitiveC>(primitiveT.release()));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool MindirAdjustPass::Run(const FuncGraphPtr &graph) {
|
||||
if (this->fmk_type_ != lite::converter::FmkType_MS) {
|
||||
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
int status = lite::RET_OK;
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<ParameterPtr>(node)) {
|
||||
status = ParameterNodeConvert(node);
|
||||
} else if (utils::isa<CNodePtr>(node)) {
|
||||
status = PrimitiveConvert(node);
|
||||
}
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::QuantType;
|
||||
namespace mindspore::opt {
|
||||
class MindirAdjustPass : public Pass {
|
||||
public:
|
||||
MindirAdjustPass() : Pass("mindir_adjust_pass") {}
|
||||
~MindirAdjustPass() override = default;
|
||||
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
|
||||
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
|
||||
int ParameterNodeConvert(AnfNodePtr anf_node);
|
||||
int PrimitiveConvert(AnfNodePtr anf_node);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
protected:
|
||||
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
Loading…
Reference in New Issue