!9047 [lite]mindir reconstruct compatibility

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-04 11:15:33 +08:00 committed by Gitee
commit 61d717032c
13 changed files with 282 additions and 78 deletions

View File

@ -228,6 +228,7 @@ class PrimitiveC {
bool infer_flag_ = true;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);
#endif
typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive);

View File

@ -203,6 +203,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
)
endif()
### train

View File

@ -14,15 +14,15 @@
* limitations under the License.
*/
#include <utility>
#include "tools/anf_importer/anf_importer.h"
#include <utility>
#include "schema/model_generated.h"
#include "ir/dtype.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
int AnfImporter::Import(const schema::QuantType &quantType) {
int AnfImporter::Import(const converter::Flags *flag) {
auto ret = ConverterConstTensor();
if (RET_OK != ret) {
MS_LOG(ERROR) << "ConverterConstTensor failed " << ret;

View File

@ -22,6 +22,7 @@
#include "ir/anf.h"
#include "base/base.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
namespace mindspore::lite {
class AnfImporter {
@ -30,7 +31,7 @@ class AnfImporter {
virtual ~AnfImporter() = default;
virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE);
virtual int Import(const converter::Flags *flag = nullptr);
virtual FuncGraphPtr GetResult() = 0;

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "tools/anf_importer/import_from_protobuf.h"
#include "tools/anf_importer/import_from_mindir.h"
#include <unistd.h>
#include <map>
#include <memory>
@ -36,6 +36,7 @@
#include "src/common/log_adapter.h"
#include "tools/common/protobuf_utils.h"
#include "tools/common/graph_util.h"
#include "load_mindir/load_model.h"
using string = std::string;
using int32 = int32_t;
@ -199,8 +200,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
const onnx::ValueInfoProto &value_proto) {
int AnfImporterFromMindir::BuildParameterForFuncGraph(const ParameterPtr &node,
const onnx::ValueInfoProto &value_proto) {
if (node == nullptr) {
return RET_NULL_PTR;
}
@ -274,8 +275,8 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
return RET_OK;
}
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) {
int AnfImporterFromMindir::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) {
if (outputFuncGraph == nullptr) {
return RET_NULL_PTR;
}
@ -303,8 +304,8 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output
return status;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromMindir::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
if (prim == nullptr) {
return false;
}
@ -317,7 +318,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
return true;
}
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
ValuePtr AnfImporterFromMindir::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: {
@ -347,8 +348,8 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor
}
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromMindir::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
if (prim == nullptr) {
return false;
}
@ -405,7 +406,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
return ret == EOK;
}
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
bool AnfImporterFromMindir::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
if (prim == nullptr) {
return false;
}
@ -460,8 +461,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromMindir::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
std::vector<int> shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
@ -501,8 +502,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
bool AnfImporterFromMindir::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
@ -515,8 +516,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
return true;
}
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
const onnx::AttributeProto &attr_proto) {
bool AnfImporterFromMindir::GetAttrValueForValueNode(const std::string &value_node_name,
const onnx::AttributeProto &attr_proto) {
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
@ -572,7 +573,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_
return true;
}
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
bool AnfImporterFromMindir::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0);
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
if (!attr_proto.has_ref_attr_name()) {
@ -582,7 +583,7 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProtobuf::GetAbstractForCNode(
std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromMindir::GetAbstractForCNode(
const onnx::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
@ -601,9 +602,9 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
return kv;
}
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto,
const schema::QuantType &quantType) {
CNodePtr AnfImporterFromMindir::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto,
const schema::QuantType &quantType) {
static bool interrupt = false;
if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "output funcgraph is nullptr";
@ -685,8 +686,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
return cnode_ptr;
}
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
bool AnfImporterFromMindir::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
return false;
@ -765,9 +766,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
return true;
}
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
int AnfImporterFromMindir::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "funcgraph is nullptr";
return RET_NULL_PTR;
@ -809,8 +809,8 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
return status;
}
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
int AnfImporterFromMindir::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) {
if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "fundgraph is nullptr";
return RET_NULL_PTR;
@ -833,7 +833,7 @@ int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
}
int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
if (!model_proto.has_producer_name()) {
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
return RET_GRAPH_FILE_ERR;
@ -854,7 +854,17 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod
return RET_OK;
}
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
if (onnx_model_ == nullptr) {
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
func_graph_ = LoadMindIR(flag->modelFile);
if (func_graph_ == nullptr) {
MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file.";
return RET_GRAPH_FILE_ERR;
}
return RET_OK;
}
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
if (dstGraph == nullptr) {
MS_LOG(ERROR) << "funcgraph is nullptr";
@ -865,10 +875,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
return status;
}
if (onnx_model_ == nullptr) {
MS_LOG(ERROR) << "onnx_model_ is nullptr";
return RET_NULL_PTR;
}
auto quantType = flag->quantType;
const onnx::GraphProto &graphBuild = onnx_model_->graph();
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
if (status != RET_OK) {
@ -881,25 +888,22 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
return RET_OK;
}
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
onnx::ModelProto *AnfImporterFromMindir::ReadOnnxFromBinary(const std::string &model_path) {
auto onnx_model = new (std::nothrow) onnx::ModelProto;
if (onnx_model == nullptr) {
MS_LOG(ERROR) << "New onnx ModelProto failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
return nullptr;
}
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID);
return nullptr;
}
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
MS_LOG(ERROR) << "Read onnx model file failed, which is not a matched onnx model";
return nullptr;
}
return onnx_model;
}
FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
FuncGraphPtr AnfImporterFromMindir::GetResult() { return this->func_graph_; }
} // namespace mindspore::lite

