forked from mindspore-Ecosystem/mindspore
!1830 add time_monitor
Merge pull request !1830 from wukesong/widedeep-time
This commit is contained in:
commit
c2b2ed1482
|
@ -14,7 +14,7 @@
|
||||||
""" test_training """
|
""" test_training """
|
||||||
import os
|
import os
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack
|
from src.callbacks import LossCallBack
|
||||||
|
@ -75,7 +75,7 @@ def test_train(configure):
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
|
||||||
keep_checkpoint_max=5)
|
keep_checkpoint_max=5)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig)
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig)
|
||||||
model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb])
|
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue