!30326 If split indices is large than gradient's number, throw warnings.
Merge pull request !30326 from linqingke/r1.6
This commit is contained in:
commit
b97af2b6eb
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""grad reducer cell for distributed training"""
|
"""grad reducer cell for distributed training"""
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import log as logger
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.communication.management import GlobalComm, get_group_size
|
from mindspore.communication.management import GlobalComm, get_group_size
|
||||||
from mindspore.common.tensor import RowTensor
|
from mindspore.common.tensor import RowTensor
|
||||||
|
@ -29,6 +30,11 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||||
|
|
||||||
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
|
def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
""" initialize allreduce communication operators"""
|
""" initialize allreduce communication operators"""
|
||||||
|
for indices in split_indices:
|
||||||
|
if indices >= length:
|
||||||
|
logger.warning(f"AllReduce's split index {indices} is greater than or equal to "
|
||||||
|
f"the total gradient's number of {length}")
|
||||||
|
|
||||||
fusion_type = 2 ** 10
|
fusion_type = 2 ** 10
|
||||||
split = 0
|
split = 0
|
||||||
fusion = ()
|
fusion = ()
|
||||||
|
@ -39,6 +45,7 @@ def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM
|
||||||
if split_indices[split] <= i:
|
if split_indices[split] <= i:
|
||||||
fusion_type += 1
|
fusion_type += 1
|
||||||
split += 1
|
split += 1
|
||||||
|
|
||||||
index = tuple(range(1, length + 1))
|
index = tuple(range(1, length + 1))
|
||||||
op_list = ()
|
op_list = ()
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
|
@ -70,6 +77,7 @@ def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fu
|
||||||
op.add_prim_attr('index', index)
|
op.add_prim_attr('index', index)
|
||||||
index += 1
|
index += 1
|
||||||
op_list = op_list + (op,)
|
op_list = op_list + (op,)
|
||||||
|
|
||||||
if not param_fusion:
|
if not param_fusion:
|
||||||
if split_indices and fusion_type == 1:
|
if split_indices and fusion_type == 1:
|
||||||
op_list = _init_allreduce_operators(len(parameters), split_indices, group)
|
op_list = _init_allreduce_operators(len(parameters), split_indices, group)
|
||||||
|
|
Loading…
Reference in New Issue