diff --git a/example/resnet50_cifar10/dataset.py b/example/resnet50_cifar10/dataset.py index 1d7074d7333..da5cc4a4014 100755 --- a/example/resnet50_cifar10/dataset.py +++ b/example/resnet50_cifar10/dataset.py @@ -42,6 +42,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" device_num = int(os.getenv("DEVICE_NUM")) rank_id = int(os.getenv("RANK_ID")) else: + init("nccl") rank_id = get_rank() device_num = get_group_size() diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 7ef42412354..93efed73381 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -53,15 +53,12 @@ if __name__ == '__main__': mirror_mean=True) auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) ckpt_save_dir = config.save_checkpoint_path - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') - epoch_size = config.epoch_size net = resnet50(class_num=config.class_num) @@ -77,8 +74,11 @@ if __name__ == '__main__': opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale) if target == 'GPU': + loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=True)