edit FmkType member name

This commit is contained in:
xuanyue 2021-08-17 15:08:42 +08:00
parent f5003a840a
commit feec548150
50 changed files with 111 additions and 111 deletions

View File

@ -26,11 +26,11 @@ namespace mindspore {
namespace converter {
/// \brief FmkType defined frameworks which converter tool supports.
enum MS_API FmkType : int {
FmkType_TF = 0,
FmkType_CAFFE = 1,
FmkType_ONNX = 2,
FmkType_MS = 3,
FmkType_TFLITE = 4,
kFmkTypeTf = 0,
kFmkTypeCaffe = 1,
kFmkTypeOnnx = 2,
kFmkTypeMs = 3,
kFmkTypeTflite = 4,
};
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
@ -42,7 +42,7 @@ struct MS_API ConverterParameters {
std::map<std::string, std::string> attrs;
};
/// \brief ModelParser defined a model parser
/// \brief ModelParser defined a base class of model parser
class MS_API ModelParser;
} // namespace converter
} // namespace mindspore

View File

@ -20,7 +20,7 @@
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::converter::ConverterParameters;
using mindspore::converter::FmkType_CAFFE;
using mindspore::converter::kFmkTypeCaffe;
namespace mindspore {
class ModelParserRegistryTest : public mindspore::CommonTest {
public:
@ -33,9 +33,9 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
ASSERT_NE(add_parser, nullptr);
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
ASSERT_NE(proposal_parser, nullptr);
REG_MODEL_PARSER(FmkType_CAFFE,
REG_MODEL_PARSER(kFmkTypeCaffe,
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
auto model_parser = registry::ModelParserRegistry::GetModelParser(kFmkTypeCaffe);
ASSERT_NE(model_parser, nullptr);
ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters);

View File

@ -29,15 +29,15 @@
#include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::converter::ConverterParameters;
using mindspore::converter::FmkType_CAFFE;
using mindspore::converter::kFmkTypeCaffe;
using mindspore::registry::POSITION_BEGIN;
namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest {
public:
PassRegistryTest() = default;
void SetUp() override {
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
REG_MODEL_PARSER(kFmkTypeCaffe, TestModelParserCreator);
auto model_parser = registry::ModelParserRegistry::GetModelParser(kFmkTypeCaffe);
if (model_parser == nullptr) {
return;
}

View File

@ -77,12 +77,12 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap
}
int GetFormatByFmk(int32_t fmk_type) {
switch (fmk_type) {
case converter::FmkType_ONNX:
case converter::FmkType_CAFFE:
case converter::FmkType_MS:
case converter::kFmkTypeOnnx:
case converter::kFmkTypeCaffe:
case converter::kFmkTypeMs:
return mindspore::NCHW;
case converter::FmkType_TF:
case converter::FmkType_TFLITE:
case converter::kFmkTypeTf:
case converter::kFmkTypeTflite:
return mindspore::NHWC;
default:
return -1;

View File

@ -115,7 +115,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
}
if (config->fmk == converter::FmkType_MS) {
if (config->fmk == converter::kFmkTypeMs) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
@ -195,8 +195,8 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
if (config->fmk == converter::FmkType_TFLITE || config->fmk == converter::FmkType_TF ||
config->fmk == converter::FmkType_ONNX) {
if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf ||
config->fmk == converter::kFmkTypeOnnx) {
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
}
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();

View File

@ -42,7 +42,7 @@ void InitConverterParameters(const converter::Flags &flag, converter::ConverterP
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
FuncGraphPtr func_graph = nullptr;
if (flag.fmk == converter::FmkType::FmkType_MS) {
if (flag.fmk == converter::FmkType::kFmkTypeMs) {
kernel::PopulateTrainParameters();
MindsporeImporter ms_import;
func_graph = ms_import.ImportMindIR(flag);

View File

@ -114,21 +114,21 @@ int Flags::InitInputOutputDataType() {
int Flags::InitFmk() {
if (this->fmkIn == "CAFFE") {
this->fmk = FmkType_CAFFE;
this->fmk = kFmkTypeCaffe;
} else if (this->fmkIn == "MINDIR") {
this->fmk = FmkType_MS;
this->fmk = kFmkTypeMs;
} else if (this->fmkIn == "TFLITE") {
this->fmk = FmkType_TFLITE;
this->fmk = kFmkTypeTflite;
} else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX;
this->fmk = kFmkTypeOnnx;
} else if (this->fmkIn == "TF") {
this->fmk = FmkType_TF;
this->fmk = kFmkTypeTf;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if (this->fmk != FmkType_CAFFE && !weightFile.empty()) {
if (this->fmk != kFmkTypeCaffe && !weightFile.empty()) {
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
@ -196,7 +196,7 @@ int Flags::InitTrainModel() {
}
if (this->trainModel) {
if (this->fmk != FmkType_MS) {
if (this->fmk != kFmkTypeMs) {
std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl;
return RET_INPUT_PARAM_INVALID;
}

View File

@ -199,8 +199,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
if (flags->fmk == converter::FmkType_TFLITE || flags->fmk == converter::FmkType_TF ||
flags->fmk == converter::FmkType_ONNX) {
if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf ||
flags->fmk == converter::kFmkTypeOnnx) {
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
}
optimizer->AddPassManager(graph_pm);

View File

@ -74,7 +74,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) {
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}
@ -117,7 +117,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
// quantization
if (ctx.fmk != converter::FmkType_TF) {
if (ctx.fmk != converter::kFmkTypeTf) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer tensor_quant_optimizer;
@ -134,7 +134,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
// quantization
if (ctx.fmk != converter::FmkType_TF) {
if (ctx.fmk != converter::kFmkTypeTf) {
// init old node indices
Optimizer quant_node_optimizer;
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());

View File

@ -257,7 +257,7 @@ int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
}
bool MindirAdjust::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type_ != converter::FmkType_MS) {
if (this->fmk_type_ != converter::kFmkTypeMs) {
MS_LOG(INFO) << "The framework type of model should be mindir.";
return lite::RET_OK;
}

View File

@ -38,7 +38,7 @@ class MindirAdjust {
int ComputeQuantParams(AnfNodePtr anf_node);
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
FmkType fmk_type_ = FmkType::FmkType_MS;
FmkType fmk_type_ = FmkType::kFmkTypeMs;
bool train_flag_ = false;
};
} // namespace mindspore::lite

View File

@ -113,14 +113,14 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
return nullptr;
}
func_graph->set_attr("graph_name", MakeValue("main_graph"));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
STATUS status;
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_MS, flag.trainModel, flag.quantType);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel, flag.quantType);
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -519,7 +519,7 @@ int MoveAttrMapResizeGrad(const CNodePtr &cnode) {
} // namespace
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
if (this->fmk_type_ != converter::FmkType_MS) {
if (this->fmk_type_ != converter::kFmkTypeMs) {
MS_LOG(INFO) << "The framework type of model should be mindir.";
return lite::RET_OK;
}

View File

@ -71,7 +71,7 @@ class PrimitiveAdjust {
bool Run(const FuncGraphPtr &func_graph);
protected:
FmkType fmk_type_ = FmkType::FmkType_MS;
FmkType fmk_type_ = FmkType::kFmkTypeMs;
};
} // namespace lite
} // namespace mindspore

View File

@ -240,7 +240,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
MS_ASSERT(graph->allTensors.size() > bnNode->inputIndex.at(1));
auto bnWeightTensorIdxes = bnNode->inputIndex;
bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin());
if (fmkType == converter::FmkType_CAFFE) {
if (fmkType == converter::kFmkTypeCaffe) {
bnWeightTensors->meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get();
bnWeightTensors->varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get();
} else {
@ -258,7 +258,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr";
return RET_ERROR;
}
if (fmkType == converter::FmkType_CAFFE) {
if (fmkType == converter::kFmkTypeCaffe) {
auto scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_SCALE_INDEX]).get();
// calibrate mean and variance
float scale_factor_data = (reinterpret_cast<float *>(scaleTensor->data.data()))[0];

View File

@ -61,7 +61,7 @@ class BatchNormConvertScalePass : public GraphPass {
float *transBias = nullptr;
std::unique_ptr<TensorT> newScaleWeightTensor = nullptr;
std::unique_ptr<TensorT> newScaleBiasTensor = nullptr;
converter::FmkType fmkType = converter::FmkType_TF;
converter::FmkType fmkType = converter::kFmkTypeTf;
};
} // namespace lite
} // namespace mindspore

View File

@ -30,7 +30,7 @@
#include "tools/converter/converter_flags.h"
#include "src/common/string_util.h"
using mindspore::converter::FmkType_TF;
using mindspore::converter::kFmkTypeTf;
namespace mindspore {
namespace lite {
namespace {

View File

@ -26,7 +26,7 @@
#include "tools/converter/optimizer.h"
#include "tools/converter/converter_flags.h"
using mindspore::converter::FmkType_TF;
using mindspore::converter::kFmkTypeTf;
using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
@ -59,7 +59,7 @@ class InferShapePass : public GraphPass {
void InitInferTensor(MetaGraphT *graph);
int InferSubgraph(const int &subgraph_index, MetaGraphT *graph);
converter::FmkType fmk_type_ = FmkType_TF;
converter::FmkType fmk_type_ = kFmkTypeTf;
std::vector<InferTensor> tensors_ = {};
};
} // namespace lite

View File

@ -33,7 +33,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::converter::FmkType_CAFFE;
using mindspore::converter::kFmkTypeCaffe;
namespace mindspore::lite {
namespace {
namespace {
@ -104,7 +104,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
return nullptr;
}
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeCaffe)));
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
@ -112,7 +112,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_CAFFE, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeCaffe, false, quant_type_);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
@ -555,5 +555,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
}
return layer.name();
}
REG_MODEL_PARSER(FmkType_CAFFE, converter::LiteModelParserCreator<CaffeModelParser>)
REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator<CaffeModelParser>)
} // namespace mindspore::lite

View File

@ -151,13 +151,13 @@ int ReplaceLstmNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func
auto lstm_weight_node = GetRealLstmWeightNode(func_graph, lstm_cnode, kWeightIndex);
lite::DataInfo data_info;
if (lstm_weight_node->isa<Parameter>()) {
auto ret = FetchDataFromParameterNode(lstm_cnode, kWeightIndex, converter::FmkType_MS, false, &data_info);
auto ret = FetchDataFromParameterNode(lstm_cnode, kWeightIndex, converter::kFmkTypeMs, false, &data_info);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;
}
} else if (lstm_weight_node->isa<ValueNode>()) {
auto ret = FetchDataFromValueNode(lstm_cnode, kWeightIndex, converter::FmkType_MS, false, &data_info);
auto ret = FetchDataFromValueNode(lstm_cnode, kWeightIndex, converter::kFmkTypeMs, false, &data_info);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;

View File

@ -37,7 +37,7 @@
#include "ops/transpose.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::converter::FmkType_ONNX;
using mindspore::converter::kFmkTypeOnnx;
namespace mindspore {
namespace lite {
namespace {
@ -79,10 +79,10 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
static auto root_func_manager = Manage(res_graph_);
for (auto &subgraph : all_subgraphs_) {
subgraph->set_manager(root_func_manager);
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
}
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
@ -95,7 +95,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_ONNX, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeOnnx, false, quant_type_);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
@ -118,7 +118,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
}
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
onnx_root_graph_ = onnx_model_.graph();
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
return RET_OK;
}
STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
@ -1253,6 +1253,6 @@ int OnnxModelParser::Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graph
return RET_OK;
}
REG_MODEL_PARSER(FmkType_ONNX, converter::LiteModelParserCreator<OnnxModelParser>)
REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator<OnnxModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -133,7 +133,7 @@ STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::
}
FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::FmkType_TF);
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::kFmkTypeTf);
if (graph == nullptr) {
MS_LOG(ERROR) << "new graph Partial Node return nullptr";
return nullptr;

View File

@ -367,7 +367,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
STATUS FunctionalizeWhile::BuildCondGraph() {
cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, converter::FmkType_TF);
cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, converter::kFmkTypeTf);
if (cond_sub_func_graph_ == nullptr) {
MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
return RET_NULL_PTR;
@ -522,7 +522,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
STATUS FunctionalizeWhile::BuildBodyGraph() {
body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, converter::FmkType_TF);
body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, converter::kFmkTypeTf);
if (body_sub_func_graph_ == nullptr) {
MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
return RET_NULL_PTR;

View File

@ -35,7 +35,7 @@
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::converter::FmkType_TF;
using mindspore::converter::kFmkTypeTf;
namespace mindspore {
namespace lite {
namespace {
@ -515,7 +515,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
return nullptr;
}
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
@ -576,7 +576,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TF, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTf, false, quant_type_);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
@ -727,7 +727,7 @@ STATUS TFModelParser::ConvertSubgraph() {
FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
@ -1136,6 +1136,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
return RET_OK;
}
REG_MODEL_PARSER(FmkType_TF, converter::LiteModelParserCreator<TFModelParser>)
REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator<TFModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -32,7 +32,7 @@
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::converter::FmkType_TFLITE;
using mindspore::converter::kFmkTypeTflite;
namespace mindspore::lite {
namespace {
constexpr size_t kConvWeightIndex = 2;
@ -69,7 +69,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
return nullptr;
}
res_graph_ = std::make_shared<FuncGraph>();
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTflite)));
auto status = ConvertGraphInputs();
if (status != RET_OK) {
@ -105,7 +105,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TFLITE, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTflite, false, quant_type_);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
@ -546,5 +546,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
return RET_OK;
}
REG_MODEL_PARSER(FmkType_TFLITE, converter::LiteModelParserCreator<TfliteModelParser>)
REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator<TfliteModelParser>)
} // namespace mindspore::lite

