delete exportMindIR para

This commit is contained in:
zhou_chao1993 2023-02-20 09:51:10 +08:00
parent d79e3079ff
commit 0bd549f328
28 changed files with 60 additions and 80 deletions

View File

@ -7,7 +7,7 @@ mindspore_lite.ModelType
适用于以下场景:
1. Converter时设置 `export_mindir` 参数, `ModelType` 用于定义转换生成的模型类型。
1. Converter时设置 `save_type` 参数, `ModelType` 用于定义转换生成的模型类型。
2. Converter之后当从文件加载或构建模型以进行推理时 `ModelType` 用于定义输入模型框架类型。

View File

@ -59,8 +59,8 @@ class MS_API Converter {
void SetOutputDataType(DataType data_type);
DataType GetOutputDataType();
void SetExportMindIR(ModelType export_mindir);
ModelType GetExportMindIR() const;
void SetSaveType(ModelType save_type);
ModelType GetSaveType() const;
inline void SetDecryptKey(const std::string &key);
inline std::string GetDecryptKey() const;

View File

@ -44,7 +44,7 @@ enum MS_API FmkType : int {
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
struct MS_API ConverterParameters {
FmkType fmk;
ModelType export_mindir = kMindIR_Lite;
ModelType save_type = kMindIR_Lite;
std::string model_file;
std::string weight_file;
std::map<std::string, std::string> attrs;

View File

@ -296,7 +296,7 @@ class Converter:
if output_data_type != DataType.FLOAT32:
self._converter.set_output_data_type(data_type_py_cxx_map.get(output_data_type))
if save_type != ModelType.MINDIR_LITE:
self._converter.set_export_mindir(model_type_py_cxx_map.get(save_type))
self._converter.set_save_type(model_type_py_cxx_map.get(save_type))
if decrypt_key != "":
self._converter.set_decrypt_key(decrypt_key)
self._converter.set_decrypt_mode(decrypt_mode)
@ -330,7 +330,7 @@ class Converter:
f"input_format: {format_cxx_py_map.get(self._converter.get_input_format())},\n" \
f"input_data_type: {data_type_cxx_py_map.get(self._converter.get_input_data_type())},\n" \
f"output_data_type: {data_type_cxx_py_map.get(self._converter.get_output_data_type())},\n" \
f"save_type: {model_type_cxx_py_map.get(self._converter.get_export_mindir())},\n" \
f"save_type: {model_type_cxx_py_map.get(self._converter.get_save_type())},\n" \
f"decrypt_key: {self._converter.get_decrypt_key()},\n" \
f"decrypt_mode: {self._converter.get_decrypt_mode()},\n" \
f"enable_encryption: {self._converter.get_enable_encryption()},\n" \

View File

@ -48,8 +48,8 @@ void ConverterPyBind(const py::module &m) {
.def("get_input_data_type", &Converter::GetInputDataType)
.def("set_output_data_type", &Converter::SetOutputDataType)
.def("get_output_data_type", &Converter::GetOutputDataType)
.def("set_export_mindir", &Converter::SetExportMindIR)
.def("get_export_mindir", &Converter::GetExportMindIR)
.def("set_save_type", &Converter::SetSaveType)
.def("get_save_type", &Converter::GetSaveType)
.def("set_decrypt_key", py::overload_cast<const std::string &>(&Converter::SetDecryptKey))
.def("get_decrypt_key", &Converter::GetDecryptKey)
.def("set_decrypt_mode", py::overload_cast<const std::string &>(&Converter::SetDecryptMode))

View File

@ -34,7 +34,7 @@ int RuntimeConvert(const mindspore::api::FuncGraphPtr &graph, const std::shared_
param->output_data_type = mindspore::DataType::kTypeUnknown;
param->weight_fp16 = false;
param->train_model = false;
param->export_mindir = mindspore::kMindIR;
param->save_type = mindspore::kMindIR;
param->enable_encryption = false;
param->is_runtime_converter = true;

View File

@ -100,7 +100,7 @@ bool LiteRTGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map
if (fb_model_buf_ == nullptr) {
auto param = std::make_shared<ConverterPara>();
param->fmk_type = converter::kFmkTypeMs;
param->export_mindir = kMindIR;
param->save_type = kMindIR;
auto mutable_graph = std::const_pointer_cast<FuncGraph>(graph);
meta_graph = lite::ConverterToMetaGraph::Build(param, mutable_graph);
if (meta_graph == nullptr) {

View File

@ -150,7 +150,7 @@ STATUS PreProcForOnnx(const FuncGraphPtr &func_graph, bool offline) {
AclPassImpl::AclPassImpl(const std::shared_ptr<ConverterPara> &param)
: param_(param),
fmk_type_(param->fmk_type),
export_mindir_(param->export_mindir),
export_mindir_(param->save_type),
user_options_cfg_(std::move(param->aclModelOptionCfgParam)),
om_parameter_(nullptr),
custom_node_(nullptr) {}
@ -693,7 +693,7 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
STATUS AclPassImpl::PreQuantization(const FuncGraphPtr &func_graph) {
auto value = func_graph->get_attr(ops::kFormat);
if (value == nullptr) {
auto unify_format = std::make_shared<lite::UnifyFormatToNHWC>(fmk_type_, false, param_->export_mindir);
auto unify_format = std::make_shared<lite::UnifyFormatToNHWC>(fmk_type_, false, param_->save_type);
CHECK_NULL_RETURN(unify_format);
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -461,7 +461,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
}
}
// adjust for conv2d_transpose
if (!(param->no_fusion && param->export_mindir == kMindIR)) {
if (!(param->no_fusion && param->save_type == kMindIR)) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(old_graph, &all_func_graphs);
auto conv2d_transpose_adjust = std::make_shared<Conv2DTransposeInputAdjust>();
@ -563,7 +563,7 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_pt
}
int AnfTransform::DoFormatForMindIR(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
if (param->export_mindir != kMindIR) {
if (param->save_type != kMindIR) {
return RET_OK;
}
if (param->no_fusion || param->device.find("Ascend") == std::string::npos) {
@ -709,7 +709,7 @@ int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<C
STATUS AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
MS_ASSERT(old_graph != nullptr);
MS_ASSERT(param != nullptr);
if (param->no_fusion && param->export_mindir == kMindIR) { // converter, online
if (param->no_fusion && param->save_type == kMindIR) { // converter, online
if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
MS_LOG(ERROR) << "Proc online transform failed.";
return RET_ERROR;

View File

@ -622,12 +622,11 @@ int CheckInputOutputDataType(const std::shared_ptr<ConverterPara> &param) {
return RET_OK;
}
int CheckExportMindIR(const std::shared_ptr<ConverterPara> &param) {
int CheckSaveType(const std::shared_ptr<ConverterPara> &param) {
if (param != nullptr) {
std::set valid_values = {kMindIR, kMindIR_Lite};
if (std::find(valid_values.begin(), valid_values.end(), param->export_mindir) == valid_values.end()) {
MS_LOG(ERROR) << "INPUT ILLEGAL: export_mindir is not in {kMindIR, kMindIR_Lite}, but got "
<< param->export_mindir;
if (std::find(valid_values.begin(), valid_values.end(), param->save_type) == valid_values.end()) {
MS_LOG(ERROR) << "INPUT ILLEGAL: save_type is not in {kMindIR, kMindIR_Lite}, but got " << param->save_type;
return RET_INPUT_PARAM_INVALID;
}
}
@ -722,9 +721,9 @@ int CheckValueParam(const std::shared_ptr<ConverterPara> &param) {
return RET_INPUT_PARAM_INVALID;
}
ret = CheckExportMindIR(param);
ret = CheckSaveType(param);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Check value of export_mindir failed.";
MS_LOG(ERROR) << "Check value of save_type failed.";
return RET_INPUT_PARAM_INVALID;
}
@ -812,7 +811,7 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, void **m
int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> &param, void **model_data,
size_t *data_size, bool not_save) {
int status = RET_ERROR;
if (param->export_mindir == kMindIR) {
if (param->save_type == kMindIR) {
status = SaveMindIRModel(graph, param, model_data, data_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Save mindir model failed :" << status << " " << GetErrorInfo(status);

View File

@ -74,7 +74,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C
}
converter::ConverterParameters converter_parameters;
converter_parameters.fmk = param->fmk_type;
converter_parameters.export_mindir = param->export_mindir;
converter_parameters.save_type = param->save_type;
converter_parameters.model_file = param->model_file;
converter_parameters.weight_file = param->weight_file;
func_graph_base = model_parser->Parse(converter_parameters);
@ -196,8 +196,7 @@ STATUS ConverterFuncGraph::UnifyFuncGraphForInfer(const std::shared_ptr<Converte
}
}
auto unify_format =
std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model, param->export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model, param->save_type);
MS_CHECK_TRUE_MSG(unify_format != nullptr, RET_NULL_PTR, "unify_format is nullptr.");
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
@ -219,7 +218,7 @@ STATUS ConverterFuncGraph::UnifyFuncGraphInputFormat(const std::shared_ptr<Conve
auto spec_input_format = param->spec_input_format;
if (spec_input_format == DEFAULT_FORMAT) {
if (param->export_mindir == kMindIR || param->fmk_type != converter::kFmkTypeMs) {
if (param->save_type == kMindIR || param->fmk_type != converter::kFmkTypeMs) {
// if it saves to mindir, the input format must be the same as the original model
// if it saves to mindir lite, the input format must be the same as the original model for 3rd model
func_graph->set_attr(kInputFormat, MakeValue(static_cast<int>(cur_input_format)));

View File

@ -83,12 +83,15 @@ Flags::Flags() {
"Whether to do pre-inference after convert. "
"true | false",
"false");
AddFlag(&Flags::exportMindIR, "exportMindIR", "MINDIR | MINDIR_LITE", "MINDIR_LITE");
AddFlag(&Flags::noFusionStr, "NoFusion",
"Avoid fusion optimization true|false. NoFusion is true when saveType is MINDIR.", "");
AddFlag(&Flags::device, "device",
"Set the target device, support Ascend, Ascend310 and Ascend310P will be deprecated.", "");
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR");
#else
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
#endif
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "general");
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
}
@ -266,22 +269,6 @@ int Flags::InitOptimize() {
return RET_OK;
}
int Flags::InitExportMindIR() {
// value check not here, it is in converter c++ API's CheckValueParam method.
std::map<std::string, ModelType> StrToEnumModelTypeMap = {{"MINDIR", kMindIR}, {"MINDIR_LITE", kMindIR_Lite}};
if (StrToEnumModelTypeMap.find(this->exportMindIR) != StrToEnumModelTypeMap.end()) {
this->export_mindir = StrToEnumModelTypeMap.at(this->exportMindIR);
} else {
std::cerr << "INPUT ILLEGAL: exportMindIR must be MINDIR|MINDIR_LITE " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if ((this->exportMindIR == "MINDIR") && (this->optimizeTransformer == false)) {
this->disableFusion = true;
}
return RET_OK;
}
int Flags::InitEncrypt() {
if (this->encryptionStr == "true") {
this->encryption = true;
@ -295,13 +282,9 @@ int Flags::InitEncrypt() {
}
int Flags::InitSaveType() {
// For compatibility of interface, the check will be removed when exportMindIR is deleted
if (this->exportMindIR == "MINDIR") {
return RET_OK;
}
std::map<std::string, ModelType> StrToEnumModelTypeMap = {{"MINDIR", kMindIR}, {"MINDIR_LITE", kMindIR_Lite}};
if (StrToEnumModelTypeMap.find(this->saveTypeStr) != StrToEnumModelTypeMap.end()) {
this->export_mindir = StrToEnumModelTypeMap.at(this->saveTypeStr);
this->save_type = StrToEnumModelTypeMap.at(this->saveTypeStr);
} else {
std::cerr << "INPUT ILLEGAL: saveType must be MINDIR|MINDIR_LITE " << std::endl;
return RET_INPUT_PARAM_INVALID;
@ -407,12 +390,6 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
ret = InitExportMindIR();
if (ret != RET_OK) {
std::cerr << "Init export mindir failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitSaveType();
if (ret != RET_OK) {
std::cerr << "Init save type failed." << std::endl;

View File

@ -41,7 +41,6 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitSaveFP16();
int InitNoFusion();
int InitOptimize();
int InitExportMindIR();
int InitSaveType();
int InitOptimizeTransformer();
@ -74,10 +73,12 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string encMode = "AES-GCM";
std::string inferStr;
bool infer = false;
std::string exportMindIR;
ModelType export_mindir = kMindIR_Lite;
std::string saveTypeStr;
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
ModelType save_type = kMindIR;
#else
ModelType save_type = kMindIR_Lite;
#endif
std::string optimizeStr;
#ifdef ENABLE_OPENSSL
std::string encryptionStr = "true";

View File

@ -57,7 +57,7 @@ int main(int argc, const char **argv) {
converter.SetInputFormat(flags.graphInputFormat);
converter.SetInputDataType(flags.inputDataType);
converter.SetOutputDataType(flags.outputDataType);
converter.SetExportMindIR(flags.export_mindir);
converter.SetSaveType(flags.save_type);
converter.SetDecryptKey(flags.dec_key);
flags.dec_key.clear();
converter.SetDecryptMode(flags.dec_mode);

View File

@ -158,15 +158,15 @@ DataType Converter::GetOutputDataType() {
}
}
void Converter::SetExportMindIR(ModelType export_mindir) {
void Converter::SetSaveType(ModelType save_type) {
if (data_ != nullptr) {
data_->export_mindir = export_mindir;
data_->save_type = save_type;
}
}
ModelType Converter::GetExportMindIR() const {
ModelType Converter::GetSaveType() const {
if (data_ != nullptr) {
return data_->export_mindir;
return data_->save_type;
} else {
return kMindIR_Lite;
}

View File

@ -54,7 +54,11 @@ struct ConverterPara {
Format spec_input_format = DEFAULT_FORMAT;
DataType input_data_type = DataType::kNumberTypeFloat32;
DataType output_data_type = DataType::kNumberTypeFloat32;
ModelType export_mindir = kMindIR_Lite;
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
ModelType save_type = kMindIR;
#else
ModelType save_type = kMindIR_Lite;
#endif
std::string decrypt_key;
std::string decrypt_mode = "AES-GCM";
std::string encrypt_key;

View File

@ -169,7 +169,7 @@ api::FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false, flag.export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false, flag.save_type);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -487,7 +487,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::Co
}
if (opt::CheckPrimitiveType(node, prim::kPrimConstant)) {
status = ReplaceConstant(func_graph, cnode);
} else if (opt::CheckPrimitiveType(node, prim::kPrimTranspose) && flag.export_mindir != kMindIR) {
} else if (opt::CheckPrimitiveType(node, prim::kPrimTranspose) && flag.save_type != kMindIR) {
status = ReplaceTransposeWithGraphInput(func_graph, cnode);
} else if (opt::CheckPrimitiveType(node, prim::kPrimStridedSlice)) {
status = AdjustStridedSlice(func_graph, cnode);

View File

@ -653,7 +653,7 @@ api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &f
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.save_type);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
if (!unify_format->Run(graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -74,7 +74,7 @@ api::FuncGraphPtr PytorchModelParser::Parse(const converter::ConverterParameters
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false, flag.export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypePytorch, false, flag.save_type);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
if (!unify_format->Run(anf_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -635,7 +635,7 @@ api::FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &fla
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false, flag.export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false, flag.save_type);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -207,7 +207,7 @@ api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false, flag.export_mindir);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false, flag.save_type);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -25,8 +25,8 @@ namespace lite {
class UnifyFormatToNHWC : public opt::ToFormatBase {
public:
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
mindspore::ModelType export_mindir = kMindIR)
: ToFormatBase(fmk_type, train_flag, export_mindir) {}
mindspore::ModelType save_type = kMindIR)
: ToFormatBase(fmk_type, train_flag, save_type) {}
~UnifyFormatToNHWC() override = default;
bool Run(const FuncGraphPtr &func_graph) override;

View File

@ -118,7 +118,7 @@ void GraphKernelOptimizer::Run(const FuncGraphPtr &func_graph) {
(void)pm_list.emplace_back(BuildKernel());
for (auto &pm : pm_list) {
pm->SetDumpIr(converter_param_->export_mindir);
pm->SetDumpIr(converter_param_->save_type);
optimizer->AddPassManager(pm);
}

View File

@ -344,7 +344,7 @@ bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph)
return false;
}
}
if (main_graph && export_mindir_ != kMindIR) {
if (main_graph && save_type_ != kMindIR) {
status = HandleGraphInput(func_graph);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "handle graph input failed.";

View File

@ -32,8 +32,8 @@ namespace opt {
class ToFormatBase : public Pass {
public:
explicit ToFormatBase(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
ModelType export_mindir = kMindIR, const std::string &pass_name = "ToFormatBase")
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag), export_mindir_(export_mindir) {}
ModelType save_type = kMindIR, const std::string &pass_name = "ToFormatBase")
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag), save_type_(save_type) {}
~ToFormatBase() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
static bool IsConvFamilyNode(const AnfNodePtr &node) {
@ -68,7 +68,7 @@ class ToFormatBase : public Pass {
schema::Format *dst_format) = 0;
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
ModelType export_mindir_ = kMindIR_Lite;
ModelType save_type_ = kMindIR_Lite;
mindspore::Format format_{mindspore::NHWC};
std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr};
std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_;

View File

@ -24,8 +24,8 @@ namespace opt {
class ToNCHWFormat : public ToFormatBase {
public:
explicit ToNCHWFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
ModelType export_mindir = kMindIR)
: ToFormatBase(fmk_type, train_flag, export_mindir, "ToNCHWFormat") {
ModelType save_type = kMindIR)
: ToFormatBase(fmk_type, train_flag, save_type, "ToNCHWFormat") {
format_ = mindspore::NCHW;
}
~ToNCHWFormat() = default;

View File

@ -24,8 +24,8 @@ namespace opt {
class ToNHWCFormat : public ToFormatBase {
public:
explicit ToNHWCFormat(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
ModelType export_mindir = kMindIR)
: ToFormatBase(fmk_type, train_flag, export_mindir, "ToNHWCFormat") {}
ModelType save_type = kMindIR)
: ToFormatBase(fmk_type, train_flag, save_type, "ToNHWCFormat") {}
~ToNHWCFormat() = default;
protected: