diff --git a/model_zoo/official/cv/googlenet/train.py b/model_zoo/official/cv/googlenet/train.py index b050d6b5325..835a21af87f 100644 --- a/model_zoo/official/cv/googlenet/train.py +++ b/model_zoo/official/cv/googlenet/train.py @@ -205,5 +205,9 @@ if __name__ == '__main__': ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, config=config_ck) loss_cb = LossMonitor() - model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + + cbs = [time_cb, ckpoint_cb, loss_cb] + if device_num > 1 and rank != 0: + cbs = [time_cb, loss_cb] + model.train(cfg.epoch_size, dataset, callbacks=cbs) print("train success")