forked from mindspore-Ecosystem/mindspore
!12512 optimize class TrainOneStepWithLossScaleCell
From: @wangnan39 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
95adf66d30
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -22,7 +22,6 @@ from ...common.parameter import Parameter
|
|||
from ...ops import functional as F
|
||||
from ...ops import composite as C
|
||||
from ...ops import operations as P
|
||||
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual
|
||||
from ...common import dtype as mstype
|
||||
|
||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
|
@ -275,22 +274,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
def __init__(self, network, optimizer, scale_sense):
|
||||
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
||||
self.hyper_map = C.HyperMap()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = NPUAllocFloatStatus()
|
||||
self.get_status = NPUGetFloatStatus()
|
||||
self.clear_status = NPUClearFloatStatus()
|
||||
self.reduce_sum = ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = LessEqual()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.allreduce = P.AllReduce()
|
||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
||||
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.gpu_target = (context.get_context("device_target") == "GPU")
|
||||
self.loss_scaling_manager = None
|
||||
if isinstance(scale_sense, Cell):
|
||||
self.loss_scaling_manager = scale_sense
|
||||
|
@ -307,43 +296,19 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer after loss calculated
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
loss = F.depend(loss, clear_status)
|
||||
|
||||
scaling_sens = self.scale_sense
|
||||
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
# get the overflow buffer
|
||||
if not self.gpu_target:
|
||||
# get overflow status after grads calculated
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(self.scale_sense, cond)
|
||||
cond = self.detect_overflow(status, grads)
|
||||
overflow = self.process_loss_scale(cond)
|
||||
# if there is no overflow, do optimize
|
||||
if not overflow:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
|
@ -356,3 +321,84 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
self.scale_sense.set_data(sens)
|
||||
else:
|
||||
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
||||
|
||||
def start_overflow(self, pre_cond, compute_input):
|
||||
"""
|
||||
Start floating-point overflow detection. Create and clear the overflow detection state.
|
||||
|
||||
Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
|
||||
Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
|
||||
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
|
||||
function, and compute_input should be the input of gradients-computing function.
|
||||
|
||||
Args:
|
||||
pre_cond(object): A precondition for starting overflow detection. It determines the executing order of
|
||||
overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears
|
||||
status after finishing the process of precondition.
|
||||
compute_input(object): The input of subsequent process. Overflow detection should be performed on a certain
|
||||
computation. Set `compute_input` as the input of the computation, to ensure overflow status is cleared
|
||||
before executing the computation.
|
||||
|
||||
Returns:
|
||||
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
|
||||
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
|
||||
The second value is the same as the input of `compute_input`, but contains some information about the
|
||||
execution order.
|
||||
"""
|
||||
status = False
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
status = P.NPUAllocFloatStatus()()
|
||||
status = F.depend(status, pre_cond)
|
||||
# clear overflow buffer
|
||||
clear_status = P.NPUClearFloatStatus()(status)
|
||||
compute_input = F.depend(compute_input, clear_status)
|
||||
return status, compute_input
|
||||
|
||||
def detect_overflow(self, status, compute_output):
|
||||
"""
|
||||
Detect floating-point overflow status.
|
||||
|
||||
Get overflow results after executing the target process for overflow detection.
|
||||
|
||||
Args:
|
||||
status(object): A status instance used to detect the overflow.
|
||||
compute_output: Overflow detection should be performed on a certain computation. Set `compute_output` as
|
||||
the output of the computation, to ensure overflow status is acquired before executing the computation.
|
||||
|
||||
Returns:
|
||||
bool, whether the overflow occurs or not.
|
||||
"""
|
||||
if not self.gpu_target:
|
||||
status = F.depend(status, compute_output)
|
||||
get_status = P.NPUGetFloatStatus()(status)
|
||||
status = F.depend(status, get_status)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(status, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
|
||||
flag_sum = P.AddN()(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = P.Reshape()(flag_sum, (()))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
return overflow
|
||||
|
||||
def process_loss_scale(self, overflow):
|
||||
"""
|
||||
Calculate loss scale according to the overflow.
|
||||
|
||||
Args:
|
||||
overflow(bool): Whether the overflow occurs or not.
|
||||
|
||||
Returns:
|
||||
bool, overflow value.
|
||||
"""
|
||||
if self.loss_scaling_manager is not None:
|
||||
return self.loss_scaling_manager(self.scale_sense, overflow)
|
||||
return overflow
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -330,7 +330,7 @@ def _tensor_grad_overflow(grad):
|
|||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
|
@ -344,39 +344,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
|
@ -404,13 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -424,21 +392,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if not self.gpu_target:
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
cond = self.detect_overflow(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
|
@ -449,7 +404,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
|
@ -464,40 +420,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
|
@ -525,14 +453,8 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
|
||||
status, scaling_sens = self.start_overflow(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -546,21 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
|||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if not self.gpu_target:
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
cond = self.detect_overflow(status, grads)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||
|
|
Loading…
Reference in New Issue