From 009fb02ef4dd4467517903dfc20b2c59394c9048 Mon Sep 17 00:00:00 2001 From: linqingke Date: Mon, 21 Feb 2022 14:40:33 +0800 Subject: [PATCH] split indices warning if large than gradient's number. --- mindspore/python/mindspore/nn/wrap/grad_reducer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mindspore/python/mindspore/nn/wrap/grad_reducer.py b/mindspore/python/mindspore/nn/wrap/grad_reducer.py index 14bb0562166..b2b54995cbb 100644 --- a/mindspore/python/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/python/mindspore/nn/wrap/grad_reducer.py @@ -14,6 +14,7 @@ # ============================================================================ """grad reducer cell for distributed training""" from mindspore import context +from mindspore import log as logger from mindspore.nn.cell import Cell from mindspore.communication.management import GlobalComm, get_group_size from mindspore.common.tensor import RowTensor @@ -29,6 +30,11 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt") def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP): """ initialize allreduce communication operators""" + for indices in split_indices: + if indices >= length: + logger.warning(f"AllReduce's split index {indices} is greater than or equal to" + f"the total gradient's number of {length}") + fusion_type = 2 ** 10 split = 0 fusion = () @@ -39,6 +45,7 @@ def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM if split_indices[split] <= i: fusion_type += 1 split += 1 + index = tuple(range(1, length + 1)) op_list = () for i in range(length): @@ -70,6 +77,7 @@ def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fu op.add_prim_attr('index', index) index += 1 op_list = op_list + (op,) + if not param_fusion: if split_indices and fusion_type == 1: op_list = _init_allreduce_operators(len(parameters), split_indices, group)