forked from mindspore-Ecosystem/mindspore
edit FmkType member name
This commit is contained in:
parent
f5003a840a
commit
feec548150
|
@ -26,11 +26,11 @@ namespace mindspore {
|
|||
namespace converter {
|
||||
/// \brief FmkType defined frameworks which converter tool supports.
|
||||
enum MS_API FmkType : int {
|
||||
FmkType_TF = 0,
|
||||
FmkType_CAFFE = 1,
|
||||
FmkType_ONNX = 2,
|
||||
FmkType_MS = 3,
|
||||
FmkType_TFLITE = 4,
|
||||
kFmkTypeTf = 0,
|
||||
kFmkTypeCaffe = 1,
|
||||
kFmkTypeOnnx = 2,
|
||||
kFmkTypeMs = 3,
|
||||
kFmkTypeTflite = 4,
|
||||
};
|
||||
|
||||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
|
@ -42,7 +42,7 @@ struct MS_API ConverterParameters {
|
|||
std::map<std::string, std::string> attrs;
|
||||
};
|
||||
|
||||
/// \brief ModelParser defined a model parser
|
||||
/// \brief ModelParser defined a base class of model parser
|
||||
class MS_API ModelParser;
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::converter::ConverterParameters;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::kFmkTypeCaffe;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
|
@ -33,9 +33,9 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
|||
ASSERT_NE(add_parser, nullptr);
|
||||
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
|
||||
ASSERT_NE(proposal_parser, nullptr);
|
||||
REG_MODEL_PARSER(FmkType_CAFFE,
|
||||
REG_MODEL_PARSER(kFmkTypeCaffe,
|
||||
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(kFmkTypeCaffe);
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
ConverterParameters converter_parameters;
|
||||
auto func_graph = model_parser->Parse(converter_parameters);
|
||||
|
|
|
@ -29,15 +29,15 @@
|
|||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
using mindspore::converter::ConverterParameters;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::kFmkTypeCaffe;
|
||||
using mindspore::registry::POSITION_BEGIN;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
PassRegistryTest() = default;
|
||||
void SetUp() override {
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
REG_MODEL_PARSER(kFmkTypeCaffe, TestModelParserCreator);
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(kFmkTypeCaffe);
|
||||
if (model_parser == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -77,12 +77,12 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap
|
|||
}
|
||||
int GetFormatByFmk(int32_t fmk_type) {
|
||||
switch (fmk_type) {
|
||||
case converter::FmkType_ONNX:
|
||||
case converter::FmkType_CAFFE:
|
||||
case converter::FmkType_MS:
|
||||
case converter::kFmkTypeOnnx:
|
||||
case converter::kFmkTypeCaffe:
|
||||
case converter::kFmkTypeMs:
|
||||
return mindspore::NCHW;
|
||||
case converter::FmkType_TF:
|
||||
case converter::FmkType_TFLITE:
|
||||
case converter::kFmkTypeTf:
|
||||
case converter::kFmkTypeTflite:
|
||||
return mindspore::NHWC;
|
||||
default:
|
||||
return -1;
|
||||
|
|
|
@ -115,7 +115,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
|
||||
}
|
||||
if (config->fmk == converter::FmkType_MS) {
|
||||
if (config->fmk == converter::kFmkTypeMs) {
|
||||
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
|
||||
if (remove_unused_cast_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
|
||||
|
@ -195,8 +195,8 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
|
|||
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
if (config->fmk == converter::FmkType_TFLITE || config->fmk == converter::FmkType_TF ||
|
||||
config->fmk == converter::FmkType_ONNX) {
|
||||
if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf ||
|
||||
config->fmk == converter::kFmkTypeOnnx) {
|
||||
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
|
||||
}
|
||||
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
|
||||
|
|
|
@ -42,7 +42,7 @@ void InitConverterParameters(const converter::Flags &flag, converter::ConverterP
|
|||
|
||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
if (flag.fmk == converter::FmkType::FmkType_MS) {
|
||||
if (flag.fmk == converter::FmkType::kFmkTypeMs) {
|
||||
kernel::PopulateTrainParameters();
|
||||
MindsporeImporter ms_import;
|
||||
func_graph = ms_import.ImportMindIR(flag);
|
||||
|
|
|
@ -114,21 +114,21 @@ int Flags::InitInputOutputDataType() {
|
|||
|
||||
int Flags::InitFmk() {
|
||||
if (this->fmkIn == "CAFFE") {
|
||||
this->fmk = FmkType_CAFFE;
|
||||
this->fmk = kFmkTypeCaffe;
|
||||
} else if (this->fmkIn == "MINDIR") {
|
||||
this->fmk = FmkType_MS;
|
||||
this->fmk = kFmkTypeMs;
|
||||
} else if (this->fmkIn == "TFLITE") {
|
||||
this->fmk = FmkType_TFLITE;
|
||||
this->fmk = kFmkTypeTflite;
|
||||
} else if (this->fmkIn == "ONNX") {
|
||||
this->fmk = FmkType_ONNX;
|
||||
this->fmk = kFmkTypeOnnx;
|
||||
} else if (this->fmkIn == "TF") {
|
||||
this->fmk = FmkType_TF;
|
||||
this->fmk = kFmkTypeTf;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->fmk != FmkType_CAFFE && !weightFile.empty()) {
|
||||
if (this->fmk != kFmkTypeCaffe && !weightFile.empty()) {
|
||||
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ int Flags::InitTrainModel() {
|
|||
}
|
||||
|
||||
if (this->trainModel) {
|
||||
if (this->fmk != FmkType_MS) {
|
||||
if (this->fmk != kFmkTypeMs) {
|
||||
std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
|
|
@ -199,8 +199,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
|
|||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
if (flags->fmk == converter::FmkType_TFLITE || flags->fmk == converter::FmkType_TF ||
|
||||
flags->fmk == converter::FmkType_ONNX) {
|
||||
if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf ||
|
||||
flags->fmk == converter::kFmkTypeOnnx) {
|
||||
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
|
||||
}
|
||||
optimizer->AddPassManager(graph_pm);
|
||||
|
|
|
@ -74,7 +74,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer format_trans_optimizer;
|
||||
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
|
||||
if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) {
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
}
|
||||
|
||||
// quantization
|
||||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
if (ctx.fmk != converter::kFmkTypeTf) {
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer tensor_quant_optimizer;
|
||||
|
@ -134,7 +134,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
}
|
||||
|
||||
// quantization
|
||||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
if (ctx.fmk != converter::kFmkTypeTf) {
|
||||
// init old node indices
|
||||
Optimizer quant_node_optimizer;
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
|
|
|
@ -257,7 +257,7 @@ int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
|
|||
}
|
||||
|
||||
bool MindirAdjust::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type_ != converter::FmkType_MS) {
|
||||
if (this->fmk_type_ != converter::kFmkTypeMs) {
|
||||
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ class MindirAdjust {
|
|||
int ComputeQuantParams(AnfNodePtr anf_node);
|
||||
|
||||
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||
FmkType fmk_type_ = FmkType::kFmkTypeMs;
|
||||
bool train_flag_ = false;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -113,14 +113,14 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
return nullptr;
|
||||
}
|
||||
func_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
|
||||
STATUS status;
|
||||
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_MS, flag.trainModel, flag.quantType);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel, flag.quantType);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -519,7 +519,7 @@ int MoveAttrMapResizeGrad(const CNodePtr &cnode) {
|
|||
} // namespace
|
||||
|
||||
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
|
||||
if (this->fmk_type_ != converter::FmkType_MS) {
|
||||
if (this->fmk_type_ != converter::kFmkTypeMs) {
|
||||
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ class PrimitiveAdjust {
|
|||
bool Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
protected:
|
||||
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||
FmkType fmk_type_ = FmkType::kFmkTypeMs;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -240,7 +240,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
|
|||
MS_ASSERT(graph->allTensors.size() > bnNode->inputIndex.at(1));
|
||||
auto bnWeightTensorIdxes = bnNode->inputIndex;
|
||||
bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin());
|
||||
if (fmkType == converter::FmkType_CAFFE) {
|
||||
if (fmkType == converter::kFmkTypeCaffe) {
|
||||
bnWeightTensors->meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get();
|
||||
bnWeightTensors->varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get();
|
||||
} else {
|
||||
|
@ -258,7 +258,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
|
|||
MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (fmkType == converter::FmkType_CAFFE) {
|
||||
if (fmkType == converter::kFmkTypeCaffe) {
|
||||
auto scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_SCALE_INDEX]).get();
|
||||
// calibrate mean and variance
|
||||
float scale_factor_data = (reinterpret_cast<float *>(scaleTensor->data.data()))[0];
|
||||
|
|
|
@ -61,7 +61,7 @@ class BatchNormConvertScalePass : public GraphPass {
|
|||
float *transBias = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleWeightTensor = nullptr;
|
||||
std::unique_ptr<TensorT> newScaleBiasTensor = nullptr;
|
||||
converter::FmkType fmkType = converter::FmkType_TF;
|
||||
converter::FmkType fmkType = converter::kFmkTypeTf;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
using mindspore::converter::FmkType_TF;
|
||||
using mindspore::converter::kFmkTypeTf;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::converter::FmkType_TF;
|
||||
using mindspore::converter::kFmkTypeTf;
|
||||
using mindspore::schema::TensorT;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -59,7 +59,7 @@ class InferShapePass : public GraphPass {
|
|||
void InitInferTensor(MetaGraphT *graph);
|
||||
int InferSubgraph(const int &subgraph_index, MetaGraphT *graph);
|
||||
|
||||
converter::FmkType fmk_type_ = FmkType_TF;
|
||||
converter::FmkType fmk_type_ = kFmkTypeTf;
|
||||
std::vector<InferTensor> tensors_ = {};
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::kFmkTypeCaffe;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
namespace {
|
||||
|
@ -104,7 +104,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
return nullptr;
|
||||
}
|
||||
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeCaffe)));
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
|
@ -112,7 +112,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_CAFFE, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeCaffe, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -555,5 +555,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
|||
}
|
||||
return layer.name();
|
||||
}
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, converter::LiteModelParserCreator<CaffeModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator<CaffeModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -151,13 +151,13 @@ int ReplaceLstmNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func
|
|||
auto lstm_weight_node = GetRealLstmWeightNode(func_graph, lstm_cnode, kWeightIndex);
|
||||
lite::DataInfo data_info;
|
||||
if (lstm_weight_node->isa<Parameter>()) {
|
||||
auto ret = FetchDataFromParameterNode(lstm_cnode, kWeightIndex, converter::FmkType_MS, false, &data_info);
|
||||
auto ret = FetchDataFromParameterNode(lstm_cnode, kWeightIndex, converter::kFmkTypeMs, false, &data_info);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse const node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (lstm_weight_node->isa<ValueNode>()) {
|
||||
auto ret = FetchDataFromValueNode(lstm_cnode, kWeightIndex, converter::FmkType_MS, false, &data_info);
|
||||
auto ret = FetchDataFromValueNode(lstm_cnode, kWeightIndex, converter::kFmkTypeMs, false, &data_info);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse const node failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
#include "ops/transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::converter::FmkType_ONNX;
|
||||
using mindspore::converter::kFmkTypeOnnx;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -79,10 +79,10 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
static auto root_func_manager = Manage(res_graph_);
|
||||
for (auto &subgraph : all_subgraphs_) {
|
||||
subgraph->set_manager(root_func_manager);
|
||||
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
|
||||
}
|
||||
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
|
@ -95,7 +95,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_ONNX, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeOnnx, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -118,7 +118,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
|||
}
|
||||
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
|
||||
onnx_root_graph_ = onnx_model_.graph();
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
|
||||
|
@ -1253,6 +1253,6 @@ int OnnxModelParser::Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graph
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_ONNX, converter::LiteModelParserCreator<OnnxModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator<OnnxModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,7 +133,7 @@ STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::
|
|||
}
|
||||
|
||||
FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
|
||||
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::FmkType_TF);
|
||||
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::kFmkTypeTf);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "new graph Partial Node return nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -367,7 +367,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
|||
|
||||
STATUS FunctionalizeWhile::BuildCondGraph() {
|
||||
cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
|
||||
cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, converter::FmkType_TF);
|
||||
cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, converter::kFmkTypeTf);
|
||||
if (cond_sub_func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -522,7 +522,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
|||
|
||||
STATUS FunctionalizeWhile::BuildBodyGraph() {
|
||||
body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
|
||||
body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, converter::FmkType_TF);
|
||||
body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, converter::kFmkTypeTf);
|
||||
if (body_sub_func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::converter::FmkType_TF;
|
||||
using mindspore::converter::kFmkTypeTf;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -515,7 +515,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
return nullptr;
|
||||
}
|
||||
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
|
||||
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
||||
auto &node_def = tf_root_graph_->node(i);
|
||||
|
@ -576,7 +576,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TF, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTf, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -727,7 +727,7 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|||
|
||||
FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
|
||||
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
|
||||
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
||||
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
|
||||
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
|
||||
|
||||
|
@ -1136,6 +1136,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_TF, converter::LiteModelParserCreator<TFModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator<TFModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::converter::FmkType_TFLITE;
|
||||
using mindspore::converter::kFmkTypeTflite;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
|
@ -69,7 +69,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
return nullptr;
|
||||
}
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTflite)));
|
||||
|
||||
auto status = ConvertGraphInputs();
|
||||
if (status != RET_OK) {
|
||||
|
@ -105,7 +105,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TFLITE, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTflite, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -546,5 +546,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_TFLITE, converter::LiteModelParserCreator<TfliteModelParser>)
|
||||
REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator<TfliteModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -176,13 +176,13 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
|
|||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (fmk_type_ == converter::FmkType_TFLITE) {
|
||||
if (fmk_type_ == converter::kFmkTypeTflite) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
} else if (fmk_type_ == converter::FmkType_TF) {
|
||||
} else if (fmk_type_ == converter::kFmkTypeTf) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
|
@ -193,7 +193,7 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
|
|||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
if (fmk_type_ == converter::kFmkTypeOnnx && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -213,10 +213,10 @@ void UnifyFormatToNHWC::SetSensitiveOps() {
|
|||
|
||||
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
|
||||
if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
|
||||
return false;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::FmkType_ONNX &&
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::kFmkTypeOnnx &&
|
||||
shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
|
||||
return false;
|
||||
}
|
||||
|
@ -230,11 +230,11 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
|
|||
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
|
||||
*dst_format = schema::Format_KHWC;
|
||||
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::QuantType, schema::Format *)>>
|
||||
decide_functions = {{converter::FmkType_MS, DecideMINDIRConvWeightSrcFormat},
|
||||
{converter::FmkType_TF, DecideTFConvWeightSrcFormat},
|
||||
{converter::FmkType_TFLITE, DecideTFLITEConvWeightSrcFormat},
|
||||
{converter::FmkType_CAFFE, DecideCAFFEConvWeightSrcFormat},
|
||||
{converter::FmkType_ONNX, DecideONNXConvWeightSrcFormat}};
|
||||
decide_functions = {{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
|
||||
auto iter = decide_functions.find(fmk_type_);
|
||||
if (iter == decide_functions.end()) {
|
||||
MS_LOG(ERROR) << "current fmk don't support, please check.";
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class UnifyFormatToNHWC : public opt::ToFormatBase {
|
||||
public:
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE)
|
||||
: ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {}
|
||||
~UnifyFormatToNHWC() override = default;
|
||||
|
|
|
@ -25,7 +25,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
|
|||
} // namespace
|
||||
|
||||
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
|
||||
if (fmk < converter::FmkType_TF || fmk > converter::FmkType_TFLITE) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
|
||||
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -136,9 +136,9 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
|
|||
lite::DataInfo data_info;
|
||||
int status;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch transpose perm data failed.";
|
||||
|
|
|
@ -172,8 +172,8 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
|
|||
}
|
||||
auto splited_axis = split_info->axis;
|
||||
// need to check
|
||||
if (split_info->fmk_type == FmkType::FmkType_CAFFE ||
|
||||
split_info->fmk_type == FmkType::FmkType_ONNX) { // NHWC -> NCHW
|
||||
if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
|
||||
split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW
|
||||
splited_axis += 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ToFormatBase : public Pass {
|
||||
public:
|
||||
explicit ToFormatBase(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
|
||||
explicit ToFormatBase(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
std::string pass_name = "to_format_base")
|
||||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ToFormatBase() override = default;
|
||||
|
@ -56,7 +56,7 @@ class ToFormatBase : public Pass {
|
|||
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
|
||||
virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) = 0;
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
mindspore::Format format_{mindspore::NHWC};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ToNCHWFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNCHWFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
explicit ToNCHWFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nchw_format") {
|
||||
format_ = mindspore::NCHW;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ToNHWCFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNHWCFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
explicit ToNHWCFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {}
|
||||
~ToNHWCFormat() = default;
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ConstFoldPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::FmkType_MS, bool multigraph = true)
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool multigraph = true)
|
||||
: PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) {
|
||||
context_ = std::make_shared<lite::InnerContext>();
|
||||
context_->Init();
|
||||
|
@ -41,7 +41,7 @@ class ConstFoldPass : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
converter::FmkType fmk_type_{converter::FmkType_MS};
|
||||
converter::FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
std::shared_ptr<lite::InnerContext> context_{nullptr};
|
||||
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
|
||||
};
|
||||
|
|
|
@ -37,7 +37,7 @@ class ConvTransformFusion : public PatternProcessPass {
|
|||
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = converter::FmkType_TF;
|
||||
FmkType fmk_type_ = converter::kFmkTypeTf;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
|
|
@ -199,7 +199,7 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
|
|||
*after_fg = std::make_shared<FuncGraph>();
|
||||
auto manager = main_fg->manager();
|
||||
manager->AddFuncGraph(*after_fg);
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
(*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
|
||||
(*after_fg)->set_manager(main_fg->manager());
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class DecreaseTransposeAlgo : public Pass {
|
||||
public:
|
||||
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::FmkType_MS, bool train_flag = false)
|
||||
explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("DecreaseTransposeAlgo"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~DecreaseTransposeAlgo() override = default;
|
||||
void Init(FmkType fmk_type, bool train_flag) {
|
||||
|
@ -63,7 +63,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
TransposeStrategy transpose_strategy_;
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class InferShapePass : public Pass {
|
||||
public:
|
||||
explicit InferShapePass(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
explicit InferShapePass(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InferShapePass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
@ -40,7 +40,7 @@ class InferShapePass : public Pass {
|
|||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
std::map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
|
|
|
@ -45,7 +45,7 @@ void FreeTensors(std::vector<lite::Tensor *> *tensors) {
|
|||
|
||||
void RectifyFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (fmk_type != converter::FmkType_ONNX) {
|
||||
if (fmk_type != converter::kFmkTypeOnnx) {
|
||||
return;
|
||||
}
|
||||
for (auto &input : inputs) {
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class NodeInferShape {
|
||||
public:
|
||||
explicit NodeInferShape(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
explicit NodeInferShape(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
virtual ~NodeInferShape() = default;
|
||||
void Init(FmkType fmk_type, bool train_flag) {
|
||||
|
@ -54,7 +54,7 @@ class NodeInferShape {
|
|||
STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs, int status);
|
||||
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
|
||||
abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor);
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -261,13 +261,13 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
|||
auto padding_node = cnode->input(kInputIndexTwo);
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<Parameter>(padding_node)) {
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::FmkType_MS, false, &data_info);
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (utils::isa<ValueNode>(padding_node)) {
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::FmkType_MS, false, &data_info);
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -1411,7 +1411,7 @@ bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slic
|
|||
}
|
||||
|
||||
bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
|
||||
if (fmk_type != converter::FmkType_TF && fmk_type != converter::FmkType_TFLITE) {
|
||||
if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
|
||||
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ class SlicePreposePass : public Pass {
|
|||
static bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices);
|
||||
|
||||
private:
|
||||
FmkType fmk_type = converter::FmkType_ONNX;
|
||||
FmkType fmk_type = converter::kFmkTypeOnnx;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class TransposeStrategy {
|
|||
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type);
|
||||
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
};
|
||||
|
|
|
@ -22,7 +22,7 @@ constexpr size_t kCastInputNum = 3;
|
|||
void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
||||
|
||||
bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != converter::FmkType_MS) {
|
||||
if (this->fmk_type != converter::kFmkTypeMs) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ class RemoveUnusedCastOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = converter::FmkType_TF;
|
||||
FmkType fmk_type = converter::kFmkTypeTf;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_
|
||||
|
|
|
@ -57,7 +57,7 @@ std::vector<int> GetTransposePerm(const CNodePtr &node) {
|
|||
}
|
||||
|
||||
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != converter::FmkType_ONNX) {
|
||||
if (this->fmk_type != converter::kFmkTypeOnnx) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be onnx.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ class RemoveUnusedTransposeOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = converter::FmkType_TF;
|
||||
FmkType fmk_type = converter::kFmkTypeTf;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
|
||||
|
|
|
@ -30,7 +30,7 @@ constexpr int kAnfPopulaterInputNumTwo = 2;
|
|||
|
||||
lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (fmk_type_ != converter::FmkType_TF) {
|
||||
if (fmk_type_ != converter::kFmkTypeTf) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0));
|
||||
|
|
|
@ -33,7 +33,7 @@ class UpdateConv2DParamPass : public Pass {
|
|||
void SetFmkType(FmkType fmk_type) { this->fmk_type_ = fmk_type; }
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = converter::FmkType_ONNX;
|
||||
FmkType fmk_type_ = converter::kFmkTypeOnnx;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
|
||||
|
|
Loading…
Reference in New Issue