!6159 optimize the TrainOneStepCell for user's define

Merge pull request !6159 from wangnan39/optim_train_one_step_cell
This commit is contained in:
mindspore-ci-bot 2020-09-15 09:29:23 +08:00 committed by Gitee
commit 0f16b7324e
4 changed files with 12 additions and 65 deletions

View File

@ -185,23 +185,21 @@ class TrainOneStepCell(Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = F.identity
parallel_mode = _get_parallel_mode() self.parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_gradients_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, *inputs): def construct(self, *inputs):
weights = self.weights weights = self.weights
loss = self.network(*inputs) loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad(self.network, weights)(*inputs, sens)
if self.reducer_flag: grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))

View File

@ -14,9 +14,8 @@
# ============================================================================ # ============================================================================
"""Loss scale cell for loss scale training.""" """Loss scale cell for loss scale training."""
import mindspore.context as context import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean from .cell_wrapper import TrainOneStepCell
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, RowTensor from ...common import Tensor, RowTensor
from ...common.parameter import Parameter from ...common.parameter import Parameter
@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell):
return overflow return overflow
class TrainOneStepWithLossScaleCell(Cell): class TrainOneStepWithLossScaleCell(TrainOneStepCell):
r""" r"""
Network training with loss scaling. Network training with loss scaling.
@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell):
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens) >>> output = train_network(inputs, label, scaling_sens)
""" """
def __init__(self, network, optimizer, scale_sense): def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU": if context.get_context("device_target") == "GPU":
self.gpu_target = True self.gpu_target = True
@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell):
self.less_equal = LessEqual() self.less_equal = LessEqual()
self.depend_parameter_use = ControlDepend(depend_mode=1) self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.loss_scaling_manager = None self.loss_scaling_manager = None

View File

@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell):
sens (Number): The adjust parameter. Default: 1.0. sens (Number): The adjust parameter. Default: 1.0.
""" """
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.cast = P.Cast() self.cast = P.Cast()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell):
self.cast(F.tuple_to_array((self.sens,)), self.cast(F.tuple_to_array((self.sens,)),
mstype.float32)) mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads) succ = self.optimizer(grads)
return F.depend(loss, succ) return F.depend(loss, succ)

View File

@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell):
""" """
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.cast = P.Cast() self.cast = P.Cast()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell):
self.cast(F.tuple_to_array((self.sens,)), self.cast(F.tuple_to_array((self.sens,)),
mstype.float32)) mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads) succ = self.optimizer(grads)
return F.depend(loss, succ) return F.depend(loss, succ)