forked from mindspore-Ecosystem/mindspore
default fusion group for ge
This commit is contained in:
parent
961af9fed1
commit
699166e552
|
@ -274,10 +274,7 @@ class _AutoParallelContext:
|
||||||
|
|
||||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
||||||
if context.get_context("device_target") == "Ascend":
|
if context.get_context("device_target") == "Ascend":
|
||||||
if group == "":
|
_set_fusion_strategy_by_idx(indices)
|
||||||
_set_fusion_strategy_by_idx(indices)
|
|
||||||
else:
|
|
||||||
_set_fusion_strategy_by_idx(indices, group)
|
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
|
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
|
@ -330,10 +327,7 @@ class _AutoParallelContext:
|
||||||
|
|
||||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
||||||
if context.get_context("device_target") == "Ascend":
|
if context.get_context("device_target") == "Ascend":
|
||||||
if group == "":
|
_set_fusion_strategy_by_size(sizes)
|
||||||
_set_fusion_strategy_by_size(sizes)
|
|
||||||
else:
|
|
||||||
_set_fusion_strategy_by_size(sizes, group)
|
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
|
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue