forked from mindspore-Ecosystem/mindspore
modify api detect_overflow name in TrainOneStepWithLossScaleCell
This commit is contained in:
parent
aa9bee0ce3
commit
7188a14215
|
@ -298,7 +298,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
loss = self.network(*inputs)
|
||||
scaling_sens = self.scale_sense
|
||||
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||
|
@ -307,7 +307,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
grads = self.grad_reducer(grads)
|
||||
|
||||
# get the overflow buffer
|
||||
cond = self.detect_overflow(status, grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = self.process_loss_scale(cond)
|
||||
# if there is no overflow, do optimize
|
||||
if not overflow:
|
||||
|
@ -322,7 +322,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
else:
|
||||
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
||||
|
||||
def start_overflow(self, pre_cond, compute_input):
|
||||
def start_overflow_check(self, pre_cond, compute_input):
|
||||
"""
|
||||
Start floating-point overflow detection. Create and clear the overflow detection state.
|
||||
|
||||
|
@ -355,9 +355,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
compute_input = F.depend(compute_input, clear_status)
|
||||
return status, compute_input
|
||||
|
||||
def detect_overflow(self, status, compute_output):
|
||||
def get_overflow_status(self, status, compute_output):
|
||||
"""
|
||||
Detect floating-point overflow status.
|
||||
Get floating-point overflow status.
|
||||
|
||||
Get overflow results after executing the target process for overflow detection.
|
||||
|
||||
|
|
|
@ -378,7 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -393,7 +393,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
|||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
cond = self.detect_overflow(status, grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
|
@ -454,7 +454,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
|
|||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -468,7 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
|
|||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
cond = self.detect_overflow(status, grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||
|
|
Loading…
Reference in New Issue