diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index d602917ed85..80d74e39714 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -81,7 +81,7 @@ class TrainSession : virtual public lite::LiteSession { return lite::LiteSession::GetOutputByTensorName(tensor_name); } int Resize(const std::vector &inputs, const std::vector> &dims) override { - return lite::RET_ERROR; + return lite::LiteSession::Resize(inputs, dims); } std::vector GetPredictions() const override { diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index 542d89f060d..0104e4b189f 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -35,6 +35,8 @@ namespace mindspore { namespace lite { static const char *DELIM_SLASH = "/"; +constexpr const char *DELIM_COLON = ":"; +constexpr const char *DELIM_COMMA = ","; constexpr int RET_TOO_BIG = -9; namespace { @@ -319,10 +321,78 @@ static CpuBindMode FlagToBindMode(int flag) { return NO_BIND; } +std::unique_ptr NetTrain::CreateAndRunNetworkForTrain(const std::string &filename, + const std::string &bb_filename, + const Context &context, + const TrainCfg &train_cfg, int epochs) { + std::unique_ptr session = nullptr; + std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); + if (!bb_filename.empty()) { + MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename; + std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl; + session = std::unique_ptr( + session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); + if (session == nullptr) { + MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str(); + std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl; + return nullptr; + } + + } else { + MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str(); + std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl; + session = std::unique_ptr( + session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); + if (session == nullptr) { + MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str(); + std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl; + return nullptr; + } + } + if (epochs > 0) { + if (flags_->virtual_batch_) { + session->SetupVirtualBatch(epochs); + } + session->Train(); + } + return session; +} + +std::unique_ptr NetTrain::CreateAndRunNetworkForInference(const std::string &filename, + const Context &context) { + std::unique_ptr session = nullptr; + std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); + std::string filenamems = filename; + if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") { + filenamems = filenamems + ".ms"; + } + + MS_LOG(INFO) << "start reading model file " << filenamems.c_str(); + std::cout << "start reading model file " << filenamems.c_str() << std::endl; + auto *model = mindspore::lite::Model::Import(filenamems.c_str()); + if (model == nullptr) { + MS_LOG(ERROR) << "create model for train session failed"; + return nullptr; + } + session = std::unique_ptr(session::LiteSession::CreateSession(&context)); + if (session == nullptr) { + MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); + std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; + delete model; + return nullptr; + } + if (session->CompileGraph(model) != RET_OK) { + MS_LOG(ERROR) << "Cannot compile model"; + delete model; + return nullptr; + } + delete model; + return session; +} + int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs, bool check_accuracy) { auto start_prepare_time = GetTimeUs(); - std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); Context context; context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_); context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; @@ -335,60 +405,26 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string } std::unique_ptr session; if (train_session) { - if (!bb_filename.empty()) { - MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename; - std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl; - session = std::unique_ptr( - session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); - if (session == nullptr) { - MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str(); - std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } - - } else { - MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str(); - std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl; - session = std::unique_ptr( - session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); - if (session == nullptr) { - MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str(); - std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } - } - if (epochs > 0) { - if (flags_->virtual_batch_) { - session->SetupVirtualBatch(epochs); - } - session->Train(); + session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs); + if (session == nullptr) { + MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed."; + return RET_ERROR; } } else { - std::string filenamems = filename; - if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") { - filenamems = filenamems + ".ms"; - } - - MS_LOG(INFO) << "start reading model file " << filenamems.c_str(); - std::cout << "start reading model file " << filenamems.c_str() << std::endl; - auto *model = mindspore::lite::Model::Import(filenamems.c_str()); - if (model == nullptr) { - MS_LOG(ERROR) << "create model for train session failed"; - return RET_ERROR; - } - session = std::unique_ptr(session::LiteSession::CreateSession(&context)); + session = CreateAndRunNetworkForInference(filename, context); if (session == nullptr) { - MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); - std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; - delete model; + MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed."; return RET_ERROR; } - if (session->CompileGraph(model) != RET_OK) { - MS_LOG(ERROR) << "Cannot compile model"; - delete model; - return RET_ERROR; + } + + if (!flags_->resize_dims_.empty()) { + auto ret = session->Resize(session->GetInputs(), flags_->resize_dims_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Input tensor resize failed."; + std::cout << "Input tensor resize failed."; + return ret; } - delete model; } auto end_prepare_time = GetTimeUs(); @@ -598,6 +634,23 @@ int NetTrain::InitCallbackParameter() { return RET_OK; } +void NetTrainFlags::InitResizeDimsList() { + std::string content = this->resize_dims_in_; + std::vector shape; + auto shape_strs = StringSplit(content, std::string(DELIM_COLON)); + for (const auto &shape_str : shape_strs) { + shape.clear(); + auto dim_strs = StringSplit(shape_str, std::string(DELIM_COMMA)); + std::cout << "Resize Dims: "; + for (const auto &dim_str : dim_strs) { + std::cout << dim_str << " "; + shape.emplace_back(static_cast(std::stoi(dim_str))); + } + std::cout << std::endl; + this->resize_dims_.emplace_back(shape); + } +} + int NetTrain::Init() { if (this->flags_ == nullptr) { return 1; @@ -654,7 +707,13 @@ int NetTrain::Init() { return RET_ERROR; } } - + flags_->InitResizeDimsList(); + if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && + flags_->resize_dims_.size() != flags_->input_data_list_.size()) { + MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; + std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; + return RET_ERROR; + } return RET_OK; } diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index c1478ed8a91..5a33af0fbff 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -73,9 +73,12 @@ class MS_API NetTrainFlags : public virtual FlagParser { AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", ""); AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); + AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", + "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); } ~NetTrainFlags() override = default; + void InitResizeDimsList(); public: // common @@ -101,7 +104,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { std::string export_file_ = ""; std::string resize_dims_in_ = ""; bool layer_checksum_ = false; - std::vector> resize_dims_; + std::vector> resize_dims_; std::string loss_name_ = ""; std::string inference_file_ = ""; }; @@ -127,6 +130,14 @@ class MS_API NetTrain { int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs, bool check_accuracy = true); + std::unique_ptr CreateAndRunNetworkForInference(const std::string &filename, + const Context &context); + + std::unique_ptr CreateAndRunNetworkForTrain(const std::string &filename, + const std::string &bb_filename, + const Context &context, const TrainCfg &train_cfg, + int epochs); + int InitCallbackParameter(); int PrintResult(const std::vector &title, const std::map> &result);