delete trainModel class

formating

add include

fix arm32

delete useness code

use char * in external APIs

formatting

change virtual export to static
This commit is contained in:
lz 2021-04-01 18:55:11 +08:00
parent b276cf5aaf
commit 0517ffbff3
15 changed files with 158 additions and 283 deletions

View File

@ -99,7 +99,13 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
model_ = mindspore::lite::Model::Import(ms_file_);
if (model_ == nullptr) {
MS_LOG(ERROR) << "import model failed";
return nullptr;
}
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
MS_ASSERT(nullptr != session_);
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
@ -154,7 +160,6 @@ int NetRunner::InitDB() {
std::cout << "No relevant data was found in " << data_dir_ << std::endl;
MS_ASSERT(train_ds_->GetDatasetSize() != 0);
}
return 0;
}
@ -182,7 +187,7 @@ int NetRunner::Main() {
if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
session_->SaveToFile(trained_fn);
Model::Export(model_, trained_fn);
}
return 0;
}

View File

@ -27,6 +27,7 @@
#include "include/train/accuracy_metrics.h"
#include "include/ms_tensor.h"
#include "include/datasets.h"
#include "include/model.h"
using mindspore::dataset::Dataset;
using mindspore::lite::AccuracyMetrics;
@ -36,6 +37,7 @@ class NetRunner {
int Main();
bool ReadArgs(int argc, char *argv[]);
~NetRunner();
mindspore::lite::Model *model_ = nullptr;
private:
void Usage();

View File

@ -45,13 +45,17 @@ struct MS_API Model {
SubGraphPtrVector sub_graphs_;
/// \brief Static method to create a Model pointer.
///
/// \param[in] model_buf Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
///
/// \return Pointer of MindSpore Lite Model.
static Model *Import(const char *model_buf, size_t size);
/// \brief Static method to create a Model pointer.
static Model *Import(const char *filename);
/// \brief method to export model to file.
static int Export(Model *model, const char *filename);
/// \brief method to export model to buffer.
static int Export(Model *model, char *buf, size_t *size);
/// \brief Free meta graph temporary buffer
virtual void Free() = 0;

View File

@ -32,23 +32,12 @@ class TrainSession : public session::LiteSession {
/// \brief Static method to create a TrainSession object
///
/// \param[in] model_buf A buffer that was read from a MS model file
/// \param[in] size Length of the buffer
/// \param[in] model A buffer that was read from a MS model file
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateSession(const char *model_buf, size_t size, lite::Context *context,
bool train_mode = false);
/// \brief Static method to create a TrainSession object
///
/// \param[in] filename Filename to read flatbuffer from
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateSession(const std::string &filename, lite::Context *context, bool train_mode = false);
static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false);
/// \brief Static method to create a transfer lernning support TrainSession object
///
@ -75,21 +64,6 @@ class TrainSession : public session::LiteSession {
static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
lite::Context *context, bool train_mode = false);
/// \brief Export the trained model into a buffer
///
/// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated
/// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer
///
/// \return pointer to the export buffer
virtual void *ExportToBuf(char *buf, size_t *len) const = 0;
/// \brief Save the trained model into a flatbuffer file
///
/// \param[in] filename Filename to save flatbuffer to
///
/// \return 0 on success or -1 in case of error
virtual int SaveToFile(const std::string &filename) const = 0;
/// \brief Set model to train mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Train() = 0;

View File

@ -111,7 +111,6 @@ if(SUPPORT_TRAIN)
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc

View File

@ -15,9 +15,13 @@
*/
#include "src/lite_model.h"
#include <sys/stat.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <set>
#include <unordered_map>
#include <memory>
#include "src/common/prim_util.h"
#ifdef ENABLE_V0
#include "src/ops/compat/compat_register.h"
@ -343,5 +347,102 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
return model;
}
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) {
std::ifstream ifs(filename);
if (!ifs.good()) {
MS_LOG(ERROR) << "File: " << filename << " does not exist";
return std::unique_ptr<char[]>(nullptr);
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "File: " << filename << " open failed";
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::end);
auto tellg_ret = ifs.tellg();
if (tellg_ret <= 0) {
MS_LOG(ERROR) << "Could not read file " << filename;
return std::unique_ptr<char[]>(nullptr);
}
size_t fsize = static_cast<size_t>(tellg_ret);
std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]);
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), fsize);
if (!ifs) {
MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.close();
if (size != nullptr) {
*size = fsize;
}
return buf;
}
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
Model *Model::Import(const char *filename) {
size_t size = -1;
auto buf = ReadFileToBuf(filename, &size);
if (buf == nullptr) {
return nullptr;
}
return ImportFromBuffer(buf.get(), size, false);
}
int Model::Export(Model *model, char *buffer, size_t *len) {
if (len == nullptr) {
MS_LOG(ERROR) << "len is nullptr";
return RET_ERROR;
}
auto *liteModel = reinterpret_cast<LiteModel *>(model);
if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
MS_LOG(ERROR) << "model buffer is invalid";
return RET_ERROR;
}
if (*len < liteModel->buf_size_ && buffer != nullptr) {
MS_LOG(ERROR) << "Buffer is too small, Export Failed";
return RET_ERROR;
}
if (buffer == nullptr) {
buffer = reinterpret_cast<char *>(malloc(liteModel->buf_size_));
if (buffer == nullptr) {
MS_LOG(ERROR) << "allocated model buf fail!";
return RET_ERROR;
}
}
memcpy(buffer, liteModel->buf, liteModel->buf_size_);
*len = liteModel->buf_size_;
return RET_OK;
}
int Model::Export(Model *model, const char *filename) {
auto *liteModel = reinterpret_cast<LiteModel *>(model);
if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
MS_LOG(ERROR) << "model buf is invalid";
return RET_ERROR;
}
std::ofstream ofs(filename);
if (!ofs.good() || !ofs.is_open()) {
MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
return RET_ERROR;
}
ofs.seekp(0, std::ios::beg);
ofs.write(liteModel->buf, liteModel->buf_size_);
ofs.close();
return chmod(filename, S_IRUSR);
}
} // namespace mindspore::lite

View File

@ -1,91 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/train/train_model.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "src/common/graph_util.h"
namespace mindspore::lite {
TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
TrainModel *model = new (std::nothrow) TrainModel();
if (model == nullptr) {
MS_LOG(ERROR) << "new model fail!";
return nullptr;
}
model->buf = reinterpret_cast<char *>(malloc(size));
if (model->buf == nullptr) {
delete model;
MS_LOG(ERROR) << "malloc inner model buf fail!";
return nullptr;
}
memcpy(model->buf, model_buf, size);
model->buf_size_ = size;
auto status = model->ConstructModel();
if (status != RET_OK) {
MS_LOG(ERROR) << "construct model failed.";
delete model;
return nullptr;
}
return model;
}
void TrainModel::Free() {}
char *TrainModel::ExportBuf(char *buffer, size_t *len) const {
if (len == nullptr) {
MS_LOG(ERROR) << "len is nullptr";
return nullptr;
}
if (buf_size_ == 0 || buf == nullptr) {
MS_LOG(ERROR) << "Model::Export is only available for Train Session";
return nullptr;
}
if (*len < buf_size_ && buffer != nullptr) {
MS_LOG(ERROR) << "Buffer is too small, Export Failed";
return nullptr;
}
if (buffer == nullptr) {
buffer = reinterpret_cast<char *>(malloc(buf_size_));
if (buffer == nullptr) {
MS_LOG(ERROR) << "allocated model buf fail!";
return nullptr;
}
}
memcpy(buffer, buf, buf_size_);
*len = buf_size_;
return buffer;
}
char *TrainModel::GetBuffer(size_t *len) const {
if (len == nullptr) {
MS_LOG(ERROR) << "len is nullptr";
return nullptr;
}
if (buf_size_ == 0 || buf == nullptr) {
MS_LOG(ERROR) << "Model::Export is only available for Train Session";
return nullptr;
}
*len = buf_size_;
return buf;
}
} // namespace mindspore::lite

View File

@ -1,57 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_
#include <vector>
#include "src/lite_model.h"
namespace mindspore {
namespace lite {
/// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model
struct TrainModel : public lite::LiteModel {
/// \brief Static method to create a TrainModel object
///
/// \param[in] model_buf A buffer that was read from a MS model file
/// \param[in] size Length of the buffer
//
/// \return Pointer to MindSpore Lite TrainModel
static TrainModel *Import(const char *model_buf, size_t size);
/// \brief Free meta graph related data
void Free() override;
/// \brief Class destructor, free all memory
virtual ~TrainModel() = default;
/// \brief Export Model into a buffer
///
/// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated
/// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer
///
/// \return Pointer to buffer with exported model
char *ExportBuf(char *buf, size_t *len) const;
/// \brief Get Model buffer
///
/// \param[in,out] len Return size of the buffer
///
/// \return Pointer to model buffer
char *GetBuffer(size_t *len) const;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_

View File

@ -25,6 +25,7 @@
#include "include/errorcode.h"
#include "src/common/utils.h"
#include "src/tensor.h"
#include "src/lite_model.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
#include "src/sub_graph_kernel.h"
@ -39,47 +40,6 @@
namespace mindspore {
namespace lite {
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) {
std::ifstream ifs(filename);
if (!ifs.good()) {
MS_LOG(ERROR) << "File: " << filename << " does not exist";
return std::unique_ptr<char[]>(nullptr);
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "File: " << filename << " open failed";
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::end);
auto tellg_ret = ifs.tellg();
if (tellg_ret <= 0) {
MS_LOG(ERROR) << "Could not read file " << filename;
return std::unique_ptr<char[]>(nullptr);
}
size_t fsize = static_cast<size_t>(tellg_ret);
std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]);
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), fsize);
if (!ifs) {
MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.close();
if (size != nullptr) {
*size = fsize;
}
return buf;
}
static size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) {
for (size_t i = 0; i < where.size(); i++) {
if (where[i] == searchParameter) {
@ -139,7 +99,7 @@ void TrainSession::AllocWorkSpace() {
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }
int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) {
int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) {
model_ = model;
auto restore = ReplaceOps();
@ -171,8 +131,6 @@ TrainSession::~TrainSession() {
}
}
void *TrainSession::ExportToBuf(char *buf, size_t *len) const { return model_->ExportBuf(buf, len); }
int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) {
this->outputs_.clear();
@ -214,25 +172,6 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a
return RET_OK;
}
int TrainSession::SaveToFile(const std::string &filename) const {
size_t fb_size = 0;
const auto *buf = reinterpret_cast<char *>(model_->GetBuffer(&fb_size));
if (buf == nullptr) {
MS_LOG(ERROR) << "Could not Export Trained model";
return lite::RET_NULL_PTR;
}
std::ofstream ofs(filename);
if ((true != ofs.good()) || (true != ofs.is_open())) {
MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
return RET_ERROR;
}
ofs.seekp(0, std::ios::beg);
ofs.write(buf, fb_size);
ofs.close();
return chmod(filename.c_str(), S_IRUSR);
}
int TrainSession::Train() {
// shift kernels to train mode
train_mode_ = true;
@ -522,14 +461,8 @@ int TrainSession::SetLossName(std::string loss_name) {
}
} // namespace lite
session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context,
session::TrainSession *session::TrainSession::CreateSession(mindspore::lite::Model *model, lite::Context *context,
bool train_mode) {
auto model = mindspore::lite::TrainModel::Import(model_buf, size);
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return nullptr;
}
auto session = new (std::nothrow) lite::TrainSession();
if (session == nullptr) {
delete model;
@ -564,15 +497,4 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu
return session;
}
session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, lite::Context *context,
bool train_mode) {
size_t size = -1;
auto buf = lite::ReadFileToBuf(filename, &size);
if (buf == nullptr) {
return nullptr;
}
return session::TrainSession::CreateSession(buf.get(), size, context, train_mode);
}
} // namespace mindspore

View File

@ -21,7 +21,6 @@
#include <unordered_map>
#include <memory>
#include "include/train/train_session.h"
#include "src/train/train_model.h"
#include "src/lite_session.h"
/*
@ -52,10 +51,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
int CompileGraph(lite::Model *model) override;
virtual int CompileTrainGraph(lite::TrainModel *model);
void *ExportToBuf(char *buf, size_t *len) const override;
int SaveToFile(const std::string &filename) const override;
virtual int CompileTrainGraph(lite::Model *model);
int Train() override;
int Eval() override;
@ -108,7 +104,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
virtual void CompileTrainOutputs();
virtual void CompileEvalOutputs();
TrainModel *model_ = nullptr;
Model *model_ = nullptr;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
std::vector<std::string> orig_output_tensor_names_;

View File

@ -190,7 +190,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
return nullptr;
}
auto model = lite::TrainModel::Import(model_buf_head, size_head);
auto model = lite::Model::Import(model_buf_head, size_head);
if (model == nullptr) {
MS_LOG(ERROR) << "create model for head train session failed";
delete session;

View File

@ -20,7 +20,6 @@
#include <tuple>
#include <unordered_map>
#include <utility>
#include "src/train/train_model.h"
#include "src/lite_session.h"
#include "src/train/train_session.h"

View File

@ -281,7 +281,6 @@ if(SUPPORT_TRAIN)
${LITE_DIR}/src/train/train_populate_parameter_v0.cc
${LITE_DIR}/src/train/train_session.cc
${LITE_DIR}/src/train/transfer_session.cc
${LITE_DIR}/src/train/train_model.cc
${LITE_DIR}/src/lite_session.cc
)
else()

View File

@ -359,10 +359,14 @@ TEST_F(NetworkTest, tuning_layer) {
meta_graph.reset();
content = nullptr;
auto *model = mindspore::lite::Model::Import(content, size);
ASSERT_NE(nullptr, model);
lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context.thread_num_ = 1;
auto session = session::TrainSession::CreateSession(content, size, &context);
auto session = session::TrainSession::CreateSession(model, &context);
ASSERT_NE(nullptr, session);
session->Train();
session->Train(); // Just double check that calling Train twice does not cause a problem
@ -513,7 +517,10 @@ TEST_F(NetworkTest, efficient_net) {
context->thread_num_ = 1;
std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
auto session = session::TrainSession::CreateSession(net, context, false);
auto *model = mindspore::lite::Model::Import(net.c_str());
ASSERT_NE(model, nullptr);
auto session = session::TrainSession::CreateSession(model, context, false);
ASSERT_NE(session, nullptr);
std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
@ -530,7 +537,6 @@ TEST_F(NetworkTest, mobileface_net) {
std::string net = "./test_data/nets/mobilefacenet0924.ms";
ReadFile(net.c_str(), &net_size, &buf);
// auto model = lite::TrainModel::Import(buf, net_size);
auto model = lite::Model::Import(buf, net_size);
delete[] buf;
auto context = new lite::Context;
@ -538,7 +544,6 @@ TEST_F(NetworkTest, mobileface_net) {
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context->thread_num_ = 1;
// auto session = session::TrainSession::CreateSession(context);
auto session = session::LiteSession::CreateSession(context);
ASSERT_NE(session, nullptr);
auto ret = session->CompileGraph(model);
@ -560,7 +565,10 @@ TEST_F(NetworkTest, setname) {
lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context.thread_num_ = 1;
auto session = mindspore::session::TrainSession::CreateSession(net, &context);
auto *model = mindspore::lite::Model::Import(net.c_str());
ASSERT_NE(model, nullptr);
auto session = mindspore::session::TrainSession::CreateSession(model, &context);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();

View File

@ -25,6 +25,7 @@
#include "include/context.h"
#include "src/runtime/runtime_api.h"
#include "include/version.h"
#include "include/model.h"
namespace mindspore {
namespace lite {
@ -326,7 +327,14 @@ int NetTrain::RunExportedNet() {
}
context->thread_num_ = flags_->num_threads_;
session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get());
auto *model = mindspore::lite::Model::Import(flags_->export_file_.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return RET_ERROR;
}
session_ = session::TrainSession::CreateSession(model, context.get());
if (session_ == nullptr) {
MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
@ -388,7 +396,13 @@ int NetTrain::RunNetTrain() {
context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
layer_checksum_ = flags_->layer_checksum_;
context->thread_num_ = flags_->num_threads_;
session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get());
auto *model = mindspore::lite::Model::Import(flags_->model_file_.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return RET_ERROR;
}
session_ = session::TrainSession::CreateSession(model, context.get());
if (session_ == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
@ -432,7 +446,7 @@ int NetTrain::RunNetTrain() {
}
}
if (!flags_->export_file_.empty()) {
auto ret = session_->SaveToFile(flags_->export_file_);
auto ret = Model::Export(model, flags_->export_file_.c_str());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SaveToFile error";
std::cout << "Run SaveToFile error";