fix sink_mode
This commit is contained in:
parent
45c7bb382b
commit
95e80f013a
|
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
|||
ckpt_config = CheckpointConfig(save_checkpoint_steps=interval_step)
|
||||
ckpt_callback = ModelCheckpoint(prefix='auto_parallel', config=ckpt_config)
|
||||
t1 = time.time()
|
||||
model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=False)
|
||||
model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=True)
|
||||
end = time.time()
|
||||
t3 = 1000 * (end - t1) / (88 * epoch_size)
|
||||
print('total time:', end - start)
|
||||
|
@ -142,7 +142,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
for i in range(epoch_size):
|
||||
t1 = time.time()
|
||||
model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=False)
|
||||
model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=True)
|
||||
t2 = time.time()
|
||||
t3 = 1000 * (t2 - t1) / 88
|
||||
if args.train_mode == 'att':
|
||||
|
|
Loading…
Reference in New Issue