!23593 [MS][LITE] fix lite train issue

Merge pull request !23593 from zhengjun10/fix
This commit is contained in:
i-robot 2021-09-16 06:26:09 +00:00 committed by Gitee
commit 61ab434493
3 changed files with 8 additions and 6 deletions

View File

@ -71,8 +71,7 @@ class OptimizerKernel : public InnerKernel {
for (size_t ix = 0; ix < indices.size(); ix++) {
auto param = in_tensors_.at(indices[ix]);
if (param->data() == nullptr) {
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "has no data";
return params;
continue;
}
params.push_back(param);
}
@ -98,10 +97,6 @@ class OptimizerKernel : public InnerKernel {
break;
}
}
if (!found) {
MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
<< param->data_type() << " is not a vlid params tensor";
}
return found;
}

View File

@ -37,6 +37,7 @@ TrainLoop::~TrainLoop() {}
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
MS_CHECK_GE(epochs, 0, RET_ERROR);
auto ret = train_session_->Train();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TrainLoop train failed";

View File

@ -736,6 +736,8 @@ int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> &para
}
}
if (!found) {
MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
<< param->data_type() << " is not a valid params tensor";
return RET_ERROR;
}
}
@ -1081,6 +1083,10 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
if (context == nullptr) {
MS_LOG(ERROR) << "context cannot be nullptr";
return nullptr;
}
auto session = std::make_unique<lite::TrainSession>();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";