forked from mindspore-Ecosystem/mindspore
fix eval error in single device and data parallel mode
This commit is contained in:
parent
77dca3ef0b
commit
f4ce5768f1
|
@ -95,7 +95,7 @@ def test_train_eval(config):
|
|||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
|
||||
|
||||
out = model.eval(ds_eval)
|
||||
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
|
||||
print("=====" * 5 + "model.eval() initialized: {}".format(out))
|
||||
model.train(epochs, ds_train,
|
||||
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],
|
||||
|
|
|
@ -105,7 +105,7 @@ def train_and_eval(config):
|
|||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
||||
config=ckptconfig)
|
||||
out = model.eval(ds_eval)
|
||||
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
|
||||
print("=====" * 5 + "model.eval() initialized: {}".format(out))
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if get_rank() == 0:
|
||||
|
|
Loading…
Reference in New Issue