draw out model_parser basic class

This commit is contained in:
xuanyue 2021-09-15 17:08:40 +08:00
parent c8b9d45abc
commit 6f738d2e4a
20 changed files with 53 additions and 53 deletions

View File

@ -334,8 +334,6 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
@ -477,8 +475,6 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema

View File

@ -237,13 +237,13 @@ build_lite() {
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
cd ${BASEPATH}/../
if [[ "${local_lite_platform}" == "x86_64" ]]; then
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master -j $THREAD_NUM
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "compile x86_64 for nnie failed."
exit 1
fi
elif [[ "${local_lite_platform}" == "arm32" ]]; then
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master -j $THREAD_NUM
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master_dev -j $THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "compile arm32 for nnie failed."
exit 1

View File

@ -14,41 +14,34 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
#include <google/protobuf/message.h>
#include <string>
#include <memory>
#include "schema/inner/model_generated.h"
#include "ir/anf.h"
#include "api/ir/func_graph.h"
#include "include/registry/model_parser_registry.h"
#include "utils/log_adapter.h"
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
namespace mindspore::converter {
class ModelParser {
#include "api/ir/func_graph.h"
#include "include/registry/parser_context.h"
namespace mindspore {
namespace converter {
/// \brief ModelParser defined a base class to parse model.
class MS_API ModelParser {
public:
/// \brief Constructor.
ModelParser() = default;
/// \brief Destructor.
virtual ~ModelParser() = default;
/// \brief Method to parse model, which must be onnx/caffe/tf/tflite.
///
/// \param[in] flags Define the basic parameters when converting, which defined in parser_context.h.
///
/// \return FuncGraph Pointer, which contains all information about the model.
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
protected:
api::FuncGraphPtr res_graph_ = nullptr;
};
} // namespace converter
} // namespace mindspore
typedef ModelParser *(*ModelParserCreator)();
template <class T>
ModelParser *LiteModelParserCreator() {
auto *parser = new (std::nothrow) T();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
} // namespace mindspore::converter
#endif
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_

View File

@ -23,6 +23,9 @@
using mindspore::converter::FmkType;
namespace mindspore {
namespace converter {
class ModelParser;
} // namespace converter
namespace registry {
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
typedef converter::ModelParser *(*ModelParserCreator)();

View File

@ -39,9 +39,6 @@ struct MS_API ConverterParameters {
std::string weight_file;
std::map<std::string, std::string> attrs;
};
/// \brief ModelParser defined a base class of model parser
class MS_API ModelParser;
} // namespace converter
} // namespace mindspore

View File

@ -20,9 +20,9 @@
#include <map>
#include <string>
#include <vector>
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "ut/tools/converter/registry/node_parser_test.h"
#include "tools/converter/model_parser.h"
namespace mindspore {
class ModelParserTest : public converter::ModelParser {

View File

@ -24,7 +24,6 @@
#include "ops/fusion/add_fusion.h"
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/optimizer_manager.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"

View File

@ -19,9 +19,9 @@
#include <memory>
#include <string>
#include "include/registry/model_parser.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"

View File

@ -576,5 +576,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
}
return layer.name();
}
REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator<CaffeModelParser>)
REG_MODEL_PARSER(kFmkTypeCaffe, LiteModelParserCreator<CaffeModelParser>)
} // namespace mindspore::lite

View File

@ -21,7 +21,7 @@
#include <memory>
#include <set>
#include <unordered_map>
#include "tools/converter/model_parser.h"
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "proto/caffe.pb.h"
#include "ops/primitive_c.h"

View File

@ -1248,6 +1248,6 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
return RET_OK;
}
REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator<OnnxModelParser>)
REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,7 @@
#include <set>
#include <unordered_map>
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "proto/onnx.pb.h"

View File

@ -19,6 +19,7 @@
#include <set>
#include <vector>
#include "include/registry/model_parser.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/log_adapter.h"
@ -39,6 +40,16 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
template <class T>
converter::ModelParser *LiteModelParserCreator() {
auto *parser = new (std::nothrow) T();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
} // namespace lite
} // namespace mindspore

View File

@ -1158,6 +1158,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
return RET_OK;
}
REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator<TFModelParser>)
REG_MODEL_PARSER(kFmkTypeTf, LiteModelParserCreator<TFModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -29,7 +29,7 @@
#include "schema/inner/model_generated.h"
#include "securec/include/securec.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/model_parser.h"
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "ops/primitive_c.h"

View File

@ -734,5 +734,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
return RET_OK;
}
REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator<TfliteModelParser>)
REG_MODEL_PARSER(kFmkTypeTflite, LiteModelParserCreator<TfliteModelParser>)
} // namespace mindspore::lite

View File

@ -23,7 +23,7 @@
#include <set>
#include <map>
#include <utility>
#include "tools/converter/model_parser.h"
#include "include/registry/model_parser.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include "tools/common/tensor_util.h"

View File

@ -68,7 +68,7 @@ STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPt
return RET_ERROR;
}
for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo);
if (tensor_addr == nullptr) {
MS_LOG(ERROR) << "input tensor addr nullptr";
return RET_ERROR;
@ -178,7 +178,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
if (IsMarkedTrainOp(fullconnect_cnode)) {
return nullptr;
}
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == 3, nullptr);
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == kInputSizeThree, nullptr);
auto left_slice_node = fullconnect_cnode->input(1);
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
if (IsMarkedTrainOp(left_slice_cnode)) {

View File

@ -31,6 +31,7 @@
namespace mindspore {
namespace opt {
namespace {
constexpr int kOffsetTwo = 2;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
constexpr size_t kBodyNodesNum = 69;
@ -162,7 +163,7 @@ const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const {
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
fw_init_state_, fw_min, fw_from_tensor, input_length_});
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end());
auto is_var1 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
fw_while.emplace_back(is_var1);
@ -232,7 +233,7 @@ const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const {
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
bw_init_state_, bw_min, bw_from_tensor, input_length_});
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end());
auto is_var2 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
bw_while.emplace_back(is_var2);
@ -400,7 +401,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
return RET_ERROR;
}
auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
if (fw_cand_kernel_shape.size() != 2) {
if (fw_cand_kernel_shape.size() != kInputSizeTwo) {
return RET_ERROR;
}
auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
@ -408,7 +409,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
return RET_ERROR;
}
auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
if (bw_cand_kernel_shape.size() != 2) {
if (bw_cand_kernel_shape.size() != kInputSizeTwo) {
return RET_ERROR;
}
if (fw_cand_kernel_shape != bw_cand_kernel_shape) {

View File

@ -45,7 +45,7 @@ void SetConvAttr(const PrimitivePtr &prim, const std::vector<int64_t> &kernel_si
STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (cnode->size() < kInputSizeThree) {
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1;
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << (cnode->size() - 1);
return lite::RET_ERROR;
}
auto weight = cnode->input(kInputIndexTwo);