!20548 [LITE][TRAIN] support resize for train

Merge pull request !20548 from yefeng/133-resize_for_train
This commit is contained in:
i-robot 2021-07-21 01:05:39 +00:00 committed by Gitee
commit 6fe2817146
3 changed files with 122 additions and 52 deletions

View File

@ -81,7 +81,7 @@ class TrainSession : virtual public lite::LiteSession {
return lite::LiteSession::GetOutputByTensorName(tensor_name);
}
int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override {
return lite::RET_ERROR;
return lite::LiteSession::Resize(inputs, dims);
}
std::vector<tensor::MSTensor *> GetPredictions() const override {

View File

@ -35,6 +35,8 @@
namespace mindspore {
namespace lite {
static const char *DELIM_SLASH = "/";
constexpr const char *DELIM_COLON = ":";
constexpr const char *DELIM_COMMA = ",";
constexpr int RET_TOO_BIG = -9;
namespace {
@ -319,10 +321,78 @@ static CpuBindMode FlagToBindMode(int flag) {
return NO_BIND;
}
std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(const std::string &filename,
const std::string &bb_filename,
const Context &context,
const TrainCfg &train_cfg, int epochs) {
std::unique_ptr<session::LiteSession> session = nullptr;
std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
if (!bb_filename.empty()) {
MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
return nullptr;
}
} else {
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
return nullptr;
}
}
if (epochs > 0) {
if (flags_->virtual_batch_) {
session->SetupVirtualBatch(epochs);
}
session->Train();
}
return session;
}
std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForInference(const std::string &filename,
const Context &context) {
std::unique_ptr<session::LiteSession> session = nullptr;
std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
std::string filenamems = filename;
if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") {
filenamems = filenamems + ".ms";
}
MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
std::cout << "start reading model file " << filenamems.c_str() << std::endl;
auto *model = mindspore::lite::Model::Import(filenamems.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return nullptr;
}
session = std::unique_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
if (session == nullptr) {
MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
delete model;
return nullptr;
}
if (session->CompileGraph(model) != RET_OK) {
MS_LOG(ERROR) << "Cannot compile model";
delete model;
return nullptr;
}
delete model;
return session;
}
int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session,
int epochs, bool check_accuracy) {
auto start_prepare_time = GetTimeUs();
std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_);
context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
@ -335,60 +405,26 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
}
std::unique_ptr<session::LiteSession> session;
if (train_session) {
if (!bb_filename.empty()) {
MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
return RET_ERROR;
}
} else {
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
return RET_ERROR;
}
}
if (epochs > 0) {
if (flags_->virtual_batch_) {
session->SetupVirtualBatch(epochs);
}
session->Train();
session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs);
if (session == nullptr) {
MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
return RET_ERROR;
}
} else {
std::string filenamems = filename;
if (filenamems.substr(filenamems.find_last_of(".") + 1) != "ms") {
filenamems = filenamems + ".ms";
}
MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
std::cout << "start reading model file " << filenamems.c_str() << std::endl;
auto *model = mindspore::lite::Model::Import(filenamems.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return RET_ERROR;
}
session = std::unique_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
session = CreateAndRunNetworkForInference(filename, context);
if (session == nullptr) {
MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
delete model;
MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
return RET_ERROR;
}
if (session->CompileGraph(model) != RET_OK) {
MS_LOG(ERROR) << "Cannot compile model";
delete model;
return RET_ERROR;
}
if (!flags_->resize_dims_.empty()) {
auto ret = session->Resize(session->GetInputs(), flags_->resize_dims_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Input tensor resize failed.";
std::cout << "Input tensor resize failed.";
return ret;
}
delete model;
}
auto end_prepare_time = GetTimeUs();
@ -598,6 +634,23 @@ int NetTrain::InitCallbackParameter() {
return RET_OK;
}
void NetTrainFlags::InitResizeDimsList() {
std::string content = this->resize_dims_in_;
std::vector<int> shape;
auto shape_strs = StringSplit(content, std::string(DELIM_COLON));
for (const auto &shape_str : shape_strs) {
shape.clear();
auto dim_strs = StringSplit(shape_str, std::string(DELIM_COMMA));
std::cout << "Resize Dims: ";
for (const auto &dim_str : dim_strs) {
std::cout << dim_str << " ";
shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
}
std::cout << std::endl;
this->resize_dims_.emplace_back(shape);
}
}
int NetTrain::Init() {
if (this->flags_ == nullptr) {
return 1;
@ -654,7 +707,13 @@ int NetTrain::Init() {
return RET_ERROR;
}
}
flags_->InitResizeDimsList();
if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
return RET_ERROR;
}
return RET_OK;
}

View File

@ -73,9 +73,12 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", "");
AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", "");
AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
}
~NetTrainFlags() override = default;
void InitResizeDimsList();
public:
// common
@ -101,7 +104,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
std::string export_file_ = "";
std::string resize_dims_in_ = "";
bool layer_checksum_ = false;
std::vector<std::vector<int64_t>> resize_dims_;
std::vector<std::vector<int>> resize_dims_;
std::string loss_name_ = "";
std::string inference_file_ = "";
};
@ -127,6 +130,14 @@ class MS_API NetTrain {
int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs,
bool check_accuracy = true);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForInference(const std::string &filename,
const Context &context);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForTrain(const std::string &filename,
const std::string &bb_filename,
const Context &context, const TrainCfg &train_cfg,
int epochs);
int InitCallbackParameter();
int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);