fix eval error in single device and data parallel mode

This commit is contained in:
huangxinjing 2020-11-25 09:32:15 +08:00
parent 77dca3ef0b
commit f4ce5768f1
2 changed files with 2 additions and 2 deletions

View File

@ -95,7 +95,7 @@ def test_train_eval(config):
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],

View File

@ -105,7 +105,7 @@ def train_and_eval(config):
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
config=ckptconfig)
out = model.eval(ds_eval)
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
print("=====" * 5 + "model.eval() initialized: {}".format(out))
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if get_rank() == 0: