From 95e80f013a98a91df012b73501f6b360f18b0b77 Mon Sep 17 00:00:00 2001 From: wukesong Date: Wed, 1 Sep 2021 16:33:26 +0800 Subject: [PATCH] fix sink_mode --- model_zoo/research/cv/dem/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_zoo/research/cv/dem/train.py b/model_zoo/research/cv/dem/train.py index 082a45c84ac..409d515401c 100644 --- a/model_zoo/research/cv/dem/train.py +++ b/model_zoo/research/cv/dem/train.py @@ -131,7 +131,7 @@ if __name__ == "__main__": ckpt_config = CheckpointConfig(save_checkpoint_steps=interval_step) ckpt_callback = ModelCheckpoint(prefix='auto_parallel', config=ckpt_config) t1 = time.time() - model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=False) + model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=True) end = time.time() t3 = 1000 * (end - t1) / (88 * epoch_size) print('total time:', end - start) @@ -142,7 +142,7 @@ if __name__ == "__main__": else: for i in range(epoch_size): t1 = time.time() - model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=False) + model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=True) t2 = time.time() t3 = 1000 * (t2 - t1) / 88 if args.train_mode == 'att':