From 2321a6b855be99ee755c0cce8e87016376592f13 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Thu, 16 Sep 2021 11:49:22 +0800 Subject: [PATCH] fix lite train issue --- mindspore/lite/src/train/optimizer_kernel.h | 7 +------ mindspore/lite/src/train/train_loop.cc | 1 + mindspore/lite/src/train/train_session.cc | 6 ++++++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/src/train/optimizer_kernel.h b/mindspore/lite/src/train/optimizer_kernel.h index b1bdf7f3bc2..56b1310543a 100644 --- a/mindspore/lite/src/train/optimizer_kernel.h +++ b/mindspore/lite/src/train/optimizer_kernel.h @@ -71,8 +71,7 @@ class OptimizerKernel : public InnerKernel { for (size_t ix = 0; ix < indices.size(); ix++) { auto param = in_tensors_.at(indices[ix]); if (param->data() == nullptr) { - MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "has no data"; - return params; + continue; } params.push_back(param); } @@ -98,10 +97,6 @@ class OptimizerKernel : public InnerKernel { break; } } - if (!found) { - MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type " - << param->data_type() << " is not a vlid params tensor"; - } return found; } diff --git a/mindspore/lite/src/train/train_loop.cc b/mindspore/lite/src/train/train_loop.cc index 72ffe00123d..4e07166ca50 100644 --- a/mindspore/lite/src/train/train_loop.cc +++ b/mindspore/lite/src/train/train_loop.cc @@ -37,6 +37,7 @@ TrainLoop::~TrainLoop() {} int TrainLoop::Train(int epochs, Dataset *ds, std::vector cbs, LoadDataFunc load_func) { MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr"); + MS_CHECK_GE(epochs, 0, RET_ERROR); auto ret = train_session_->Train(); if (ret != RET_OK) { MS_LOG(ERROR) << "TrainLoop train failed"; diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index d7c6b598818..fd3ba430d35 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -736,6 +736,8 @@ int TrainSession::SetOptimizerParams(const std::vector ¶ } } if (!found) { + MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type " + << param->data_type() << " is not a valid params tensor"; return RET_ERROR; } } @@ -1081,6 +1083,10 @@ int TrainSession::UpdateFeatureMaps(const std::vector &featu session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context, bool train_mode, const lite::TrainCfg *cfg) { + if (context == nullptr) { + MS_LOG(ERROR) << "context cannot be nullptr"; + return nullptr; + } auto session = std::make_unique(); if (session == nullptr) { MS_LOG(ERROR) << "create session failed";