forked from mindspore-Ecosystem/mindspore
fix gpu resnet script
This commit is contained in:
parent
0b4de00176
commit
47a8d3e5e8
|
@ -42,6 +42,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
device_num = int(os.getenv("DEVICE_NUM"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
|
|
@ -53,15 +53,12 @@ if __name__ == '__main__':
|
|||
mirror_mean=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
elif target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
init("nccl")
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean')
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num)
|
||||
|
||||
|
@ -77,8 +74,11 @@ if __name__ == '__main__':
|
|||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
if target == 'GPU':
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean')
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue