forked from mindspore-Ecosystem/mindspore
!21777 [lite]change namespace mindspore::lite::converter to mindspore::converter
Merge pull request !21777 from 徐安越/master
This commit is contained in:
commit
b704f9e502
|
@ -17,34 +17,17 @@
|
|||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/lite_utils.h"
|
||||
#include "include/registry/framework.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace converter {
|
||||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
struct MS_API ConverterParameters {
|
||||
FmkType fmk;
|
||||
schema::QuantType quant_type;
|
||||
std::string model_file;
|
||||
std::string weight_file;
|
||||
std::map<std::string, std::string> attrs;
|
||||
};
|
||||
} // namespace converter
|
||||
|
||||
/// \brief ModelParser defined a model parser
|
||||
class MS_API ModelParser;
|
||||
|
||||
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
||||
typedef ModelParser *(*ModelParserCreator)();
|
||||
|
||||
namespace registry {
|
||||
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
||||
typedef converter::ModelParser *(*ModelParserCreator)();
|
||||
|
||||
/// \brief ModelParserRegistry defined registration and storage of ModelParser.
|
||||
class MS_API ModelParserRegistry {
|
||||
public:
|
||||
|
@ -62,7 +45,7 @@ class MS_API ModelParserRegistry {
|
|||
/// \param[in] fmk Define identification of a certain framework.
|
||||
///
|
||||
/// \return Pointer of ModelParser.
|
||||
static ModelParser *GetModelParser(FmkType fmk);
|
||||
static converter::ModelParser *GetModelParser(FmkType fmk);
|
||||
};
|
||||
|
||||
/// \brief Defined registering macro to register ModelParser, which called by user directly.
|
||||
|
|
|
@ -14,13 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_PARSER_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_PARSER_CONTEXT_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "include/lite_utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace converter {
|
||||
/// \brief FmkType defined frameworks which converter tool supports.
|
||||
enum MS_API FmkType : int {
|
||||
|
@ -30,7 +32,19 @@ enum MS_API FmkType : int {
|
|||
FmkType_MS = 3,
|
||||
FmkType_TFLITE = 4,
|
||||
};
|
||||
|
||||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
struct MS_API ConverterParameters {
|
||||
FmkType fmk;
|
||||
schema::QuantType quant_type;
|
||||
std::string model_file;
|
||||
std::string weight_file;
|
||||
std::map<std::string, std::string> attrs;
|
||||
};
|
||||
|
||||
/// \brief ModelParser defined a model parser
|
||||
class MS_API ModelParser;
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PARSER_CONTEXT_H_
|
|
@ -19,8 +19,8 @@
|
|||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::ConverterParameters;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::ConverterParameters;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "include/registry/model_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) {
|
||||
FuncGraphPtr ModelParserTest::Parse(const converter::ConverterParameters &flag) {
|
||||
// construct funcgraph
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto ret = InitOriginModelStructure();
|
||||
|
@ -160,7 +160,7 @@ int ModelParserTest::BuildGraphOutputs() {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::ModelParser *TestModelParserCreator() {
|
||||
converter::ModelParser *TestModelParserCreator() {
|
||||
auto *model_parser = new (std::nothrow) ModelParserTest();
|
||||
if (model_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
|
|
|
@ -25,10 +25,10 @@
|
|||
#include "tools/converter/model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelParserTest : public lite::ModelParser {
|
||||
class ModelParserTest : public converter::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
int InitOriginModelStructure();
|
||||
|
@ -40,7 +40,7 @@ class ModelParserTest : public lite::ModelParser {
|
|||
std::vector<std::string> model_structure_;
|
||||
};
|
||||
|
||||
lite::ModelParser *TestModelParserCreator();
|
||||
converter::ModelParser *TestModelParserCreator();
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
|
||||
|
|
|
@ -29,8 +29,8 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
using mindspore::lite::converter::ConverterParameters;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::ConverterParameters;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
using mindspore::lite::registry::POSITION_BEGIN;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
|
|
|
@ -78,11 +78,11 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap
|
|||
int GetFormatByFmk(int32_t fmk_type) {
|
||||
switch (fmk_type) {
|
||||
case converter::FmkType_ONNX:
|
||||
case lite::converter::FmkType_CAFFE:
|
||||
case lite::converter::FmkType_MS:
|
||||
case converter::FmkType_CAFFE:
|
||||
case converter::FmkType_MS:
|
||||
return mindspore::NCHW;
|
||||
case lite::converter::FmkType_TF:
|
||||
case lite::converter::FmkType_TFLITE:
|
||||
case converter::FmkType_TF:
|
||||
case converter::FmkType_TFLITE:
|
||||
return mindspore::NHWC;
|
||||
default:
|
||||
return -1;
|
||||
|
|
|
@ -114,7 +114,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
|
||||
}
|
||||
if (config->fmk == lite::converter::FmkType_MS) {
|
||||
if (config->fmk == converter::FmkType_MS) {
|
||||
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
|
||||
if (remove_unused_cast_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
|
||||
|
@ -194,8 +194,8 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
|
|||
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
|
||||
config->fmk == lite::converter::FmkType_ONNX) {
|
||||
if (config->fmk == converter::FmkType_TFLITE || config->fmk == converter::FmkType_TF ||
|
||||
config->fmk == converter::FmkType_ONNX) {
|
||||
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
|
||||
}
|
||||
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
|
||||
|
|
|
@ -39,7 +39,7 @@ class Converter {
|
|||
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
|
||||
|
||||
protected:
|
||||
ModelParser *model_parser_ = nullptr;
|
||||
converter::ModelParser *model_parser_ = nullptr;
|
||||
std::unique_ptr<GraphDefTransform> metagraph_transform_ = std::make_unique<GraphDefTransform>();
|
||||
std::unique_ptr<AnfTransform> funcgraph_transform_ = std::make_unique<AnfTransform>();
|
||||
};
|
||||
|
|
|
@ -27,8 +27,9 @@
|
|||
#include "tools/converter/converter_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace converter {
|
||||
using mindspore::lite::RET_INPUT_PARAM_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
namespace {
|
||||
constexpr int kBase = 10;
|
||||
constexpr int kQuantBitNumInt16 = 16;
|
||||
|
@ -168,11 +169,11 @@ int Flags::QuantParamInputCheck() {
|
|||
|
||||
int Flags::InitQuantParam() {
|
||||
if (this->quantTypeStr == "WeightQuant") {
|
||||
this->quantType = QuantType_WeightQuant;
|
||||
this->quantType = schema::QuantType_WeightQuant;
|
||||
} else if (this->quantTypeStr == "PostTraining") {
|
||||
this->quantType = QuantType_PostTraining;
|
||||
this->quantType = schema::QuantType_PostTraining;
|
||||
} else if (this->quantTypeStr.empty()) {
|
||||
this->quantType = QuantType_QUANT_NONE;
|
||||
this->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
@ -212,10 +213,10 @@ int Flags::InitTrainModel() {
|
|||
int Flags::InitInTensorShape() {
|
||||
std::string content = this->inTensorShape;
|
||||
std::vector<int64_t> shape;
|
||||
auto shape_strs = StrSplit(content, std::string(";"));
|
||||
auto shape_strs = lite::StrSplit(content, std::string(";"));
|
||||
for (const auto &shape_str : shape_strs) {
|
||||
shape.clear();
|
||||
auto string_split = StrSplit(shape_str, std::string(":"));
|
||||
auto string_split = lite::StrSplit(shape_str, std::string(":"));
|
||||
auto name = string_split[0];
|
||||
if (name.empty()) {
|
||||
MS_LOG(ERROR) << "input tensor name is empty";
|
||||
|
@ -224,19 +225,19 @@ int Flags::InitInTensorShape() {
|
|||
if (dim_strs.empty()) {
|
||||
MS_LOG(ERROR) << "input tensor dim string is empty";
|
||||
}
|
||||
auto dims = StrSplit(dim_strs, std::string(","));
|
||||
auto dims = lite::StrSplit(dim_strs, std::string(","));
|
||||
if (dims.empty()) {
|
||||
MS_LOG(ERROR) << "input tensor dim is empty";
|
||||
}
|
||||
for (const auto &dim : dims) {
|
||||
if (std::stoi(dim) < 0) {
|
||||
MS_LOG(ERROR) << "Unsupported dim < 0.";
|
||||
return RET_ERROR;
|
||||
return lite::RET_ERROR;
|
||||
} else {
|
||||
shape.push_back(std::stoi(dim));
|
||||
}
|
||||
}
|
||||
ConverterContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
|
||||
lite::ConverterContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -247,7 +248,7 @@ int Flags::InitConfigFile() {
|
|||
const char *delimiter = ";";
|
||||
auto relative_path = SplitStringToVector(plugins_path_str, *delimiter);
|
||||
for (size_t i = 0; i < relative_path.size(); i++) {
|
||||
this->pluginsPath.push_back(RealPath(relative_path[i].c_str()));
|
||||
this->pluginsPath.push_back(lite::RealPath(relative_path[i].c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -271,9 +272,9 @@ int Flags::Init(int argc, const char **argv) {
|
|||
int ret;
|
||||
if (argc == 1) {
|
||||
std::cout << this->Usage() << std::endl;
|
||||
return RET_SUCCESS_EXIT;
|
||||
return lite::RET_SUCCESS_EXIT;
|
||||
}
|
||||
Option<std::string> err = this->ParseFlags(argc, argv);
|
||||
lite::Option<std::string> err = this->ParseFlags(argc, argv);
|
||||
|
||||
if (err.IsSome()) {
|
||||
std::cerr << err.Get() << std::endl;
|
||||
|
@ -283,7 +284,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
|
||||
if (this->help) {
|
||||
std::cout << this->Usage() << std::endl;
|
||||
return RET_SUCCESS_EXIT;
|
||||
return lite::RET_SUCCESS_EXIT;
|
||||
}
|
||||
if (this->modelFile.empty()) {
|
||||
std::cerr << "INPUT MISSING: model file path is necessary" << std::endl;
|
||||
|
@ -488,5 +489,4 @@ std::vector<std::string> SplitStringToVector(const std::string &raw_str, const c
|
|||
return res;
|
||||
}
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,19 +19,14 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/framework.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using mindspore::schema::QuantType;
|
||||
using mindspore::schema::QuantType_AwareTraining;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
namespace converter {
|
||||
using mindspore::schema::QuantType;
|
||||
enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
|
||||
constexpr auto kMaxSplitRatio = 10;
|
||||
constexpr auto kComputeRate = "computeRate";
|
||||
|
@ -106,7 +101,6 @@ std::string GetStrFromConfigFile(const std::string &file, const std::string &tar
|
|||
|
||||
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
||||
|
|
|
@ -199,8 +199,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
|
|||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
if (flags->fmk == lite::converter::FmkType_TFLITE || flags->fmk == lite::converter::FmkType_TF ||
|
||||
flags->fmk == lite::converter::FmkType_ONNX) {
|
||||
if (flags->fmk == converter::FmkType_TFLITE || flags->fmk == converter::FmkType_TF ||
|
||||
flags->fmk == converter::FmkType_ONNX) {
|
||||
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
|
||||
}
|
||||
optimizer->AddPassManager(graph_pm);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void ExportModelInit(lite::converter::Flags *flag);
|
||||
void ExportModelInit(converter::Flags *flag);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -257,7 +257,7 @@ int MindirAdjust::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
|
|||
}
|
||||
|
||||
bool MindirAdjust::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type_ != lite::converter::FmkType_MS) {
|
||||
if (this->fmk_type_ != converter::FmkType_MS) {
|
||||
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
using mindspore::schema::QuantType;
|
||||
namespace mindspore::lite {
|
||||
class MindirAdjust {
|
||||
|
|
|
@ -119,7 +119,7 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_MS, flag.trainModel, flag.quantType);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_MS, flag.trainModel, flag.quantType);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -31,7 +31,7 @@ class MindsporeImporter {
|
|||
|
||||
private:
|
||||
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||
};
|
||||
|
||||
|
|
|
@ -519,7 +519,7 @@ int MoveAttrMapResizeGrad(const CNodePtr &cnode) {
|
|||
} // namespace
|
||||
|
||||
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
|
||||
if (this->fmk_type_ != lite::converter::FmkType_MS) {
|
||||
if (this->fmk_type_ != converter::FmkType_MS) {
|
||||
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
typedef int (*PrimitiveAdjustCreator)(const CNodePtr &value_node);
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
using mindspore::converter::FmkType_TF;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
using mindspore::converter::FmkType_TF;
|
||||
using mindspore::schema::TensorT;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -59,7 +59,7 @@ class InferShapePass : public GraphPass {
|
|||
void InitInferTensor(MetaGraphT *graph);
|
||||
int InferSubgraph(const int &subgraph_index, MetaGraphT *graph);
|
||||
|
||||
lite::converter::FmkType fmk_type_ = FmkType_TF;
|
||||
converter::FmkType fmk_type_ = FmkType_TF;
|
||||
std::vector<InferTensor> tensors_ = {};
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -25,8 +25,7 @@
|
|||
#include "include/registry/model_parser_registry.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using namespace schema;
|
||||
namespace mindspore::converter {
|
||||
class ModelParser {
|
||||
public:
|
||||
ModelParser() = default;
|
||||
|
@ -50,6 +49,6 @@ ModelParser *LiteModelParserCreator() {
|
|||
}
|
||||
return parser;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace mindspore::converter
|
||||
|
||||
#endif
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
namespace {
|
||||
|
@ -112,7 +112,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_CAFFE, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_CAFFE, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -555,5 +555,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
|||
}
|
||||
return layer.name();
|
||||
}
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, LiteModelParserCreator<CaffeModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, converter::LiteModelParserCreator<CaffeModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
using STATUS = int;
|
||||
namespace mindspore::lite {
|
||||
class CaffeModelParser : public ModelParser {
|
||||
class CaffeModelParser : public converter::ModelParser {
|
||||
public:
|
||||
CaffeModelParser();
|
||||
|
||||
|
@ -66,7 +66,7 @@ class CaffeModelParser : public ModelParser {
|
|||
caffe::NetParameter caffe_weight_;
|
||||
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
|
||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::lite {
|
||||
class OnnxInputAdjust {
|
||||
public:
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
#include "ops/transpose.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
using mindspore::converter::FmkType_ONNX;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -95,7 +95,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_ONNX, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_ONNX, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -253,7 +253,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|||
continue;
|
||||
}
|
||||
if (primitive_c->GetAttr(ops::kFormat) == nullptr) {
|
||||
primitive_c->AddAttr(mindspore::ops::kFormat, MakeValue<int64_t>(Format_NCHW));
|
||||
primitive_c->AddAttr(mindspore::ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
|
||||
}
|
||||
status = ConvertOpQuantParams(onnx_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
|
@ -1253,6 +1253,6 @@ int OnnxModelParser::Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graph
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_ONNX, LiteModelParserCreator<OnnxModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_ONNX, converter::LiteModelParserCreator<OnnxModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxModelParser : public ModelParser {
|
||||
class OnnxModelParser : public converter::ModelParser {
|
||||
public:
|
||||
OnnxModelParser() = default;
|
||||
|
||||
|
@ -99,7 +99,7 @@ class OnnxModelParser : public ModelParser {
|
|||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
||||
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,7 +133,7 @@ STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::
|
|||
}
|
||||
|
||||
FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
|
||||
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF);
|
||||
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, converter::FmkType_TF);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "new graph Partial Node return nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
|
||||
typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType;
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
using AimFunc = std::function<bool(const AnfNodePtr &)>;
|
||||
class FunctionalizeControlOpPass : public Pass {
|
||||
|
|
|
@ -367,8 +367,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
|||
|
||||
STATUS FunctionalizeWhile::BuildCondGraph() {
|
||||
cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
|
||||
cond_sub_func_graph_ =
|
||||
FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF);
|
||||
cond_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, converter::FmkType_TF);
|
||||
if (cond_sub_func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -523,8 +522,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
|||
|
||||
STATUS FunctionalizeWhile::BuildBodyGraph() {
|
||||
body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
|
||||
body_sub_func_graph_ =
|
||||
FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF);
|
||||
body_sub_func_graph_ = FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, converter::FmkType_TF);
|
||||
if (body_sub_func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
|
||||
constexpr const int POS_INVALID = -1;
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
using mindspore::converter::FmkType_TF;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -576,7 +576,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TF, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TF, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -1122,6 +1122,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_TF, LiteModelParserCreator<TFModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_TF, converter::LiteModelParserCreator<TFModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFModelParser : public ModelParser {
|
||||
class TFModelParser : public converter::ModelParser {
|
||||
public:
|
||||
TFModelParser() = default;
|
||||
~TFModelParser() override = default;
|
||||
|
@ -106,7 +106,7 @@ class TFModelParser : public ModelParser {
|
|||
std::vector<std::string> while_cond_branch_name_;
|
||||
std::vector<std::string> if_then_branch_name_;
|
||||
std::unordered_map<std::string, int> node_output_num_;
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
using mindspore::converter::FmkType_TFLITE;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
|
@ -105,7 +105,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TFLITE, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::FmkType_TFLITE, false, quant_type_);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -546,5 +546,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(FmkType_TFLITE, LiteModelParserCreator<TfliteModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_TFLITE, converter::LiteModelParserCreator<TfliteModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteModelParser : public ModelParser {
|
||||
class TfliteModelParser : public converter::ModelParser {
|
||||
public:
|
||||
TfliteModelParser() = default;
|
||||
|
||||
|
@ -52,7 +52,7 @@ class TfliteModelParser : public ModelParser {
|
|||
STATUS ConvertGraphOutputs();
|
||||
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
|
||||
int round_type = 1);
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,10 +50,10 @@ STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quan
|
|||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
switch (quant_type) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
case schema::QuantType_AwareTraining:
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_HWCK;
|
||||
|
@ -85,10 +85,10 @@ STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType
|
|||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
switch (quant_type) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
case schema::QuantType_AwareTraining:
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
|
@ -127,7 +127,7 @@ STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType qu
|
|||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
switch (quant_type) {
|
||||
case QuantType_AwareTraining: {
|
||||
case schema::QuantType_AwareTraining: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
|
@ -141,9 +141,9 @@ STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType qu
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (format == schema::Format_NHWC) {
|
||||
|
@ -176,13 +176,13 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
|
|||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
if (fmk_type_ == converter::FmkType_TFLITE) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||
} else if (fmk_type_ == converter::FmkType_TF) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
|
@ -193,7 +193,7 @@ STATUS UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::Tra
|
|||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
if (fmk_type_ == converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_g
|
|||
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
|
||||
return false;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX &&
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == converter::FmkType_ONNX &&
|
||||
shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class UnifyFormatToNHWC : public opt::ToFormatBase {
|
||||
public:
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false,
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE)
|
||||
: ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {}
|
||||
~UnifyFormatToNHWC() override = default;
|
||||
|
|
|
@ -53,7 +53,7 @@ class Quantizer {
|
|||
|
||||
virtual STATUS DoQuantize(FuncGraphPtr func_graph) = 0;
|
||||
|
||||
mindspore::lite::converter::Flags flags;
|
||||
converter::Flags flags;
|
||||
|
||||
protected:
|
||||
FuncGraphPtr funcGraph = nullptr;
|
||||
|
|
|
@ -33,7 +33,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator
|
|||
model_parser_room[fmk] = creator;
|
||||
}
|
||||
|
||||
ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
|
||||
converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
|
||||
auto it = model_parser_room.find(fmk);
|
||||
if (it != model_parser_room.end()) {
|
||||
auto creator = it->second;
|
||||
|
|
|
@ -136,9 +136,9 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
|
|||
lite::DataInfo data_info;
|
||||
int status;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, lite::converter::FmkType_MS, false, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, lite::converter::FmkType_MS, false, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::FmkType_MS, false, &data_info);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch transpose perm data failed.";
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -27,12 +27,12 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ToFormatBase : public Pass {
|
||||
public:
|
||||
explicit ToFormatBase(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false,
|
||||
explicit ToFormatBase(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false,
|
||||
std::string pass_name = "to_format_base")
|
||||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ToFormatBase() override = default;
|
||||
|
@ -56,7 +56,7 @@ class ToFormatBase : public Pass {
|
|||
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
|
||||
virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) = 0;
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
mindspore::Format format_{mindspore::NHWC};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ToNCHWFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNCHWFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
explicit ToNCHWFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nchw_format") {
|
||||
format_ = mindspore::NCHW;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ToNHWCFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNHWCFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
explicit ToNHWCFormat(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {}
|
||||
~ToNHWCFormat() = default;
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *out
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &cnode, lite::converter::FmkType fmk_type) {
|
||||
std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &cnode, converter::FmkType fmk_type) {
|
||||
MS_ASSERT(CNode != nullptr);
|
||||
std::vector<Tensor *> tensors;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ConstFoldPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConstFoldPass(lite::converter::FmkType fmk_type = lite::converter::FmkType_MS, bool multigraph = true)
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::FmkType_MS, bool multigraph = true)
|
||||
: PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) {
|
||||
context_ = std::make_shared<lite::InnerContext>();
|
||||
context_->Init();
|
||||
|
@ -41,7 +41,7 @@ class ConstFoldPass : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
lite::converter::FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
converter::FmkType fmk_type_{converter::FmkType_MS};
|
||||
std::shared_ptr<lite::InnerContext> context_{nullptr};
|
||||
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
|
||||
};
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class ConvTransformFusion : public PatternProcessPass {
|
||||
public:
|
||||
|
@ -37,7 +37,7 @@ class ConvTransformFusion : public PatternProcessPass {
|
|||
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = lite::converter::FmkType_TF;
|
||||
FmkType fmk_type_ = converter::FmkType_TF;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
using mindspore::schema::QuantType;
|
||||
namespace mindspore::opt {
|
||||
class ClipConvertActivationPass : public Pass {
|
||||
|
|
|
@ -199,7 +199,7 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
|
|||
*after_fg = std::make_shared<FuncGraph>();
|
||||
auto manager = main_fg->manager();
|
||||
manager->AddFuncGraph(*after_fg);
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(lite::converter::FmkType_TF)));
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
||||
(*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
|
||||
(*after_fg)->set_manager(main_fg->manager());
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/transpose_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DecreaseTransposeAlgo : public Pass {
|
||||
|
@ -62,7 +62,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
TransposeStrategy transpose_strategy_;
|
||||
|
|
|
@ -105,7 +105,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class InferShapePass : public Pass {
|
||||
public:
|
||||
explicit InferShapePass(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
explicit InferShapePass(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
: Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InferShapePass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
@ -40,7 +40,7 @@ class InferShapePass : public Pass {
|
|||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
std::map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
|
|
|
@ -45,7 +45,7 @@ void FreeTensors(std::vector<lite::Tensor *> *tensors) {
|
|||
|
||||
void RectifyFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (fmk_type != lite::converter::FmkType_ONNX) {
|
||||
if (fmk_type != converter::FmkType_ONNX) {
|
||||
return;
|
||||
}
|
||||
for (auto &input : inputs) {
|
||||
|
|
|
@ -27,12 +27,12 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class NodeInferShape {
|
||||
public:
|
||||
explicit NodeInferShape(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false)
|
||||
explicit NodeInferShape(FmkType fmk_type = converter::FmkType_MS, bool train_flag = false)
|
||||
: fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
virtual ~NodeInferShape() = default;
|
||||
void Init(FmkType fmk_type, bool train_flag) {
|
||||
|
@ -54,7 +54,7 @@ class NodeInferShape {
|
|||
STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs, int status);
|
||||
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
|
||||
abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/transpose_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReduceSameActPass : public Pass {
|
||||
|
|
|
@ -261,13 +261,13 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
|||
auto padding_node = cnode->input(kInputIndexTwo);
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<Parameter>(padding_node)) {
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info);
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::FmkType_MS, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (utils::isa<ValueNode>(padding_node)) {
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info);
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::FmkType_MS, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class RemoveRedundantOpPass : public Pass {
|
||||
public:
|
||||
|
|
|
@ -1411,7 +1411,7 @@ bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slic
|
|||
}
|
||||
|
||||
bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
|
||||
if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
|
||||
if (fmk_type != converter::FmkType_TF && fmk_type != converter::FmkType_TFLITE) {
|
||||
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "mindspore/core/ir/manager.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
using lite::RET_ERROR;
|
||||
using lite::RET_OK;
|
||||
|
@ -95,7 +95,7 @@ class SlicePreposePass : public Pass {
|
|||
static bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices);
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_ONNX;
|
||||
FmkType fmk_type = converter::FmkType_ONNX;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/transpose_strategy.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitOnePass : public Pass {
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/node_infershape.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TransposeStrategy {
|
||||
|
@ -58,7 +58,7 @@ class TransposeStrategy {
|
|||
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type);
|
||||
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
FmkType fmk_type_{converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
};
|
||||
|
|
|
@ -22,7 +22,7 @@ constexpr size_t kCastInputNum = 3;
|
|||
void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
||||
|
||||
bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != lite::converter::FmkType_MS) {
|
||||
if (this->fmk_type != converter::FmkType_MS) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class RemoveUnusedCastOpPass : public Pass {
|
||||
public:
|
||||
|
@ -30,7 +30,7 @@ class RemoveUnusedCastOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_TF;
|
||||
FmkType fmk_type = converter::FmkType_TF;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_
|
||||
|
|
|
@ -57,7 +57,7 @@ std::vector<int> GetTransposePerm(const CNodePtr &node) {
|
|||
}
|
||||
|
||||
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != lite::converter::FmkType_ONNX) {
|
||||
if (this->fmk_type != converter::FmkType_ONNX) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be onnx.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class RemoveUnusedTransposeOpPass : public Pass {
|
||||
public:
|
||||
|
@ -30,7 +30,7 @@ class RemoveUnusedTransposeOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_TF;
|
||||
FmkType fmk_type = converter::FmkType_TF;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
|
||||
|
|
|
@ -30,7 +30,7 @@ constexpr int kAnfPopulaterInputNumTwo = 2;
|
|||
|
||||
lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (fmk_type_ != lite::converter::FmkType_TF) {
|
||||
if (fmk_type_ != converter::FmkType_TF) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0));
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class UpdateConv2DParamPass : public Pass {
|
||||
public:
|
||||
|
@ -33,7 +33,7 @@ class UpdateConv2DParamPass : public Pass {
|
|||
void SetFmkType(FmkType fmk_type) { this->fmk_type_ = fmk_type; }
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = lite::converter::FmkType_ONNX;
|
||||
FmkType fmk_type_ = converter::FmkType_ONNX;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::converter::FmkType;
|
||||
using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
Loading…
Reference in New Issue