open node_parser register

This commit is contained in:
xuanyue 2021-08-23 11:50:17 +08:00
parent f960f0671f
commit da45edc746
9 changed files with 273 additions and 51 deletions

View File

@ -0,0 +1,96 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "include/registry/parser_context.h"
#include "proto/onnx.pb.h"
#include "proto/caffe.pb.h"
#include "proto/graph.pb.h"
#include "schema/schema_generated.h"
namespace mindspore {
namespace ops {
/// \brief PrimitiveC defined a base class for storing properties
class PrimitiveC;
} // namespace ops
namespace converter {
/// \brief NodeParser defined a base class for parsing node's attributes.
class MS_API NodeParser {
public:
/// \brief Constructor.
NodeParser() = default;
/// \brief Destructor.
virtual ~NodeParser() = default;
/// \brief Method to parse node of ONNX.
///
/// \param[in] onnx_graph Define the onnx graph, which contains all information about the graph.
/// \param[in] onnx_node Define the node to be resolved.
///
/// \return PrimitiveC Attribute storage.
virtual ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
return nullptr;
}
/// \brief Method to parse node of CAFFE.
///
/// \param[in] proto Define the node which contains attributes.
/// \param[in] weight Define the node which contains weight information.
///
/// \return PrimitiveC Attribute storage.
virtual ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
return nullptr;
}
/// \brief Method to parse node of TF.
///
/// \param[in] tf_op Define the node to be resolved.
/// \param[in] tf_node_map Define the all nodes of the graph.
/// \param[in] inputs Define the input name, that determines which inputs will be parsed including their order.
/// Determined by user.
/// \param[in] output_size Define the output num of current node, which need to be determined by user.
///
/// \return PrimitiveC Attribute storage.
virtual ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
return nullptr;
}
/// \brief Method to parse node of TFLITE
///
/// \param[in] tflite_op Define the node to be resolved.
/// \param[in] tflite_model Define the model, which contains all information abort the graph.
///
/// \return PrimitiveC Attribute storage.
virtual ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
return nullptr;
}
};
/// \brief NodeParserPtr defined a shared_ptr type.
using NodeParserPtr = std::shared_ptr<NodeParser>;
} // namespace converter
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
#include <string>
#include "include/registry/node_parser.h"
namespace mindspore {
namespace registry {
/// \brief NodeParserRegistry defined registration of NodeParser.
class MS_API NodeParserRegistry {
public:
/// \brief Constructor of NodeParserRegistry to register NodeParser.
///
/// \param[in] fmk_type Define the framework.
/// \param[in] node_type Define the type of the node to be resolved.
/// \param[in] node_parser Define the NodeParser instance to parse the node.
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser);
/// \brief Destructor
~NodeParserRegistry() = default;
/// \brief Static method to obtain NodeParser instance of a certain node.
///
/// \param[in] fmk_type Define the framework.
/// \param[in] node_type Define the type of the node to be resolved.
///
/// \return NodeParser instance.
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
};
/// \brief Defined registering macro to register NodeParser instance.
///
/// \param[in] fmk_type Define the framework.
/// \param[in] node_type Define the type of the node to be resolved.
/// \param[in] NodeParser instance corresponding with its framework and node type.
#define REG_NODE_PARSER(fmk_type, node_type, node_parser) \
static mindspore::registry::NodeParserRegistry g_##fmk_type##node_type##ParserReg(fmk_type, #node_type, node_parser);
} // namespace registry
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_

View File

@ -33,7 +33,7 @@ using PassPtr = std::shared_ptr<Pass>;
} // namespace opt
namespace registry {
/// \brief PassPosition defined where to plae user's pass.
/// \brief PassPosition defined where to place user's pass.
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
/// \brief PassRegistry defined registration of Pass.

View File

@ -19,6 +19,7 @@
#include <set>
#include <memory>
#include <algorithm>
#include "include/registry/node_parser_registry.h"
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
#include "tools/converter/parser/caffe/caffe_inspector.h"
#include "tools/common/graph_util.h"
@ -145,18 +146,22 @@ STATUS CaffeModelParser::ConvertLayers() {
// parse primitive
MS_LOG(INFO) << "parse op : " << layer.type();
auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
if (node_parser == nullptr) {
NotSupportOp::GetInstance()->InsertOp(layer.type());
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
ops::PrimitiveC *primitive_c;
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeCaffe, layer.type());
if (node_parser != nullptr) {
primitive_c = node_parser->Parse(layer, weight);
} else {
auto node_parser_builtin = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
if (node_parser_builtin == nullptr) {
NotSupportOp::GetInstance()->InsertOp(layer.type());
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}
if (status != RET_OK) {
continue;
}
primitive_c = node_parser_builtin->Parse(layer, weight);
}
if (status != RET_OK) {
continue;
}
auto primitive_c = node_parser->Parse(layer, weight);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
status = RET_ERROR;

View File

@ -21,6 +21,7 @@
#include <vector>
#include <unordered_map>
#include <utility>
#include "include/registry/node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
@ -236,18 +237,24 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
}
STATUS status = RET_OK;
for (const auto &onnx_node : onnx_graph.node()) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr) {
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
status = status == RET_OK ? RET_NOT_FIND_OP : status;
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
}
if (status != RET_OK) {
continue;
ops::PrimitiveC *primitive_c;
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
if (node_parser != nullptr) {
primitive_c = node_parser->Parse(onnx_graph, onnx_node);
} else {
auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser_builtin == nullptr) {
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
status = status == RET_OK ? RET_NOT_FIND_OP : status;
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
}
if (status != RET_OK) {
continue;
}
MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
}
MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
auto primitive_c = node_parser->Parse(onnx_graph, onnx_node);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
status = RET_ERROR;

