forked from mindspore-Ecosystem/mindspore
!11066 GPU add restrict for bert script
From: @VectorSL Reviewed-by: @gaoxiong1,@dylangeng,@anyrenwei Signed-off-by: @gaoxiong1
This commit is contained in:
commit
30560be800
|
@ -86,7 +86,7 @@ def _get_optimizer(args_opt, network):
|
|||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
if args_opt.enable_lossscale == "true":
|
||||
if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
|
||||
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
|
@ -214,7 +214,7 @@ def run_pretrain():
|
|||
accumulation_steps = args_opt.accumulation_steps
|
||||
enable_global_norm = cfg.enable_global_norm
|
||||
if accumulation_steps <= 1:
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
|
||||
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue