forked from mindspore-Ecosystem/mindspore
!20584 [MS][LITE] fix lite train demo memory increase along training
Merge pull request !20584 from zhengjun10/master
This commit is contained in:
commit
2f65ad73ea
|
@ -42,11 +42,11 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
|
||||
for (auto cb : cbs) cb->Begin(cb_data);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
for (int i = 0; i < epochs; i++) {
|
||||
cb_data.epoch_ = epoch_++;
|
||||
for (auto cb : cbs) cb->EpochBegin(cb_data);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
MSTensorVec row_vec;
|
||||
int s = 0;
|
||||
|
||||
|
@ -62,7 +62,6 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
for (auto cb : cbs) cb->StepEnd(cb_data);
|
||||
iter->GetNextRow(&row_vec);
|
||||
}
|
||||
iter->Stop();
|
||||
int break_loop = false;
|
||||
for (auto cb : cbs) {
|
||||
int ret = cb->EpochEnd(cb_data);
|
||||
|
@ -80,7 +79,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
iter->Stop();
|
||||
for (auto cb : cbs) cb->End(cb_data);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue