forked from OSSInnovation/mindspore
bugfix(side effect): fix adding wrong control depend between AllReduce and GetStatus.
This commit is contained in:
parent
c9fba7f091
commit
5d4144de11
|
@ -370,7 +370,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
self.grad_reducer = F.identity
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = get_group_size()
|
||||
|
@ -428,9 +428,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
|
|
|
@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.grad_reducer = None
|
||||
self.grad_reducer = F.identity
|
||||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
|
@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss)))
|
||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
# get the overflow buffer
|
||||
if not self.gpu_target:
|
||||
self.get_status(init)
|
||||
|
|
Loading…
Reference in New Issue