forked from mindspore-Ecosystem/mindspore
!6159 optimize the TrainOneStepCell for user's define
Merge pull request !6159 from wangnan39/optim_train_one_step_cell
This commit is contained in:
commit
0f16b7324e
|
@ -185,23 +185,21 @@ class TrainOneStepCell(Cell):
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
self.sens = sens
|
self.sens = sens
|
||||||
self.reducer_flag = False
|
self.reducer_flag = False
|
||||||
self.grad_reducer = None
|
self.grad_reducer = F.identity
|
||||||
parallel_mode = _get_parallel_mode()
|
self.parallel_mode = _get_parallel_mode()
|
||||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
self.reducer_flag = True
|
self.reducer_flag = True
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
mean = _get_gradients_mean()
|
mean = _get_gradients_mean()
|
||||||
degree = _get_device_num()
|
degree = _get_device_num()
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network(*inputs)
|
loss = self.network(*inputs)
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||||
if self.reducer_flag:
|
grads = self.grad_reducer(grads)
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(grads)
|
|
||||||
return F.depend(loss, self.optimizer(grads))
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Loss scale cell for loss scale training."""
|
"""Loss scale cell for loss scale training."""
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
from .cell_wrapper import TrainOneStepCell
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from ...common import Tensor, RowTensor
|
from ...common import Tensor, RowTensor
|
||||||
from ...common.parameter import Parameter
|
from ...common.parameter import Parameter
|
||||||
|
@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell):
|
||||||
return overflow
|
return overflow
|
||||||
|
|
||||||
|
|
||||||
class TrainOneStepWithLossScaleCell(Cell):
|
class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
r"""
|
r"""
|
||||||
Network training with loss scaling.
|
Network training with loss scaling.
|
||||||
|
|
||||||
|
@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
|
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
|
||||||
>>> output = train_network(inputs, label, scaling_sens)
|
>>> output = train_network(inputs, label, scaling_sens)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, optimizer, scale_sense):
|
def __init__(self, network, optimizer, scale_sense):
|
||||||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
||||||
self.network = network
|
|
||||||
self.network.set_grad()
|
|
||||||
self.network.add_flags(defer_inline=True)
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
if context.get_context("device_target") == "GPU":
|
if context.get_context("device_target") == "GPU":
|
||||||
self.gpu_target = True
|
self.gpu_target = True
|
||||||
|
@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
self.less_equal = LessEqual()
|
self.less_equal = LessEqual()
|
||||||
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
||||||
self.allreduce = P.AllReduce()
|
self.allreduce = P.AllReduce()
|
||||||
self.parallel_mode = _get_parallel_mode()
|
|
||||||
self.grad_reducer = F.identity
|
|
||||||
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
|
|
||||||
if self.reducer_flag:
|
|
||||||
mean = _get_gradients_mean()
|
|
||||||
degree = _get_device_num()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
||||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
||||||
|
|
||||||
self.loss_scaling_manager = None
|
self.loss_scaling_manager = None
|
||||||
|
|
|
@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
sens (Number): The adjust parameter. Default: 1.0.
|
sens (Number): The adjust parameter. Default: 1.0.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||||
self.network = network
|
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
||||||
self.sens = sens
|
|
||||||
self.reducer_flag = False
|
|
||||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
||||||
self.reducer_flag = True
|
|
||||||
self.grad_reducer = None
|
|
||||||
if self.reducer_flag:
|
|
||||||
mean = context.get_auto_parallel_context("gradients_mean")
|
|
||||||
degree = get_group_size()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
self.cast(F.tuple_to_array((self.sens,)),
|
self.cast(F.tuple_to_array((self.sens,)),
|
||||||
mstype.float32))
|
mstype.float32))
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if self.reducer_flag:
|
grads = self.grad_reducer(grads)
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(grads)
|
|
||||||
succ = self.optimizer(grads)
|
succ = self.optimizer(grads)
|
||||||
return F.depend(loss, succ)
|
return F.depend(loss, succ)
|
||||||
|
|
||||||
|
|
|
@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||||
self.network = network
|
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
||||||
self.sens = sens
|
|
||||||
self.reducer_flag = False
|
|
||||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
||||||
self.reducer_flag = True
|
|
||||||
self.grad_reducer = None
|
|
||||||
if self.reducer_flag:
|
|
||||||
mean = context.get_auto_parallel_context("gradients_mean")
|
|
||||||
degree = get_group_size()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
self.cast(F.tuple_to_array((self.sens,)),
|
self.cast(F.tuple_to_array((self.sens,)),
|
||||||
mstype.float32))
|
mstype.float32))
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
if self.reducer_flag:
|
grads = self.grad_reducer(grads)
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(grads)
|
|
||||||
succ = self.optimizer(grads)
|
succ = self.optimizer(grads)
|
||||||
return F.depend(loss, succ)
|
return F.depend(loss, succ)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue