gradreducer modify

This commit is contained in:
yao_yf 2022-12-30 10:57:10 +08:00
parent 14db9b3678
commit 0ae256dae4
3 changed files with 14 additions and 9 deletions

View File

@ -1,7 +1,7 @@
mindspore.nn.DistributedGradReducer
===================================
.. py:class:: mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)
.. py:class:: mindspore.nn.DistributedGradReducer(parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)
分布式优化器。
@ -9,7 +9,7 @@ mindspore.nn.DistributedGradReducer
参数:
- **parameters** (list) - 需要更新的参数。
- **mean** (bool) - 当mean为True时对AllReduce之后的梯度求均值。默认值True。
- **mean** (bool) - 当mean为True时对AllReduce之后的梯度求均值。未指定时使用auto_paralel_context中的配置“gradients_mean”。 默认值None。
- **degree** (int) - 平均系数通常等于设备编号。默认值None。
- **fusion_type** (int) - AllReduce算子的融合类型。默认值1。
- **group** (str) - AllReduce算子的通信域若需要自定义通信域需要调用create_group接口。默认值GlobalComm.WORLD_COMM_GROUP。

View File

@ -363,6 +363,8 @@ class TrainOneStepCell(Cell):
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
from mindspore.communication.management import GlobalComm
group = GlobalComm.WORLD_COMM_GROUP
if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
from mindspore.communication.management import get_group_size, create_group, get_rank
group_number = get_group_size() // 8
@ -371,10 +373,8 @@ class TrainOneStepCell(Cell):
current_index = get_rank() // 8
server_group_name = "allreduce_" + str(current_index)
create_group(server_group_name, group_list[current_index])
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree,
group=server_group_name)
else:
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
group = server_group_name
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
def construct(self, *inputs):
loss = self.network(*inputs)

View File

@ -295,7 +295,9 @@ class DistributedGradReducer(Cell):
Args:
parameters (list): the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: True.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
When it is not specified, using the configuration `gradients_mean` in auto_parallel_context.
Default: None.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
fusion_type (int): The type of all reduce fusion. Default: 1.
group (str): The communication group to work on. Normally, the group should be created by create_group,
@ -387,9 +389,12 @@ class DistributedGradReducer(Cell):
256.0
"""
def __init__(self, parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
def __init__(self, parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.map_ = C.Map()
self.mean = mean
if mean is None:
self.mean = auto_parallel_context().get_gradients_mean()
if degree is None:
self.degree = get_group_size()
else:
@ -399,7 +404,7 @@ class DistributedGradReducer(Cell):
"should large than 0 and be int, degree: {}.".format(degree))
self.degree = degree
self.degree = Tensor(1.0 / self.degree, mstype.float32)
self.mean = mean
self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters)
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()