forked from mindspore-Ecosystem/mindspore
!7155 fix resnext50 gpu amp
Merge pull request !7155 from zhaoting/master
This commit is contained in:
commit
1d7c759cea
|
@ -240,12 +240,8 @@ def train(cloud_args=None):
|
|||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
|
||||
if args.platform == "Ascend":
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O3")
|
||||
else:
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O2")
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O3")
|
||||
|
||||
# checkpoint save
|
||||
progress_cb = ProgressMonitor(args)
|
||||
|
|
Loading…
Reference in New Issue