View File

@ -29,18 +29,17 @@
#include "abstract/abstract_value.h"
namespace mindspore::lite {
class AnfImporterFromProtobuf : public AnfImporter {
class AnfImporterFromMindir : public AnfImporter {
public:
AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph)
: onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {}
AnfImporterFromMindir() = default;
~AnfImporterFromProtobuf() override = default;
~AnfImporterFromMindir() override { delete onnx_model_; }
static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path);
FuncGraphPtr GetResult() override;
int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;
int Import(const converter::Flags *flag) override;
private:
int ConverterConstTensor() override { return RET_ERROR; };

View File

@ -57,6 +57,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/identity_remove_pass.cc
../optimizer/graph/infershape_pass.cc
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

View File

@ -29,6 +29,7 @@
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
@ -61,6 +62,18 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
// mindir pre adjustment
if (config->fmk == converter::FmkType_MS) {
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
mindir_adjust_pass->SetQuantType(config->quantType);
if (!mindir_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}
// for now - trainning is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining

View File

@ -30,7 +30,7 @@
#include "parser/onnx/onnx_converter.h"
#include "parser/tf/tf_converter.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/anf_importer/import_from_protobuf.h"
#include "tools/anf_importer/import_from_mindir.h"
#include "proto/onnx.pb.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
@ -54,9 +54,7 @@ Converter::~Converter() {
class MindsporeImporter : public Converter {
public:
MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) {
modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph));
}
MindsporeImporter() { modelImporter = new AnfImporterFromMindir(); }
~MindsporeImporter() override = default;
};
@ -66,7 +64,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
FuncGraphPtr graph = nullptr;
if (flag->fmk == converter::FmkType_MS) {
MS_ASSERT(nullptr != modelImporter);
int status = modelImporter->Import(flag->quantType);
int status = modelImporter->Import(flag);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult();
} else {
@ -127,15 +125,8 @@ int RunConverter(int argc, const char **argv) {
MetaGraphT *fb_graph = nullptr;
switch (flags->fmk) {
case FmkType::FmkType_MS: {
auto graph = std::make_shared<FuncGraph>();
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
if (onnx_graph == nullptr) {
MS_LOG(ERROR) << "Read MINDIR from binary return nullptr";
break;
}
MindsporeImporter mindsporeImporter(onnx_graph, graph);
MindsporeImporter mindsporeImporter;
fb_graph = mindsporeImporter.Convert(flags.get());
delete onnx_graph;
break;
}
case FmkType::FmkType_CAFFE: {

View File

@ -26,22 +26,6 @@ namespace mindspore {
namespace opt {
namespace {
constexpr auto kAnfPrimitiveIndex = 0;
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
if (node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}
bool IsRealKernel(const AnfNodePtr &node) {
if (node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -136,6 +120,22 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
}
} // namespace
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
if (node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);

View File

@ -34,6 +34,8 @@ using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
bool IsRealCNodeKernel(const AnfNodePtr &node);
bool IsGraphKernel(const AnfNodePtr &node);

View File

@ -0,0 +1,147 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include <algorithm>
#include <vector>
#include <memory>
#include "src/ops/primitive_c.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
using mindspore::lite::PrimitiveC;
namespace mindspore {
namespace opt {
int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) {
if (!utils::isa<ParameterPtr>(anf_node)) {
MS_LOG(INFO) << "only parameter node need to convert tensor.";
return lite::RET_NO_CHANGE;
}
auto param_node = anf_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
MS_LOG(INFO) << "this is graph input, don't need to convert.";
return lite::RET_NO_CHANGE;
}
if (utils::isa<ParamValueLitePtr>(param_node->default_param())) {
MS_LOG(INFO) << "the tensor has been a paramvalueLite.";
return lite::RET_NO_CHANGE;
}
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "fail to new a ParamValueLite.";
return lite::RET_ERROR;
}
param_node->set_name(param_node->debug_info()->name());
auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "the node is not a tensor::TensorPtr.";
return lite::RET_ERROR;
}
param_value->set_tensor_size(tensor_info->Size());
param_value->set_tensor_type(tensor_info->data_type());
auto tensor_shape = tensor_info->shape();
std::vector<int> shape;
std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape),
[](int64_t value) { return static_cast<int>(value); });
param_value->set_tensor_shape(shape);
auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape);
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr.";
return lite::RET_MEMORY_FAILED;
}
auto *tensor_data_buf = tensor->MutableData();
if (tensor_data_buf == nullptr) {
MS_LOG(ERROR) << "malloc tensor data failed.";
delete tensor;
return lite::RET_MEMORY_FAILED;
}
if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
delete tensor;
return lite::RET_MEMORY_FAILED;
}
tensor->set_data(nullptr);
param_value->set_tensor_addr(tensor_data_buf);
param_node->set_default_param(param_value);
delete tensor;
return lite::RET_OK;
}
int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(INFO) << "only cnode need to convert primitive.";
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->inputs().empty() || cnode->input(0) == nullptr) {
MS_LOG(ERROR) << "the cnode is invalid.";
return lite::RET_NULL_PTR;
}
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr || value_node->value() == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return lite::RET_NULL_PTR;
}
if (utils::isa<PrimitiveCPtr>(value_node->value())) {
MS_LOG(INFO) << "the value has been primitiveC.";
return lite::RET_NO_CHANGE;
}
auto primitive = value_node->value()->cast<PrimitivePtr>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "the value is not primitive.";
return lite::RET_ERROR;
}
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
value_node->set_value(primitive_c);
} else {
auto primitiveT = std::make_unique<schema::PrimitiveT>();
primitiveT->value.type = (CheckPrimitiveType(anf_node, prim::kPrimReturn) ? schema::PrimitiveType_Return
: schema::PrimitiveType_MakeTuple);
value_node->set_value(std::make_shared<PrimitiveC>(primitiveT.release()));
}
return lite::RET_OK;
}
bool MindirAdjustPass::Run(const FuncGraphPtr &graph) {
if (this->fmk_type_ != lite::converter::FmkType_MS) {
MS_LOG(INFO) << "The framework type of model should be mindir.";
return lite::RET_OK;
}
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
int status = lite::RET_OK;
for (auto &node : node_list) {
if (utils::isa<ParameterPtr>(node)) {
status = ParameterNodeConvert(node);
} else if (utils::isa<CNodePtr>(node)) {
status = PrimitiveConvert(node);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
#include <string>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/param_value_lite.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class MindirAdjustPass : public Pass {
public:
MindirAdjustPass() : Pass("mindir_adjust_pass") {}
~MindirAdjustPass() override = default;
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
int ParameterNodeConvert(AnfNodePtr anf_node);
int PrimitiveConvert(AnfNodePtr anf_node);
bool Run(const FuncGraphPtr &graph) override;
protected:
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
FmkType fmk_type_ = FmkType::FmkType_MS;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_