forked from mindspore-Ecosystem/mindspore
delete exportMindIR para
This commit is contained in:
parent
d79e3079ff
commit
0bd549f328
|
@ -7,7 +7,7 @@ mindspore_lite.ModelType
|
|||
|
||||
适用于以下场景:
|
||||
|
||||
1. Converter时,设置 `export_mindir` 参数, `ModelType` 用于定义转换生成的模型类型。
|
||||
1. Converter时,设置 `save_type` 参数, `ModelType` 用于定义转换生成的模型类型。
|
||||
|
||||
2. Converter之后,当从文件加载或构建模型以进行推理时, `ModelType` 用于定义输入模型框架类型。
|
||||
|
||||
|
|
|
@ -59,8 +59,8 @@ class MS_API Converter {
|
|||
void SetOutputDataType(DataType data_type);
|
||||
DataType GetOutputDataType();
|
||||
|
||||
void SetExportMindIR(ModelType export_mindir);
|
||||
ModelType GetExportMindIR() const;
|
||||
void SetSaveType(ModelType save_type);
|
||||
ModelType GetSaveType() const;
|
||||
|
||||
inline void SetDecryptKey(const std::string &key);
|
||||
inline std::string GetDecryptKey() const;
|
||||
|
|
|
@ -44,7 +44,7 @@ enum MS_API FmkType : int {
|
|||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
struct MS_API ConverterParameters {
|
||||
FmkType fmk;
|
||||
ModelType export_mindir = kMindIR_Lite;
|
||||
ModelType save_type = kMindIR_Lite;
|
||||
std::string model_file;
|
||||
std::string weight_file;
|
||||
std::map<std::string, std::string> attrs;
|
||||
|
|
|
@ -296,7 +296,7 @@ class Converter:
|
|||
if output_data_type != DataType.FLOAT32:
|
||||
self._converter.set_output_data_type(data_type_py_cxx_map.get(output_data_type))
|
||||
if save_type != ModelType.MINDIR_LITE:
|
||||
self._converter.set_export_mindir(model_type_py_cxx_map.get(save_type))
|
||||
self._converter.set_save_type(model_type_py_cxx_map.get(save_type))
|
||||
if decrypt_key != "":
|
||||
self._converter.set_decrypt_key(decrypt_key)
|
||||
self._converter.set_decrypt_mode(decrypt_mode)
|
||||
|
@ -330,7 +330,7 @@ class Converter:
|
|||
f"input_format: {format_cxx_py_map.get(self._converter.get_input_format())},\n" \
|
||||
f"input_data_type: {data_type_cxx_py_map.get(self._converter.get_input_data_type())},\n" \
|
||||
f"output_data_type: {data_type_cxx_py_map.get(self._converter.get_output_data_type())},\n" \
|
||||
f"save_type: {model_type_cxx_py_map.get(self._converter.get_export_mindir())},\n" \
|
||||
f"save_type: {model_type_cxx_py_map.get(self._converter.get_save_type())},\n" \
|
||||
f"decrypt_key: {self._converter.get_decrypt_key()},\n" \
|
||||
f"decrypt_mode: {self._converter.get_decrypt_mode()},\n" \
|
||||
f"enable_encryption: {self._converter.get_enable_encryption()},\n" \
|
||||
|
|
|
@ -48,8 +48,8 @@ void ConverterPyBind(const py::module &m) {
|
|||
.def("get_input_data_type", &Converter::GetInputDataType)
|
||||
.def("set_output_data_type", &Converter::SetOutputDataType)
|
||||
.def("get_output_data_type", &Converter::GetOutputDataType)
|
||||
.def("set_export_mindir", &Converter::SetExportMindIR)
|
||||
.def("get_export_mindir", &Converter::GetExportMindIR)
|
||||
.def("set_save_type", &Converter::SetSaveType)
|
||||
.def("get_save_type", &Converter::GetSaveType)
|
||||
.def("set_decrypt_key", py::overload_cast<const std::string &>(&Converter::SetDecryptKey))
|
||||
.def("get_decrypt_key", &Converter::GetDecryptKey)
|
||||
.def("set_decrypt_mode", py::overload_cast<const std::string &>(&Converter::SetDecryptMode))
|
||||
|
|
|
@ -34,7 +34,7 @@ int RuntimeConvert(const mindspore::api::FuncGraphPtr &graph, const std::shared_
|
|||
param->output_data_type = mindspore::DataType::kTypeUnknown;
|
||||
param->weight_fp16 = false;
|
||||
param->train_model = false;
|
||||
param->export_mindir = mindspore::kMindIR;
|
||||
param->save_type = mindspore::kMindIR;
|
||||
param->enable_encryption = false;
|
||||
param->is_runtime_converter = true;
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ bool LiteRTGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map
|
|||
if (fb_model_buf_ == nullptr) {
|
||||
auto param = std::make_shared<ConverterPara>();
|
||||
param->fmk_type = converter::kFmkTypeMs;
|
||||
param->export_mindir = kMindIR;
|
||||
param->save_type = kMindIR;
|
||||
auto mutable_graph = std::const_pointer_cast<FuncGraph>(graph);
|
||||
meta_graph = lite::ConverterToMetaGraph::Build(param, mutable_graph);
|
||||
if (meta_graph == nullptr) {
|
||||
|
|
|
@ -150,7 +150,7 @@ STATUS PreProcForOnnx(const FuncGraphPtr &func_graph, bool offline) {
|
|||
AclPassImpl::AclPassImpl(const std::shared_ptr<ConverterPara> ¶m)
|
||||
: param_(param),
|
||||
fmk_type_(param->fmk_type),
|
||||
export_mindir_(param->export_mindir),
|
||||
export_mindir_(param->save_type),
|
||||
user_options_cfg_(std::move(param->aclModelOptionCfgParam)),
|
||||
om_parameter_(nullptr),
|
||||
custom_node_(nullptr) {}
|
||||
|
@ -693,7 +693,7 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
|
|||
STATUS AclPassImpl::PreQuantization(const FuncGraphPtr &func_graph) {
|
||||
auto value = func_graph->get_attr(ops::kFormat);
|
||||
if (value == nullptr) {
|
||||
auto unify_format = std::make_shared<lite::UnifyFormatToNHWC>(fmk_type_, false, param_->export_mindir);
|
||||
auto unify_format = std::make_shared<lite::UnifyFormatToNHWC>(fmk_type_, false, param_->save_type);
|
||||
CHECK_NULL_RETURN(unify_format);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -461,7 +461,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
|
|||
}
|
||||
}
|
||||
// adjust for conv2d_transpose
|
||||
if (!(param->no_fusion && param->export_mindir == kMindIR)) {
|
||||
if (!(param->no_fusion && param->save_type == kMindIR)) {
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(old_graph, &all_func_graphs);
|
||||
auto conv2d_transpose_adjust = std::make_shared<Conv2DTransposeInputAdjust>();
|
||||
|
@ -563,7 +563,7 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_pt
|
|||
}
|
||||
|
||||
int AnfTransform::DoFormatForMindIR(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param->export_mindir != kMindIR) {
|
||||
if (param->save_type != kMindIR) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (param->no_fusion || param->device.find("Ascend") == std::string::npos) {
|
||||
|
@ -709,7 +709,7 @@ int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<C
|
|||
STATUS AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
MS_ASSERT(old_graph != nullptr);
|
||||
MS_ASSERT(param != nullptr);
|
||||
if (param->no_fusion && param->export_mindir == kMindIR) { // converter, online
|
||||
if (param->no_fusion && param->save_type == kMindIR) { // converter, online
|
||||
if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Proc online transform failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -622,12 +622,11 @@ int CheckInputOutputDataType(const std::shared_ptr<ConverterPara> ¶m) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int CheckExportMindIR(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
int CheckSaveType(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param != nullptr) {
|
||||
std::set valid_values = {kMindIR, kMindIR_Lite};
|
||||
if (std::find(valid_values.begin(), valid_values.end(), param->export_mindir) == valid_values.end()) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: export_mindir is not in {kMindIR, kMindIR_Lite}, but got "
|
||||
<< param->export_mindir;
|
||||
if (std::find(valid_values.begin(), valid_values.end(), param->save_type) == valid_values.end()) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: save_type is not in {kMindIR, kMindIR_Lite}, but got " << param->save_type;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
@ -722,9 +721,9 @@ int CheckValueParam(const std::shared_ptr<ConverterPara> ¶m) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = CheckExportMindIR(param);
|
||||
ret = CheckSaveType(param);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check value of export_mindir failed.";
|
||||
MS_LOG(ERROR) << "Check value of save_type failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
@ -812,7 +811,7 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, void **m
|
|||
int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data,
|
||||
size_t *data_size, bool not_save) {
|
||||
int status = RET_ERROR;
|
||||
if (param->export_mindir == kMindIR) {
|
||||
if (param->save_type == kMindIR) {
|
||||
status = SaveMindIRModel(graph, param, model_data, data_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save mindir model failed :" << status << " " << GetErrorInfo(status);
|
||||
|
|
|
@ -74,7 +74,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C
|
|||
}
|
||||
converter::ConverterParameters converter_parameters;
|
||||
converter_parameters.fmk = param->fmk_type;
|
||||
converter_parameters.export_mindir = param->export_mindir;
|
||||
converter_parameters.save_type = param->save_type;
|
||||
converter_parameters.model_file = param->model_file;
|
||||
converter_parameters.weight_file = param->weight_file;
|
||||
func_graph_base = model_parser->Parse(converter_parameters);
|
||||
|
@ -196,8 +196,7 @@ STATUS ConverterFuncGraph::UnifyFuncGraphForInfer(const std::shared_ptr<Converte
|
|||
}
|
||||
}
|
||||
|
||||
auto unify_format =
|
||||
std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model, param->export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model, param->save_type);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, RET_NULL_PTR, "unify_format is nullptr.");
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
@ -219,7 +218,7 @@ STATUS ConverterFuncGraph::UnifyFuncGraphInputFormat(const std::shared_ptr<Conve
|
|||
|
||||
auto spec_input_format = param->spec_input_format;
|
||||
if (spec_input_format == DEFAULT_FORMAT) {
|
||||
if (param->export_mindir == kMindIR || param->fmk_type != converter::kFmkTypeMs) {
|
||||
if (param->save_type == kMindIR || param->fmk_type != converter::kFmkTypeMs) {
|
||||
// if it saves to mindir, the input format must be the same as the original model
|
||||
// if it saves to mindir lite, the input format must be the same as the original model for 3rd model
|
||||
func_graph->set_attr(kInputFormat, MakeValue(static_cast<int>(cur_input_format)));
|
||||
|
|
|
@ -83,12 +83,15 @@ Flags::Flags() {
|
|||
"Whether to do pre-inference after convert. "
|
||||
"true | false",
|
||||
"false");
|
||||
AddFlag(&Flags::exportMindIR, "exportMindIR", "MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
AddFlag(&Flags::noFusionStr, "NoFusion",
|
||||
"Avoid fusion optimization true|false. NoFusion is true when saveType is MINDIR.", "");
|
||||
AddFlag(&Flags::device, "device",
|
||||
"Set the target device, support Ascend, Ascend310 and Ascend310P will be deprecated.", "");
|
||||
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
|
||||
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR");
|
||||
#else
|
||||
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
#endif
|
||||
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "general");
|
||||
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
|
||||
}
|
||||
|
@ -266,22 +269,6 @@ int Flags::InitOptimize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitExportMindIR() {
|
||||
// value check not here, it is in converter c++ API's CheckValueParam method.
|
||||
std::map<std::string, ModelType> StrToEnumModelTypeMap = {{"MINDIR", kMindIR}, {"MINDIR_LITE", kMindIR_Lite}};
|
||||
if (StrToEnumModelTypeMap.find(this->exportMindIR) != StrToEnumModelTypeMap.end()) {
|
||||
this->export_mindir = StrToEnumModelTypeMap.at(this->exportMindIR);
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: exportMindIR must be MINDIR|MINDIR_LITE " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if ((this->exportMindIR == "MINDIR") && (this->optimizeTransformer == false)) {
|
||||
this->disableFusion = true;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitEncrypt() {
|
||||
if (this->encryptionStr == "true") {
|
||||
this->encryption = true;
|
||||
|
@ -295,13 +282,9 @@ int Flags::InitEncrypt() {
|
|||
}
|
||||
|
||||
int Flags::InitSaveType() {
|
||||
// For compatibility of interface, the check will be removed when exportMindIR is deleted
|
||||
if (this->exportMindIR == "MINDIR") {
|
||||
return RET_OK;
|
||||
}
|
||||
std::map<std::string, ModelType> StrToEnumModelTypeMap = {{"MINDIR", kMindIR}, {"MINDIR_LITE", kMindIR_Lite}};
|
||||
if (StrToEnumModelTypeMap.find(this->saveTypeStr) != StrToEnumModelTypeMap.end()) {
|
||||
this->export_mindir = StrToEnumModelTypeMap.at(this->saveTypeStr);
|
||||
this->save_type = StrToEnumModelTypeMap.at(this->saveTypeStr);
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: saveType must be MINDIR|MINDIR_LITE " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
@ -407,12 +390,6 @@ int Flags::Init(int argc, const char **argv) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitExportMindIR();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init export mindir failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitSaveType();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init save type failed." << std::endl;
|
||||
|
|
|
@ -41,7 +41,6 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
int InitSaveFP16();
|
||||
int InitNoFusion();
|
||||
int InitOptimize();
|
||||
int InitExportMindIR();
|
||||
int InitSaveType();
|
||||
int InitOptimizeTransformer();
|
||||
|
||||
|
@ -74,10 +73,12 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string encMode = "AES-GCM";
|
||||
std::string inferStr;
|
||||
bool infer = false;
|
||||
std::string exportMindIR;
|
||||
ModelType export_mindir = kMindIR_Lite;
|
||||
std::string saveTypeStr;
|
||||
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
|
||||
ModelType save_type = kMindIR;
|
||||
#else
|
||||
ModelType save_type = kMindIR_Lite;
|
||||
#endif
|
||||
std::string optimizeStr;
|
||||
#ifdef ENABLE_OPENSSL
|
||||
std::string encryptionStr = "true";
|
||||
|
|
|
@ -57,7 +57,7 @@ int main(int argc, const char **argv) {
|
|||
converter.SetInputFormat(flags.graphInputFormat);
|
||||
converter.SetInputDataType(flags.inputDataType);
|
||||
converter.SetOutputDataType(flags.outputDataType);
|
||||
converter.SetExportMindIR(flags.export_mindir);
|
||||
converter.SetSaveType(flags.save_type);
|
||||
converter.SetDecryptKey(flags.dec_key);
|
||||
flags.dec_key.clear();
|
||||
converter.SetDecryptMode(flags.dec_mode);
|
||||
|
|
|
@ -158,15 +158,15 @@ DataType Converter::GetOutputDataType() {
|
|||
}
|
||||
}
|
||||
|
||||
void Converter::SetExportMindIR(ModelType export_mindir) {
|
||||
void Converter::SetSaveType(ModelType save_type) {
|
||||
if (data_ != nullptr) {
|
||||
data_->export_mindir = export_mindir;
|
||||
data_->save_type = save_type;
|
||||
}
|
||||
}
|
||||
|
||||
ModelType Converter::GetExportMindIR() const {
|
||||
ModelType Converter::GetSaveType() const {
|
||||
if (data_ != nullptr) {
|
||||
return data_->export_mindir;
|
||||
return data_->save_type;
|
||||
} else {
|
||||
return kMindIR_Lite;
|
||||
}
|
||||
|
|
|
@ -54,7 +54,11 @@ struct ConverterPara {
|
|||
Format spec_input_format = DEFAULT_FORMAT;
|
||||
DataType input_data_type = DataType::kNumberTypeFloat32;
|
||||
DataType output_data_type = DataType::kNumberTypeFloat32;
|
||||
ModelType export_mindir = kMindIR_Lite;
|
||||
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
|
||||
ModelType save_type = kMindIR;
|
||||
#else
|
||||
ModelType save_type = kMindIR_Lite;
|
||||
#endif
|
||||
std::string decrypt_key;
|
||||
std::string decrypt_mode = "AES-GCM";
|
||||
std::string encrypt_key;
|
||||
|
|
|
@ -169,7 +169,7 @@ api::FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false, flag.export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false, flag.save_type);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -487,7 +487,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::Co
|
|||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimConstant)) {
|
||||
status = ReplaceConstant(func_graph, cnode);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimTranspose) && flag.export_mindir != kMindIR) {
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimTranspose) && flag.save_type != kMindIR) {
|
||||
status = ReplaceTransposeWithGraphInput(func_graph, cnode);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimStridedSlice)) {
|
||||
status = AdjustStridedSlice(func_graph, cnode);
|
||||
|
|
|
@ -653,7 +653,7 @@ api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &f
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.save_type);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
|
||||
if (!unify_format->Run(graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -74,7 +74,7 @@ api::FuncGraphPtr PytorchModelParser::Parse(const converter::ConverterParameters
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false, flag.export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false, flag.save_type);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
|
||||
if (!unify_format->Run(anf_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -635,7 +635,7 @@ api::FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &fla
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false, flag.export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false, flag.save_type);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -207,7 +207,7 @@ api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false, flag.export_mindir);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false, flag.save_type);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
|
|
|
@ -25,8 +25,8 @@ namespace lite {
|
|||
class UnifyFormatToNHWC : public opt::ToFormatBase {
|
||||
public:
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
mindspore::ModelType export_mindir = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, export_mindir) {}
|
||||
mindspore::ModelType save_type = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, save_type) {}
|
||||
~UnifyFormatToNHWC() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ void GraphKernelOptimizer::Run(const FuncGraphPtr &func_graph) {
|
|||
(void)pm_list.emplace_back(BuildKernel());
|
||||
|
||||
for (auto &pm : pm_list) {
|
||||
pm->SetDumpIr(converter_param_->export_mindir);
|
||||
pm->SetDumpIr(converter_param_->save_type);
|
||||
optimizer->AddPassManager(pm);
|
||||
}
|
||||
|
||||
|
|
|
@ -344,7 +344,7 @@ bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph)
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (main_graph && export_mindir_ != kMindIR) {
|
||||
if (main_graph && save_type_ != kMindIR) {
|
||||
status = HandleGraphInput(func_graph);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "handle graph input failed.";
|
||||
|
|
|
@ -32,8 +32,8 @@ namespace opt {
|
|||
class ToFormatBase : public Pass {
|
||||
public:
|
||||
explicit ToFormatBase(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
ModelType export_mindir = kMindIR, const std::string &pass_name = "ToFormatBase")
|
||||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag), export_mindir_(export_mindir) {}
|
||||
ModelType save_type = kMindIR, const std::string &pass_name = "ToFormatBase")
|
||||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag), save_type_(save_type) {}
|
||||
~ToFormatBase() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
static bool IsConvFamilyNode(const AnfNodePtr &node) {
|
||||
|
@ -68,7 +68,7 @@ class ToFormatBase : public Pass {
|
|||
schema::Format *dst_format) = 0;
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
ModelType export_mindir_ = kMindIR_Lite;
|
||||
ModelType save_type_ = kMindIR_Lite;
|
||||
mindspore::Format format_{mindspore::NHWC};
|
||||
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
|
||||
std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_;
|
||||
|
|
|
@ -24,8 +24,8 @@ namespace opt {
|
|||
class ToNCHWFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNCHWFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
ModelType export_mindir = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, export_mindir, "ToNCHWFormat") {
|
||||
ModelType save_type = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, save_type, "ToNCHWFormat") {
|
||||
format_ = mindspore::NCHW;
|
||||
}
|
||||
~ToNCHWFormat() = default;
|
||||
|
|
|
@ -24,8 +24,8 @@ namespace opt {
|
|||
class ToNHWCFormat : public ToFormatBase {
|
||||
public:
|
||||
explicit ToNHWCFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
ModelType export_mindir = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, export_mindir, "ToNHWCFormat") {}
|
||||
ModelType save_type = kMindIR)
|
||||
: ToFormatBase(fmk_type, train_flag, save_type, "ToNHWCFormat") {}
|
||||
~ToNHWCFormat() = default;
|
||||
|
||||
protected:
|
||||
|
|
Loading…
Reference in New Issue