!12512 optimize class TrainOneStepWithLossScaleCell

From: @wangnan39
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-27 19:23:13 +08:00 committed by Gitee
commit 95adf66d30
2 changed files with 106 additions and 152 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,7 +22,6 @@ from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
from ...ops import operations as P
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual
from ...common import dtype as mstype
_grad_scale = C.MultitypeFuncGraph("grad_scale")
@ -275,22 +274,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = LessEqual()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.less_equal = P.LessEqual()
self.allreduce = P.AllReduce()
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.gpu_target = (context.get_context("device_target") == "GPU")
self.loss_scaling_manager = None
if isinstance(scale_sense, Cell):
self.loss_scaling_manager = scale_sense
@ -307,43 +296,19 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
init = False
if not self.gpu_target:
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer after loss calculated
init = F.depend(init, loss)
clear_status = self.clear_status(init)
loss = F.depend(loss, clear_status)
scaling_sens = self.scale_sense
status, scaling_sens = self.start_overflow(loss, scaling_sens)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
if not self.gpu_target:
# get overflow status after grads calculated
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
# convert flag_sum to scalar
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(self.scale_sense, cond)
cond = self.detect_overflow(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
@ -356,3 +321,84 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
self.scale_sense.set_data(sens)
else:
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
def start_overflow(self, pre_cond, compute_input):
"""
Start floating-point overflow detection. Create and clear the overflow detection state.
Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
function, and compute_input should be the input of gradients-computing function.
Args:
pre_cond(object): A precondition for starting overflow detection. It determines the executing order of
overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears
status after finishing the process of precondition.
compute_input(object): The input of subsequent process. Overflow detection should be performed on a certain
computation. Set `compute_input` as the input of the computation, to ensure overflow status is cleared
before executing the computation.
Returns:
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
The second value is the same as the input of `compute_input`, but contains some information about the
execution order.
"""
status = False
if not self.gpu_target:
# init overflow buffer
status = P.NPUAllocFloatStatus()()
status = F.depend(status, pre_cond)
# clear overflow buffer
clear_status = P.NPUClearFloatStatus()(status)
compute_input = F.depend(compute_input, clear_status)
return status, compute_input
def detect_overflow(self, status, compute_output):
"""
Detect floating-point overflow status.
Get overflow results after executing the target process for overflow detection.
Args:
status(object): A status instance used to detect the overflow.
compute_output: Overflow detection should be performed on a certain computation. Set `compute_output` as
the output of the computation, to ensure overflow status is acquired before executing the computation.
Returns:
bool, whether the overflow occurs or not.
"""
if not self.gpu_target:
status = F.depend(status, compute_output)
get_status = P.NPUGetFloatStatus()(status)
status = F.depend(status, get_status)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(status, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
flag_sum = P.AddN()(flag_sum)
# convert flag_sum to scalar
flag_sum = P.Reshape()(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
overflow = self.less_equal(self.base, flag_reduce)
else:
overflow = self.less_equal(self.base, flag_sum)
return overflow
def process_loss_scale(self, overflow):
"""
Calculate loss scale according to the overflow.
Args:
overflow(bool): Whether the overflow occurs or not.
Returns:
bool, overflow value.
"""
if self.loss_scaling_manager is not None:
return self.loss_scaling_manager(self.scale_sense, overflow)
return overflow

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -330,7 +330,7 @@ def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class BertTrainOneStepWithLossScaleCell(nn.Cell):
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
"""
Encapsulation class of bert network training.
@ -344,39 +344,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
self.cast = P.Cast()
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
@ -404,13 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
scaling_sens = self.loss_scale
else:
scaling_sens = sens
init = False
if not self.gpu_target:
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
init = F.depend(init, loss)
clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status)
status, scaling_sens = self.start_overflow(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
@ -424,21 +392,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if not self.gpu_target:
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
cond = self.detect_overflow(status, grads)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
@ -449,7 +404,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
"""
Encapsulation class of bert network training.
@ -464,40 +420,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
self.cast = P.Cast()
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
@ -525,14 +453,8 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
scaling_sens = self.loss_scale
else:
scaling_sens = sens
init = False
if not self.gpu_target:
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
init = F.depend(init, loss)
clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status)
status, scaling_sens = self.start_overflow(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
@ -546,21 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if not self.gpu_target:
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
cond = self.detect_overflow(status, grads)
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(scaling_sens, cond)