View File

@ -18,6 +18,7 @@
#include "tools/converter/parser/tf/tf_model_parser.h"
#include <functional>
#include <set>
#include "include/registry/node_parser_registry.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
@ -929,18 +930,23 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return RET_OK;
}
MS_LOG(INFO) << "parse op : " << op_type;
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NotSupportOp::GetInstance()->InsertOp(op_type);
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
<< func_graph_ptr->get_attr("graph_name")->ToString();
return RET_NOT_FIND_OP;
}
ops::PrimitiveC *primitive_c;
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeTf, op_type);
int output_size;
std::vector<std::string> input_names;
auto primitiveC = node_parser->Parse(node_def, tf_node_map, &input_names, &output_size);
if (primitiveC == nullptr) {
if (node_parser != nullptr) {
primitive_c = node_parser->Parse(node_def, tf_node_map, &input_names, &output_size);
} else {
auto node_parser_builtin = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser_builtin == nullptr) {
NotSupportOp::GetInstance()->InsertOp(op_type);
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
<< func_graph_ptr->get_attr("graph_name")->ToString();
return RET_NOT_FIND_OP;
}
primitive_c = node_parser_builtin->Parse(node_def, tf_node_map, &input_names, &output_size);
}
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "node " << op_type << " parser failed";
return RET_ERROR;
}
@ -948,7 +954,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
for (int i = 0; i < output_size; i++) {
node_output_num_[node_def.name() + ":" + to_string(i)] = 1;
}
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC));
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c));
if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr";
return RET_ERROR;
@ -978,7 +984,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return status;
}
status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
status = ConvertQuantParams(inputs.size() - 1, output_size, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
return status;

View File

@ -20,6 +20,7 @@
#include <memory>
#include <algorithm>
#include <utility>
#include "include/registry/node_parser_registry.h"
#include "ops/primitive_c.h"
#include "ir/func_graph.h"
#include "src/common/file_utils.h"
@ -139,27 +140,33 @@ STATUS TfliteModelParser::ConvertOps() {
op_idx++;
// parse primitive
MS_LOG(INFO) << "parse node :" << op_name;
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
if (node_parser == nullptr) {
NotSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
continue;
}
if (status != RET_OK) {
continue;
ops::PrimitiveC *primitive_c;
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeTflite, op_type);
if (node_parser != nullptr) {
primitive_c = node_parser->Parse(op, tflite_model_);
} else {
auto node_parser_builtin = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
if (node_parser_builtin == nullptr) {
NotSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
continue;
}
if (status != RET_OK) {
continue;
}
primitive_c = node_parser_builtin->Parse(op, tflite_model_);
}
std::vector<AnfNodePtr> op_inputs;
auto primitiveC = node_parser->Parse(op, tflite_model_);
if (primitiveC != nullptr) {
op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC))};
if (primitive_c != nullptr) {
op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
} else {
MS_LOG(ERROR) << "parse failed for node: " << op_name;
return RET_ERROR;
}
status = ConvertOpQuantParams(op.get(), primitiveC);
status = ConvertOpQuantParams(op.get(), primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " quant param failed.";
continue;

View File

@ -1,8 +1,5 @@
set(KERNEL_REG_DIR ${TOP_DIR}/mindspore/lite/src/registry)
file(GLOB CONVERT_REG_SRC
pass_registry.cc
model_parser_registry.cc
)
file(GLOB CONVERT_REG_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
file(GLOB KERNEL_REG_SRC ${KERNEL_REG_DIR}/*.cc)
set(REG_SRC ${CONVERT_REG_SRC}
${KERNEL_REG_SRC}

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/registry/node_parser_registry.h"
#include <map>
#include <mutex>
#include <string>
namespace mindspore {
namespace registry {
namespace {
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
std::mutex node_mutex;
} // namespace
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser) {
std::unique_lock<std::mutex> lock(node_mutex);
node_parser_room[fmk_type][node_type] = node_parser;
}
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
auto iter_level1 = node_parser_room.find(fmk_type);
if (iter_level1 == node_parser_room.end()) {
return nullptr;
}
auto iter_level2 = iter_level1->second.find(node_type);
if (iter_level2 == iter_level1->second.end()) {
return nullptr;
}
return iter_level2->second;
}
} // namespace registry
} // namespace mindspore