draw out model_parser basic class
This commit is contained in:
parent
c8b9d45abc
commit
6f738d2e4a
|
@ -334,8 +334,6 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||
|
@ -477,8 +475,6 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||
|
|
|
@ -237,13 +237,13 @@ build_lite() {
|
|||
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
|
||||
cd ${BASEPATH}/../
|
||||
if [[ "${local_lite_platform}" == "x86_64" ]]; then
|
||||
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master -j $THREAD_NUM
|
||||
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "compile x86_64 for nnie failed."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${local_lite_platform}" == "arm32" ]]; then
|
||||
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master -j $THREAD_NUM
|
||||
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master_dev -j $THREAD_NUM
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "compile arm32 for nnie failed."
|
||||
exit 1
|
||||
|
|
|
@ -14,41 +14,34 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
|
||||
#include <google/protobuf/message.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "ir/anf.h"
|
||||
#include "api/ir/func_graph.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
||||
|
||||
namespace mindspore::converter {
|
||||
class ModelParser {
|
||||
#include "api/ir/func_graph.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace converter {
|
||||
/// \brief ModelParser defined a base class to parse model.
|
||||
class MS_API ModelParser {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
ModelParser() = default;
|
||||
|
||||
/// \brief Destructor.
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
/// \brief Method to parse model, which must be onnx/caffe/tf/tflite.
|
||||
///
|
||||
/// \param[in] flags Define the basic parameters when converting, which defined in parser_context.h.
|
||||
///
|
||||
/// \return FuncGraph Pointer, which contains all information about the model.
|
||||
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||
|
||||
protected:
|
||||
api::FuncGraphPtr res_graph_ = nullptr;
|
||||
};
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
||||
typedef ModelParser *(*ModelParserCreator)();
|
||||
|
||||
template <class T>
|
||||
ModelParser *LiteModelParserCreator() {
|
||||
auto *parser = new (std::nothrow) T();
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
} // namespace mindspore::converter
|
||||
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
|
@ -23,6 +23,9 @@
|
|||
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace converter {
|
||||
class ModelParser;
|
||||
} // namespace converter
|
||||
namespace registry {
|
||||
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
||||
typedef converter::ModelParser *(*ModelParserCreator)();
|
||||
|
|
|
@ -39,9 +39,6 @@ struct MS_API ConverterParameters {
|
|||
std::string weight_file;
|
||||
std::map<std::string, std::string> attrs;
|
||||
};
|
||||
|
||||
/// \brief ModelParser defined a base class of model parser
|
||||
class MS_API ModelParser;
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelParserTest : public converter::ModelParser {
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "ops/fusion/add_fusion.h"
|
||||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
|
|
|
@ -576,5 +576,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
|||
}
|
||||
return layer.name();
|
||||
}
|
||||
REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator<CaffeModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeCaffe, LiteModelParserCreator<CaffeModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "proto/caffe.pb.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
|
|
@ -1248,6 +1248,6 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator<OnnxModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
@ -39,6 +40,16 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
|
|||
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
||||
int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
|
||||
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
||||
|
||||
template <class T>
|
||||
converter::ModelParser *LiteModelParserCreator() {
|
||||
auto *parser = new (std::nothrow) T();
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1158,6 +1158,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator<TFModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeTf, LiteModelParserCreator<TFModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
|
|
|
@ -734,5 +734,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator<TfliteModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeTflite, LiteModelParserCreator<TfliteModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <set>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "include/registry/model_parser.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
|
|
@ -68,7 +68,7 @@ STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPt
|
|||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
|
||||
auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
|
||||
auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo);
|
||||
if (tensor_addr == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor addr nullptr";
|
||||
return RET_ERROR;
|
||||
|
@ -178,7 +178,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
if (IsMarkedTrainOp(fullconnect_cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == 3, nullptr);
|
||||
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == kInputSizeThree, nullptr);
|
||||
auto left_slice_node = fullconnect_cnode->input(1);
|
||||
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
||||
if (IsMarkedTrainOp(left_slice_cnode)) {
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr int kOffsetTwo = 2;
|
||||
constexpr size_t kCondNodesNum = 12;
|
||||
constexpr size_t kCondCNodesNum = 4;
|
||||
constexpr size_t kBodyNodesNum = 69;
|
||||
|
@ -162,7 +163,7 @@ const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const {
|
|||
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
||||
auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
|
||||
fw_init_state_, fw_min, fw_from_tensor, input_length_});
|
||||
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
|
||||
fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end());
|
||||
auto is_var1 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
|
||||
fw_while.emplace_back(is_var1);
|
||||
|
@ -232,7 +233,7 @@ const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const {
|
|||
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
||||
auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
|
||||
bw_init_state_, bw_min, bw_from_tensor, input_length_});
|
||||
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
|
||||
bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end());
|
||||
auto is_var2 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
|
||||
bw_while.emplace_back(is_var2);
|
||||
|
@ -400,7 +401,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
|
||||
if (fw_cand_kernel_shape.size() != 2) {
|
||||
if (fw_cand_kernel_shape.size() != kInputSizeTwo) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
|
||||
|
@ -408,7 +409,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
|
||||
if (bw_cand_kernel_shape.size() != 2) {
|
||||
if (bw_cand_kernel_shape.size() != kInputSizeTwo) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
|
||||
|
|
|
@ -45,7 +45,7 @@ void SetConvAttr(const PrimitivePtr &prim, const std::vector<int64_t> &kernel_si
|
|||
STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (cnode->size() < kInputSizeThree) {
|
||||
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1;
|
||||
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << (cnode->size() - 1);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto weight = cnode->input(kInputIndexTwo);
|
||||
|
|
Loading…
Reference in New Issue