forked from mindspore-Ecosystem/mindspore
!23593 [MS][LITE] fix lite train issue
Merge pull request !23593 from zhengjun10/fix
This commit is contained in:
commit
61ab434493
|
@ -71,8 +71,7 @@ class OptimizerKernel : public InnerKernel {
|
|||
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||
auto param = in_tensors_.at(indices[ix]);
|
||||
if (param->data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "has no data";
|
||||
return params;
|
||||
continue;
|
||||
}
|
||||
params.push_back(param);
|
||||
}
|
||||
|
@ -98,10 +97,6 @@ class OptimizerKernel : public InnerKernel {
|
|||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
|
||||
<< param->data_type() << " is not a vlid params tensor";
|
||||
}
|
||||
return found;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ TrainLoop::~TrainLoop() {}
|
|||
|
||||
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
||||
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
|
||||
MS_CHECK_GE(epochs, 0, RET_ERROR);
|
||||
auto ret = train_session_->Train();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainLoop train failed";
|
||||
|
|
|
@ -736,6 +736,8 @@ int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶
|
|||
}
|
||||
}
|
||||
if (!found) {
|
||||
MS_LOG(ERROR) << "Tensor " << param->tensor_name() << " with " << param->ElementsNum() << " elelmts and type "
|
||||
<< param->data_type() << " is not a valid params tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -1081,6 +1083,10 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu
|
|||
|
||||
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "context cannot be nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto session = std::make_unique<lite::TrainSession>();
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
|
|
Loading…
Reference in New Issue