forked from mindspore-Ecosystem/mindspore
gradreducer modify
This commit is contained in:
parent
14db9b3678
commit
0ae256dae4
|
@ -1,7 +1,7 @@
|
|||
mindspore.nn.DistributedGradReducer
|
||||
===================================
|
||||
|
||||
.. py:class:: mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)
|
||||
.. py:class:: mindspore.nn.DistributedGradReducer(parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
分布式优化器。
|
||||
|
||||
|
@ -9,7 +9,7 @@ mindspore.nn.DistributedGradReducer
|
|||
|
||||
参数:
|
||||
- **parameters** (list) - 需要更新的参数。
|
||||
- **mean** (bool) - 当mean为True时,对AllReduce之后的梯度求均值。默认值:True。
|
||||
- **mean** (bool) - 当mean为True时,对AllReduce之后的梯度求均值。未指定时,使用auto_paralel_context中的配置“gradients_mean”。 默认值:None。
|
||||
- **degree** (int) - 平均系数,通常等于设备编号。默认值:None。
|
||||
- **fusion_type** (int) - AllReduce算子的融合类型。默认值:1。
|
||||
- **group** (str) - AllReduce算子的通信域,若需要自定义通信域,需要调用create_group接口。默认值:GlobalComm.WORLD_COMM_GROUP。
|
||||
|
|
|
@ -363,6 +363,8 @@ class TrainOneStepCell(Cell):
|
|||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
from mindspore.communication.management import GlobalComm
|
||||
group = GlobalComm.WORLD_COMM_GROUP
|
||||
if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
|
||||
from mindspore.communication.management import get_group_size, create_group, get_rank
|
||||
group_number = get_group_size() // 8
|
||||
|
@ -371,10 +373,8 @@ class TrainOneStepCell(Cell):
|
|||
current_index = get_rank() // 8
|
||||
server_group_name = "allreduce_" + str(current_index)
|
||||
create_group(server_group_name, group_list[current_index])
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree,
|
||||
group=server_group_name)
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
group = server_group_name
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.network(*inputs)
|
||||
|
|
|
@ -295,7 +295,9 @@ class DistributedGradReducer(Cell):
|
|||
|
||||
Args:
|
||||
parameters (list): the parameters to be updated.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: True.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||
When it is not specified, using the configuration `gradients_mean` in auto_parallel_context.
|
||||
Default: None.
|
||||
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
|
||||
fusion_type (int): The type of all reduce fusion. Default: 1.
|
||||
group (str): The communication group to work on. Normally, the group should be created by create_group,
|
||||
|
@ -387,9 +389,12 @@ class DistributedGradReducer(Cell):
|
|||
256.0
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
def __init__(self, parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
||||
self.map_ = C.Map()
|
||||
self.mean = mean
|
||||
if mean is None:
|
||||
self.mean = auto_parallel_context().get_gradients_mean()
|
||||
if degree is None:
|
||||
self.degree = get_group_size()
|
||||
else:
|
||||
|
@ -399,7 +404,7 @@ class DistributedGradReducer(Cell):
|
|||
"should large than 0 and be int, degree: {}.".format(degree))
|
||||
self.degree = degree
|
||||
self.degree = Tensor(1.0 / self.degree, mstype.float32)
|
||||
self.mean = mean
|
||||
|
||||
self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters)
|
||||
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
|
|
Loading…
Reference in New Issue