View File

@ -176,13 +176,13 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (fmk_type_ == converter::FmkType_TFLITE) {
if (fmk_type_ == converter::kFmkTypeTflite) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return lite::RET_OK;
}
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
} else if (fmk_type_ == converter::FmkType_TF) {
} else if (fmk_type_ == converter::kFmkTypeTf) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
@ -193,7 +193,7 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
if (fmk_type_ == converter::kFmkTypeOnnx && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return lite::RET_OK;
}
@ -213,10 +213,10 @@ void UnifyFormatToNHWC::SetSensitiveOps() {
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) {
MS_ASSERT(func_graph != nullptr);
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
return false;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::FmkType_ONNX &&
if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::kFmkTypeOnnx &&
shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
return false;
}
@ -230,11 +230,11 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
*dst_format = schema::Format_KHWC;
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::QuantType, schema::Format *)>>
decide_functions = {{converter::FmkType_MS, DecideMINDIRConvWeightSrcFormat},
{converter::FmkType_TF, DecideTFConvWeightSrcFormat},
{converter::FmkType_TFLITE, DecideTFLITEConvWeightSrcFormat},
{converter::FmkType_CAFFE, DecideCAFFEConvWeightSrcFormat},
{converter::FmkType_ONNX, DecideONNXConvWeightSrcFormat}};
decide_functions = {{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
auto iter = decide_functions.find(fmk_type_);
if (iter == decide_functions.end()) {
MS_LOG(ERROR) << "current fmk don't support, please check.";

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace lite {
class UnifyFormatToNHWC : public opt::ToFormatBase {
public:
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
schema::QuantType quant_type = schema::QuantType_QUANT_NONE)
: ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {}
~UnifyFormatToNHWC() override = default;

View File

@ -25,7 +25,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
} // namespace
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
if (fmk < converter::FmkType_TF || fmk > converter::FmkType_TFLITE) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
return;
}

View File

@ -136,9 +136,9 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
lite::DataInfo data_info;
int status;
if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
} else {
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "fetch transpose perm data failed.";

View File

@ -172,8 +172,8 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
}
auto splited_axis = split_info->axis;
// need to check
if (split_info->fmk_type == FmkType::FmkType_CAFFE ||
split_info->fmk_type == FmkType::FmkType_ONNX) { // NHWC -> NCHW
if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW
splited_axis += 1;
}

View File

@ -32,7 +32,7 @@ namespace mindspore {
namespace opt {
class ToFormatBase : public Pass {
public:
explicit ToFormatBase(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
explicit ToFormatBase(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
std::string pass_name = "to_format_base")
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ToFormatBase() override = default;
@ -56,7 +56,7 @@ class ToFormatBase : public Pass {
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
schema::Format *dst_format) = 0;
FmkType fmk_type_{converter::FmkType_MS};
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
mindspore::Format format_{mindspore::NHWC};
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};

View File

@ -23,7 +23,7 @@ namespace mindspore {
namespace opt {
class ToNCHWFormat : public ToFormatBase {
public:
explicit ToNCHWFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
explicit ToNCHWFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag, "to_nchw_format") {
format_ = mindspore::NCHW;
}

View File

@ -23,7 +23,7 @@ namespace mindspore {
namespace opt {
class ToNHWCFormat : public ToFormatBase {
public:
explicit ToNHWCFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
explicit ToNHWCFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {}
~ToNHWCFormat() = default;

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace opt {
class ConstFoldPass : public PatternProcessPass {
public:
explicit ConstFoldPass(converter::FmkType fmk_type = converter::FmkType_MS, bool multigraph = true)
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool multigraph = true)
: PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) {
context_ = std::make_shared<lite::InnerContext>();
context_->Init();
@ -41,7 +41,7 @@ class ConstFoldPass : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
converter::FmkType fmk_type_{converter::FmkType_MS};
converter::FmkType fmk_type_{converter::kFmkTypeMs};
std::shared_ptr<lite::InnerContext> context_{nullptr};
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
};

View File

@ -37,7 +37,7 @@ class ConvTransformFusion : public PatternProcessPass {
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
private:
FmkType fmk_type_ = converter::FmkType_TF;
FmkType fmk_type_ = converter::kFmkTypeTf;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_

View File

@ -199,7 +199,7 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
*after_fg = std::make_shared<FuncGraph>();
auto manager = main_fg->manager();
manager->AddFuncGraph(*after_fg);
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
(*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
(*after_fg)->set_manager(main_fg->manager());

View File

@ -33,7 +33,7 @@ namespace mindspore {
namespace opt {
class DecreaseTransposeAlgo : public Pass {
public:
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::FmkType_MS, bool train_flag = false)
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::kFmkTypeMs, bool train_flag = false)
: Pass("DecreaseTransposeAlgo"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~DecreaseTransposeAlgo() override = default;
void Init(FmkType fmk_type, bool train_flag) {
@ -63,7 +63,7 @@ class DecreaseTransposeAlgo : public Pass {
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
FmkType fmk_type_{converter::FmkType_MS};
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
NodeInferShape node_infer_shape_;
TransposeStrategy transpose_strategy_;

View File

@ -27,7 +27,7 @@ namespace mindspore {
namespace opt {
class InferShapePass : public Pass {
public:
explicit InferShapePass(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~InferShapePass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
@ -40,7 +40,7 @@ class InferShapePass : public Pass {
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput();
FmkType fmk_type_{converter::FmkType_MS};
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
std::map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;

View File

@ -45,7 +45,7 @@ void FreeTensors(std::vector<lite::Tensor *> *tensors) {
void RectifyFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
MS_ASSERT(cnode != nullptr);
if (fmk_type != converter::FmkType_ONNX) {
if (fmk_type != converter::kFmkTypeOnnx) {
return;
}
for (auto &input : inputs) {

View File

@ -32,7 +32,7 @@ namespace mindspore {
namespace opt {
class NodeInferShape {
public:
explicit NodeInferShape(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
explicit NodeInferShape(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: fmk_type_(fmk_type), train_flag_(train_flag) {}
virtual ~NodeInferShape() = default;
void Init(FmkType fmk_type, bool train_flag) {
@ -54,7 +54,7 @@ class NodeInferShape {
STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs, int status);
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor);
FmkType fmk_type_{converter::FmkType_MS};
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
};
} // namespace opt

View File

@ -261,13 +261,13 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
auto padding_node = cnode->input(kInputIndexTwo);
lite::DataInfo data_info;
if (utils::isa<Parameter>(padding_node)) {
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::FmkType_MS, false, &data_info);
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from parameter node failed.";
return lite::RET_ERROR;
}
} else if (utils::isa<ValueNode>(padding_node)) {
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::FmkType_MS, false, &data_info);
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from value node failed.";
return lite::RET_ERROR;

View File

@ -1411,7 +1411,7 @@ bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slic
}
bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
if (fmk_type != converter::FmkType_TF && fmk_type != converter::FmkType_TFLITE) {
if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
return false;
}

View File

@ -95,7 +95,7 @@ class SlicePreposePass : public Pass {
static bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices);
private:
FmkType fmk_type = converter::FmkType_ONNX;
FmkType fmk_type = converter::kFmkTypeOnnx;
};
} // namespace mindspore::opt

View File

@ -58,7 +58,7 @@ class TransposeStrategy {
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type);
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
FmkType fmk_type_{converter::FmkType_MS};
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
NodeInferShape node_infer_shape_;
};

View File

@ -22,7 +22,7 @@ constexpr size_t kCastInputNum = 3;
void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != converter::FmkType_MS) {
if (this->fmk_type != converter::kFmkTypeMs) {
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
return RET_ERROR;
}

View File

@ -30,7 +30,7 @@ class RemoveUnusedCastOpPass : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
FmkType fmk_type = converter::FmkType_TF;
FmkType fmk_type = converter::kFmkTypeTf;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_

View File

@ -57,7 +57,7 @@ std::vector<int> GetTransposePerm(const CNodePtr &node) {
}
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != converter::FmkType_ONNX) {
if (this->fmk_type != converter::kFmkTypeOnnx) {
MS_LOG(ERROR) << "The framework type of model should be onnx.";
return RET_ERROR;
}

View File

@ -30,7 +30,7 @@ class RemoveUnusedTransposeOpPass : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
FmkType fmk_type = converter::FmkType_TF;
FmkType fmk_type = converter::kFmkTypeTf;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_

View File

@ -30,7 +30,7 @@ constexpr int kAnfPopulaterInputNumTwo = 2;
lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (fmk_type_ != converter::FmkType_TF) {
if (fmk_type_ != converter::kFmkTypeTf) {
return lite::RET_OK;
}
auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0));

View File

@ -33,7 +33,7 @@ class UpdateConv2DParamPass : public Pass {
void SetFmkType(FmkType fmk_type) { this->fmk_type_ = fmk_type; }
private:
FmkType fmk_type_ = converter::FmkType_ONNX;
FmkType fmk_type_ = converter::kFmkTypeOnnx;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_