open node_parser register
This commit is contained in:
parent
f960f0671f
commit
da45edc746
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "proto/caffe.pb.h"
|
||||
#include "proto/graph.pb.h"
|
||||
#include "schema/schema_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief PrimitiveC defined a base class for storing properties
|
||||
class PrimitiveC;
|
||||
} // namespace ops
|
||||
namespace converter {
|
||||
/// \brief NodeParser defined a base class for parsing node's attributes.
|
||||
class MS_API NodeParser {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
NodeParser() = default;
|
||||
|
||||
/// \brief Destructor.
|
||||
virtual ~NodeParser() = default;
|
||||
|
||||
/// \brief Method to parse node of ONNX.
|
||||
///
|
||||
/// \param[in] onnx_graph Define the onnx graph, which contains all information about the graph.
|
||||
/// \param[in] onnx_node Define the node to be resolved.
|
||||
///
|
||||
/// \return PrimitiveC Attribute storage.
|
||||
virtual ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// \brief Method to parse node of CAFFE.
|
||||
///
|
||||
/// \param[in] proto Define the node which contains attributes.
|
||||
/// \param[in] weight Define the node which contains weight information.
|
||||
///
|
||||
/// \return PrimitiveC Attribute storage.
|
||||
virtual ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// \brief Method to parse node of TF.
|
||||
///
|
||||
/// \param[in] tf_op Define the node to be resolved.
|
||||
/// \param[in] tf_node_map Define the all nodes of the graph.
|
||||
/// \param[in] inputs Define the input name, that determines which inputs will be parsed including their order.
|
||||
/// Determined by user.
|
||||
/// \param[in] output_size Define the output num of current node, which need to be determined by user.
|
||||
///
|
||||
/// \return PrimitiveC Attribute storage.
|
||||
virtual ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// \brief Method to parse node of TFLITE
|
||||
///
|
||||
/// \param[in] tflite_op Define the node to be resolved.
|
||||
/// \param[in] tflite_model Define the model, which contains all information abort the graph.
|
||||
///
|
||||
/// \return PrimitiveC Attribute storage.
|
||||
virtual ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
/// \brief NodeParserPtr defined a shared_ptr type.
|
||||
using NodeParserPtr = std::shared_ptr<NodeParser>;
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_H_
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
|
||||
|
||||
#include <string>
|
||||
#include "include/registry/node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
/// \brief NodeParserRegistry defined registration of NodeParser.
|
||||
class MS_API NodeParserRegistry {
|
||||
public:
|
||||
/// \brief Constructor of NodeParserRegistry to register NodeParser.
|
||||
///
|
||||
/// \param[in] fmk_type Define the framework.
|
||||
/// \param[in] node_type Define the type of the node to be resolved.
|
||||
/// \param[in] node_parser Define the NodeParser instance to parse the node.
|
||||
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
||||
const converter::NodeParserPtr &node_parser);
|
||||
|
||||
/// \brief Destructor
|
||||
~NodeParserRegistry() = default;
|
||||
|
||||
/// \brief Static method to obtain NodeParser instance of a certain node.
|
||||
///
|
||||
/// \param[in] fmk_type Define the framework.
|
||||
/// \param[in] node_type Define the type of the node to be resolved.
|
||||
///
|
||||
/// \return NodeParser instance.
|
||||
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
|
||||
};
|
||||
|
||||
/// \brief Defined registering macro to register NodeParser instance.
|
||||
///
|
||||
/// \param[in] fmk_type Define the framework.
|
||||
/// \param[in] node_type Define the type of the node to be resolved.
|
||||
/// \param[in] NodeParser instance corresponding with its framework and node type.
|
||||
#define REG_NODE_PARSER(fmk_type, node_type, node_parser) \
|
||||
static mindspore::registry::NodeParserRegistry g_##fmk_type##node_type##ParserReg(fmk_type, #node_type, node_parser);
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
|
|
@ -33,7 +33,7 @@ using PassPtr = std::shared_ptr<Pass>;
|
|||
} // namespace opt
|
||||
|
||||
namespace registry {
|
||||
/// \brief PassPosition defined where to plae user's pass.
|
||||
/// \brief PassPosition defined where to place user's pass.
|
||||
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
|
||||
|
||||
/// \brief PassRegistry defined registration of Pass.
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <set>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
|
||||
#include "tools/converter/parser/caffe/caffe_inspector.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
|
@ -145,18 +146,22 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
|
||||
// parse primitive
|
||||
MS_LOG(INFO) << "parse op : " << layer.type();
|
||||
auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
|
||||
if (node_parser == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
ops::PrimitiveC *primitive_c;
|
||||
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeCaffe, layer.type());
|
||||
if (node_parser != nullptr) {
|
||||
primitive_c = node_parser->Parse(layer, weight);
|
||||
} else {
|
||||
auto node_parser_builtin = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
|
||||
if (node_parser_builtin == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
primitive_c = node_parser_builtin->Parse(layer, weight);
|
||||
}
|
||||
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto primitive_c = node_parser->Parse(layer, weight);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
|
||||
status = RET_ERROR;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/protobuf_utils.h"
|
||||
|
@ -236,18 +237,24 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|||
}
|
||||
STATUS status = RET_OK;
|
||||
for (const auto &onnx_node : onnx_graph.node()) {
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||
if (node_parser == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
status = status == RET_OK ? RET_NOT_FIND_OP : status;
|
||||
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
ops::PrimitiveC *primitive_c;
|
||||
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
|
||||
if (node_parser != nullptr) {
|
||||
primitive_c = node_parser->Parse(onnx_graph, onnx_node);
|
||||
} else {
|
||||
auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||
if (node_parser_builtin == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
status = status == RET_OK ? RET_NOT_FIND_OP : status;
|
||||
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
|
||||
primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
|
||||
auto primitive_c = node_parser->Parse(onnx_graph, onnx_node);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
|
||||
status = RET_ERROR;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "tools/converter/parser/tf/tf_model_parser.h"
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
|
@ -929,18 +930,23 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
return RET_OK;
|
||||
}
|
||||
MS_LOG(INFO) << "parse op : " << op_type;
|
||||
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
|
||||
<< func_graph_ptr->get_attr("graph_name")->ToString();
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
|
||||
ops::PrimitiveC *primitive_c;
|
||||
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeTf, op_type);
|
||||
int output_size;
|
||||
std::vector<std::string> input_names;
|
||||
auto primitiveC = node_parser->Parse(node_def, tf_node_map, &input_names, &output_size);
|
||||
if (primitiveC == nullptr) {
|
||||
if (node_parser != nullptr) {
|
||||
primitive_c = node_parser->Parse(node_def, tf_node_map, &input_names, &output_size);
|
||||
} else {
|
||||
auto node_parser_builtin = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser_builtin == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
|
||||
<< func_graph_ptr->get_attr("graph_name")->ToString();
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
primitive_c = node_parser_builtin->Parse(node_def, tf_node_map, &input_names, &output_size);
|
||||
}
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "node " << op_type << " parser failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -948,7 +954,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
for (int i = 0; i < output_size; i++) {
|
||||
node_output_num_[node_def.name() + ":" + to_string(i)] = 1;
|
||||
}
|
||||
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC));
|
||||
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c));
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "value_node is nullptr";
|
||||
return RET_ERROR;
|
||||
|
@ -978,7 +984,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
|
||||
status = ConvertQuantParams(inputs.size() - 1, output_size, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
|
||||
return status;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
@ -139,27 +140,33 @@ STATUS TfliteModelParser::ConvertOps() {
|
|||
op_idx++;
|
||||
// parse primitive
|
||||
MS_LOG(INFO) << "parse node :" << op_name;
|
||||
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
ops::PrimitiveC *primitive_c;
|
||||
auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeTflite, op_type);
|
||||
if (node_parser != nullptr) {
|
||||
primitive_c = node_parser->Parse(op, tflite_model_);
|
||||
} else {
|
||||
auto node_parser_builtin = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
|
||||
if (node_parser_builtin == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
primitive_c = node_parser_builtin->Parse(op, tflite_model_);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
auto primitiveC = node_parser->Parse(op, tflite_model_);
|
||||
if (primitiveC != nullptr) {
|
||||
op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC))};
|
||||
if (primitive_c != nullptr) {
|
||||
op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
|
||||
} else {
|
||||
MS_LOG(ERROR) << "parse failed for node: " << op_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = ConvertOpQuantParams(op.get(), primitiveC);
|
||||
status = ConvertOpQuantParams(op.get(), primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert " << op_name << " quant param failed.";
|
||||
continue;
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
set(KERNEL_REG_DIR ${TOP_DIR}/mindspore/lite/src/registry)
|
||||
file(GLOB CONVERT_REG_SRC
|
||||
pass_registry.cc
|
||||
model_parser_registry.cc
|
||||
)
|
||||
file(GLOB CONVERT_REG_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
file(GLOB KERNEL_REG_SRC ${KERNEL_REG_DIR}/*.cc)
|
||||
set(REG_SRC ${CONVERT_REG_SRC}
|
||||
${KERNEL_REG_SRC}
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "include/registry/node_parser_registry.h"
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
|
||||
std::mutex node_mutex;
|
||||
} // namespace
|
||||
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
||||
const converter::NodeParserPtr &node_parser) {
|
||||
std::unique_lock<std::mutex> lock(node_mutex);
|
||||
node_parser_room[fmk_type][node_type] = node_parser;
|
||||
}
|
||||
|
||||
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
|
||||
auto iter_level1 = node_parser_room.find(fmk_type);
|
||||
if (iter_level1 == node_parser_room.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto iter_level2 = iter_level1->second.find(node_type);
|
||||
if (iter_level2 == iter_level1->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return iter_level2->second;
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue