forked from mindspore-Ecosystem/mindspore
!1830 add time_monitor
Merge pull request !1830 from wukesong/widedeep-time
This commit is contained in:
commit
c2b2ed1482
|
@ -14,7 +14,7 @@
|
|||
""" test_training """
|
||||
import os
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack
|
||||
|
@ -75,7 +75,7 @@ def test_train(configure):
|
|||
ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
|
||||
keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig)
|
||||
model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb])
|
||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue