forked from mindspore-Ecosystem/mindspore
!12512 optimize class TrainOneStepWithLossScaleCell
From: @wangnan39 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
95adf66d30
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,6 @@ from ...common.parameter import Parameter
|
||||||
from ...ops import functional as F
|
from ...ops import functional as F
|
||||||
from ...ops import composite as C
|
from ...ops import composite as C
|
||||||
from ...ops import operations as P
|
from ...ops import operations as P
|
||||||
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual
|
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||||
|
@ -275,22 +274,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
def __init__(self, network, optimizer, scale_sense):
|
def __init__(self, network, optimizer, scale_sense):
|
||||||
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
if context.get_context("device_target") == "GPU":
|
|
||||||
self.gpu_target = True
|
|
||||||
self.float_status = P.FloatStatus()
|
|
||||||
self.addn = P.AddN()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
else:
|
|
||||||
self.gpu_target = False
|
|
||||||
self.alloc_status = NPUAllocFloatStatus()
|
|
||||||
self.get_status = NPUGetFloatStatus()
|
|
||||||
self.clear_status = NPUClearFloatStatus()
|
|
||||||
self.reduce_sum = ReduceSum(keep_dims=False)
|
|
||||||
self.base = Tensor(1, mstype.float32)
|
self.base = Tensor(1, mstype.float32)
|
||||||
self.less_equal = LessEqual()
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
self.allreduce = P.AllReduce()
|
self.allreduce = P.AllReduce()
|
||||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.gpu_target = (context.get_context("device_target") == "GPU")
|
||||||
self.loss_scaling_manager = None
|
self.loss_scaling_manager = None
|
||||||
if isinstance(scale_sense, Cell):
|
if isinstance(scale_sense, Cell):
|
||||||
self.loss_scaling_manager = scale_sense
|
self.loss_scaling_manager = scale_sense
|
||||||
|
@ -307,43 +296,19 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network(*inputs)
|
loss = self.network(*inputs)
|
||||||
init = False
|
|
||||||
if not self.gpu_target:
|
|
||||||
# init overflow buffer
|
|
||||||
init = self.alloc_status()
|
|
||||||
# clear overflow buffer after loss calculated
|
|
||||||
init = F.depend(init, loss)
|
|
||||||
clear_status = self.clear_status(init)
|
|
||||||
loss = F.depend(loss, clear_status)
|
|
||||||
|
|
||||||
scaling_sens = self.scale_sense
|
scaling_sens = self.scale_sense
|
||||||
|
|
||||||
|
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||||
|
|
||||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
# get the overflow buffer
|
# get the overflow buffer
|
||||||
if not self.gpu_target:
|
cond = self.detect_overflow(status, grads)
|
||||||
# get overflow status after grads calculated
|
overflow = self.process_loss_scale(cond)
|
||||||
init = F.depend(init, grads)
|
|
||||||
get_status = self.get_status(init)
|
|
||||||
init = F.depend(init, get_status)
|
|
||||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
|
||||||
else:
|
|
||||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
|
||||||
flag_sum = self.addn(flag_sum)
|
|
||||||
# convert flag_sum to scalar
|
|
||||||
flag_sum = self.reshape(flag_sum, (()))
|
|
||||||
if self.is_distributed:
|
|
||||||
# sum overflow flag over devices
|
|
||||||
flag_reduce = self.allreduce(flag_sum)
|
|
||||||
cond = self.less_equal(self.base, flag_reduce)
|
|
||||||
else:
|
|
||||||
cond = self.less_equal(self.base, flag_sum)
|
|
||||||
overflow = cond
|
|
||||||
if self.loss_scaling_manager is not None:
|
|
||||||
overflow = self.loss_scaling_manager(self.scale_sense, cond)
|
|
||||||
# if there is no overflow, do optimize
|
# if there is no overflow, do optimize
|
||||||
if not overflow:
|
if not overflow:
|
||||||
loss = F.depend(loss, self.optimizer(grads))
|
loss = F.depend(loss, self.optimizer(grads))
|
||||||
|
@ -356,3 +321,84 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
self.scale_sense.set_data(sens)
|
self.scale_sense.set_data(sens)
|
||||||
else:
|
else:
|
||||||
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
||||||
|
|
||||||
|
def start_overflow(self, pre_cond, compute_input):
|
||||||
|
"""
|
||||||
|
Start floating-point overflow detection. Create and clear the overflow detection state.
|
||||||
|
|
||||||
|
Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
|
||||||
|
Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
|
||||||
|
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
|
||||||
|
function, and compute_input should be the input of gradients-computing function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pre_cond(object): A precondition for starting overflow detection. It determines the executing order of
|
||||||
|
overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears
|
||||||
|
status after finishing the process of precondition.
|
||||||
|
compute_input(object): The input of subsequent process. Overflow detection should be performed on a certain
|
||||||
|
computation. Set `compute_input` as the input of the computation, to ensure overflow status is cleared
|
||||||
|
before executing the computation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
|
||||||
|
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
|
||||||
|
The second value is the same as the input of `compute_input`, but contains some information about the
|
||||||
|
execution order.
|
||||||
|
"""
|
||||||
|
status = False
|
||||||
|
if not self.gpu_target:
|
||||||
|
# init overflow buffer
|
||||||
|
status = P.NPUAllocFloatStatus()()
|
||||||
|
status = F.depend(status, pre_cond)
|
||||||
|
# clear overflow buffer
|
||||||
|
clear_status = P.NPUClearFloatStatus()(status)
|
||||||
|
compute_input = F.depend(compute_input, clear_status)
|
||||||
|
return status, compute_input
|
||||||
|
|
||||||
|
def detect_overflow(self, status, compute_output):
|
||||||
|
"""
|
||||||
|
Detect floating-point overflow status.
|
||||||
|
|
||||||
|
Get overflow results after executing the target process for overflow detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status(object): A status instance used to detect the overflow.
|
||||||
|
compute_output: Overflow detection should be performed on a certain computation. Set `compute_output` as
|
||||||
|
the output of the computation, to ensure overflow status is acquired before executing the computation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, whether the overflow occurs or not.
|
||||||
|
"""
|
||||||
|
if not self.gpu_target:
|
||||||
|
status = F.depend(status, compute_output)
|
||||||
|
get_status = P.NPUGetFloatStatus()(status)
|
||||||
|
status = F.depend(status, get_status)
|
||||||
|
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||||
|
flag_sum = self.reduce_sum(status, (0,))
|
||||||
|
else:
|
||||||
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
|
||||||
|
flag_sum = P.AddN()(flag_sum)
|
||||||
|
# convert flag_sum to scalar
|
||||||
|
flag_sum = P.Reshape()(flag_sum, (()))
|
||||||
|
|
||||||
|
if self.is_distributed:
|
||||||
|
# sum overflow flag over devices
|
||||||
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
overflow = self.less_equal(self.base, flag_reduce)
|
||||||
|
else:
|
||||||
|
overflow = self.less_equal(self.base, flag_sum)
|
||||||
|
return overflow
|
||||||
|
|
||||||
|
def process_loss_scale(self, overflow):
|
||||||
|
"""
|
||||||
|
Calculate loss scale according to the overflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overflow(bool): Whether the overflow occurs or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, overflow value.
|
||||||
|
"""
|
||||||
|
if self.loss_scaling_manager is not None:
|
||||||
|
return self.loss_scaling_manager(self.scale_sense, overflow)
|
||||||
|
return overflow
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -330,7 +330,7 @@ def _tensor_grad_overflow(grad):
|
||||||
return grad_overflow(grad)
|
return grad_overflow(grad)
|
||||||
|
|
||||||
|
|
||||||
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||||
"""
|
"""
|
||||||
Encapsulation class of bert network training.
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
@ -344,39 +344,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||||
self.network = network
|
self.cast = P.Cast()
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True,
|
|
||||||
sens_param=True)
|
|
||||||
self.reducer_flag = False
|
|
||||||
self.allreduce = P.AllReduce()
|
|
||||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
||||||
self.reducer_flag = True
|
|
||||||
self.grad_reducer = F.identity
|
|
||||||
self.degree = 1
|
self.degree = 1
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
self.degree = get_group_size()
|
self.degree = get_group_size()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
||||||
self.cast = P.Cast()
|
|
||||||
if context.get_context("device_target") == "GPU":
|
|
||||||
self.gpu_target = True
|
|
||||||
self.float_status = P.FloatStatus()
|
|
||||||
self.addn = P.AddN()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
else:
|
|
||||||
self.gpu_target = False
|
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
|
||||||
self.clear_status = P.NPUClearFloatStatus()
|
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
||||||
self.base = Tensor(1, mstype.float32)
|
|
||||||
self.less_equal = P.LessEqual()
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
self.loss_scale = None
|
self.loss_scale = None
|
||||||
self.loss_scaling_manager = scale_update_cell
|
self.loss_scaling_manager = scale_update_cell
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
|
@ -404,13 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
scaling_sens = self.loss_scale
|
scaling_sens = self.loss_scale
|
||||||
else:
|
else:
|
||||||
scaling_sens = sens
|
scaling_sens = sens
|
||||||
init = False
|
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||||
if not self.gpu_target:
|
|
||||||
# alloc status and clear should be right before gradoperation
|
|
||||||
init = self.alloc_status()
|
|
||||||
init = F.depend(init, loss)
|
|
||||||
clear_status = self.clear_status(init)
|
|
||||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
|
||||||
grads = self.grad(self.network, weights)(input_ids,
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
token_type_id,
|
token_type_id,
|
||||||
|
@ -424,21 +392,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if not self.gpu_target:
|
|
||||||
init = F.depend(init, grads)
|
cond = self.detect_overflow(status, grads)
|
||||||
get_status = self.get_status(init)
|
|
||||||
init = F.depend(init, get_status)
|
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
|
||||||
else:
|
|
||||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
|
||||||
flag_sum = self.addn(flag_sum)
|
|
||||||
flag_sum = self.reshape(flag_sum, (()))
|
|
||||||
if self.is_distributed:
|
|
||||||
# sum overflow flag over devices
|
|
||||||
flag_reduce = self.allreduce(flag_sum)
|
|
||||||
cond = self.less_equal(self.base, flag_reduce)
|
|
||||||
else:
|
|
||||||
cond = self.less_equal(self.base, flag_sum)
|
|
||||||
overflow = cond
|
overflow = cond
|
||||||
if sens is None:
|
if sens is None:
|
||||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||||
|
@ -449,7 +404,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
ret = (loss, cond, scaling_sens)
|
ret = (loss, cond, scaling_sens)
|
||||||
return F.depend(ret, succ)
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|
||||||
|
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||||
"""
|
"""
|
||||||
Encapsulation class of bert network training.
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
@ -464,40 +420,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
||||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
|
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
|
||||||
self.network = network
|
self.cast = P.Cast()
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True,
|
|
||||||
sens_param=True)
|
|
||||||
self.reducer_flag = False
|
|
||||||
self.allreduce = P.AllReduce()
|
|
||||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
||||||
self.reducer_flag = True
|
|
||||||
self.grad_reducer = F.identity
|
|
||||||
self.degree = 1
|
self.degree = 1
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
self.degree = get_group_size()
|
self.degree = get_group_size()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
||||||
self.cast = P.Cast()
|
|
||||||
if context.get_context("device_target") == "GPU":
|
|
||||||
self.gpu_target = True
|
|
||||||
self.float_status = P.FloatStatus()
|
|
||||||
self.addn = P.AddN()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
else:
|
|
||||||
self.gpu_target = False
|
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
|
||||||
self.clear_status = P.NPUClearFloatStatus()
|
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
||||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
|
||||||
self.base = Tensor(1, mstype.float32)
|
|
||||||
self.less_equal = P.LessEqual()
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
self.loss_scale = None
|
self.loss_scale = None
|
||||||
self.loss_scaling_manager = scale_update_cell
|
self.loss_scaling_manager = scale_update_cell
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
|
@ -525,14 +453,8 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
||||||
scaling_sens = self.loss_scale
|
scaling_sens = self.loss_scale
|
||||||
else:
|
else:
|
||||||
scaling_sens = sens
|
scaling_sens = sens
|
||||||
init = False
|
|
||||||
if not self.gpu_target:
|
|
||||||
# alloc status and clear should be right before gradoperation
|
|
||||||
init = self.alloc_status()
|
|
||||||
init = F.depend(init, loss)
|
|
||||||
clear_status = self.clear_status(init)
|
|
||||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
|
||||||
|
|
||||||
|
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||||
grads = self.grad(self.network, weights)(input_ids,
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
token_type_id,
|
token_type_id,
|
||||||
|
@ -546,21 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if not self.gpu_target:
|
cond = self.detect_overflow(status, grads)
|
||||||
init = F.depend(init, grads)
|
|
||||||
get_status = self.get_status(init)
|
|
||||||
init = F.depend(init, get_status)
|
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
|
||||||
else:
|
|
||||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
|
||||||
flag_sum = self.addn(flag_sum)
|
|
||||||
flag_sum = self.reshape(flag_sum, (()))
|
|
||||||
if self.is_distributed:
|
|
||||||
# sum overflow flag over devices
|
|
||||||
flag_reduce = self.allreduce(flag_sum)
|
|
||||||
cond = self.less_equal(self.base, flag_reduce)
|
|
||||||
else:
|
|
||||||
cond = self.less_equal(self.base, flag_sum)
|
|
||||||
overflow = cond
|
overflow = cond
|
||||||
if self.loss_scaling_manager is not None:
|
if self.loss_scaling_manager is not None:
|
||||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||||
|
|
Loading…
Reference in New Issue