forked from OSSInnovation/mindspore
!5584 delete the redundant argument while initializing the class of GradOperation
Merge pull request !5584 from shibeiji/master
This commit is contained in:
commit
d5e02cf474
|
@ -121,9 +121,10 @@ def run_pretrain():
|
|||
|
||||
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
train_steps = args_opt.train_steps * args_opt.accumulation_steps
|
||||
new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps)
|
||||
else:
|
||||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
||||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
|
||||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
|
|
|
@ -487,9 +487,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow")
|
||||
self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss")
|
||||
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
|
|
Loading…
Reference in New Issue