!6159 optimize the TrainOneStepCell for user's define

Merge pull request !6159 from wangnan39/optim_train_one_step_cell
This commit is contained in:
mindspore-ci-bot 2020-09-15 09:29:23 +08:00 committed by Gitee
commit 0f16b7324e
4 changed files with 12 additions and 65 deletions

View File

@ -185,23 +185,21 @@ class TrainOneStepCell(Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -14,9 +14,8 @@
# ============================================================================
"""Loss scale cell for loss scale training."""
import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from .cell_wrapper import TrainOneStepCell
from ..cell import Cell
from ...common import Tensor, RowTensor
from ...common.parameter import Parameter
@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell):
return overflow
class TrainOneStepWithLossScaleCell(Cell):
class TrainOneStepWithLossScaleCell(TrainOneStepCell):
r"""
Network training with loss scaling.
@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell):
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
"""
def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell):
self.less_equal = LessEqual()
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.loss_scaling_manager = None

View File

@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell):
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell):
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)

View File

@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell):
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell):
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)