forked from mindspore-Ecosystem/mindspore
fix allreduce and add openstate
This commit is contained in:
parent
d700e5d073
commit
1502b406ab
|
@ -61,6 +61,7 @@ mindspore.set_auto_parallel_context
|
|||
|
||||
- **comm_fusion** (dict) - 用于设置通信算子的融合配置。可以同一类型的通信算子按梯度张量的大小或者顺序分块传输。输入格式为{"通信类型": {"mode":str, "config": None int 或者 list}},每种通信算子的融合配置有两个键:"mode"和"config"。支持以下通信类型的融合类型和配置:
|
||||
|
||||
- openstate:是否开启通信融合功能。通过 True 或 False 来开启或关闭通信融合功能。默认值:True。
|
||||
- allreduce:进行AllReduce算子的通信融合。"mode"包含:"auto"、"size"和"index"。在"auto"模式下,融合的是梯度变量的大小,默认值阈值为"64"MB,"config"对应的值为None。在"size"模式下,需要用户在config的字典中指定梯度大小阈值,这个值必须大于"0"MB。在"mode"为"index"时,它与"all_reduce_fusion_config"相同,用户需要给"config"传入一个列表,里面每个值表示梯度的索引。
|
||||
- allgather:进行AllGather算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。
|
||||
- reducescatter:进行ReduceScatter算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。
|
||||
|
|
|
@ -606,6 +606,9 @@ def set_auto_parallel_context(**kwargs):
|
|||
communication fusion config has two keys: "mode" and "config".
|
||||
It supports following communication fusion types and configurations:
|
||||
|
||||
- openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
|
||||
the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
|
||||
|
||||
- allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
|
||||
and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default
|
||||
fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size
|
||||
|
|
|
@ -40,6 +40,7 @@ class _ParallelFusionConfig:
|
|||
AUTO = "auto"
|
||||
INDEX = "index"
|
||||
SIZE = "size"
|
||||
OPENSTATE = "openstate"
|
||||
|
||||
|
||||
class _ParallelOptimizerConfig:
|
||||
|
@ -117,6 +118,9 @@ class _AutoParallelContext:
|
|||
KeyError: When key of comm_fusion is not 'allreduce'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
config = config.copy()
|
||||
if _ParallelFusionConfig.OPENSTATE not in config.keys():
|
||||
config[_ParallelFusionConfig.OPENSTATE] = True
|
||||
for key in list(config.keys()):
|
||||
if key == _ParallelFusionConfig.ALLREDUCE:
|
||||
self._set_allreduce_comm_fusion(config[key])
|
||||
|
@ -124,8 +128,11 @@ class _AutoParallelContext:
|
|||
self._set_allgather_comm_fusion(config[key], key)
|
||||
elif key == _ParallelFusionConfig.REDUCESCATTER:
|
||||
self._set_allgather_comm_fusion(config[key], key)
|
||||
elif key == _ParallelFusionConfig.OPENSTATE:
|
||||
self._set_openstate_comm_fusion(config[key])
|
||||
else:
|
||||
raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
|
||||
raise KeyError("comm fusion type must be openstate,"
|
||||
"allreduce, allgather or reducescatter, but got {}".format(key))
|
||||
|
||||
def get_comm_fusion(self):
|
||||
"""Get comm fusion config."""
|
||||
|
@ -138,78 +145,6 @@ class _AutoParallelContext:
|
|||
return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
|
||||
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
||||
|
||||
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
||||
"""
|
||||
Set allgather and reducescatter fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto` and `size`.
|
||||
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
||||
return
|
||||
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
||||
return
|
||||
if not isinstance(comm_fusion, dict):
|
||||
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
||||
comm_type, type(comm_fusion)))
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
|
||||
fusion_threshold = 64
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
|
||||
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
||||
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
||||
|
||||
def _set_allreduce_comm_fusion(self, comm_fusion):
|
||||
"""
|
||||
Set fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto`, `size` and `index`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size' or 'index'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not self.get_enable_all_reduce_fusion():
|
||||
return
|
||||
if not isinstance(comm_fusion, dict):
|
||||
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
||||
type(comm_fusion)))
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
|
||||
self.set_fusion_threshold_mb(fusion_threshold=64)
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
|
||||
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
||||
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
|
||||
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
||||
"""
|
||||
Set fusion threshold (MB) for auto parallel.
|
||||
|
@ -943,6 +878,101 @@ class _AutoParallelContext:
|
|||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return group
|
||||
|
||||
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
||||
"""
|
||||
Set allgather and reducescatter fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto` and `size`.
|
||||
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
||||
return
|
||||
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
||||
return
|
||||
if not isinstance(comm_fusion, dict):
|
||||
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
||||
comm_type, type(comm_fusion)))
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
|
||||
fusion_threshold = 64
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
|
||||
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
||||
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
||||
|
||||
def _set_allreduce_comm_fusion(self, comm_fusion):
|
||||
"""
|
||||
Set fusion method for auto parallel.
|
||||
|
||||
Args:
|
||||
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
||||
supports four fusion methods: `auto`, `size` and `index`.
|
||||
|
||||
Raises:
|
||||
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
||||
KeyError: When `mode` is not 'auto', 'size' or 'index'.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not self.get_enable_all_reduce_fusion():
|
||||
return
|
||||
if not isinstance(comm_fusion, dict):
|
||||
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
||||
type(comm_fusion)))
|
||||
if _ParallelFusionConfig.MODE not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
||||
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
||||
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
||||
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
||||
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
||||
else:
|
||||
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
|
||||
comm_fusion[_ParallelFusionConfig.MODE]))
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
|
||||
self.set_fusion_threshold_mb(fusion_threshold=64)
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
|
||||
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
||||
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
||||
|
||||
def _set_openstate_comm_fusion(self, openstate):
|
||||
"""
|
||||
Set open state for comm fusion.
|
||||
|
||||
Args:
|
||||
openstate (bool): The open state value to set the fusion method whether or not. Currently it
|
||||
supports two states: `True`, or `Flase`.
|
||||
|
||||
Raises:
|
||||
TypeError: When the value is not bool.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not self.get_enable_all_reduce_fusion():
|
||||
return
|
||||
if not isinstance(openstate, bool):
|
||||
raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
|
||||
type(openstate)))
|
||||
if not openstate:
|
||||
self.set_enable_all_reduce_fusion(openstate)
|
||||
self.set_enable_all_gather_fusion(openstate)
|
||||
self.set_enable_reduce_scatter_fusion(openstate)
|
||||
|
||||
|
||||
|
||||
_AUTO_PARALLEL_CONTEXT = None
|
||||
|
||||
|
@ -1098,12 +1128,23 @@ def _set_auto_parallel_context(**kwargs):
|
|||
communication fusion config has two keys: "mode" and "config".
|
||||
It supports following communication fusion types and configurations:
|
||||
|
||||
- openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
|
||||
the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
|
||||
|
||||
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
|
||||
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
|
||||
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
|
||||
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
|
||||
`all_reduce_fusion_config`.
|
||||
|
||||
- allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
|
||||
In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
|
||||
threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
|
||||
manually, and the fusion threshold must be larger than `0` MB.
|
||||
|
||||
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
|
||||
and `size`. Config is same as `allgather`.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -287,3 +287,42 @@ def test_enable_invalid_value_failed():
|
|||
"""
|
||||
with pytest.raises(TypeError):
|
||||
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion="fusion")
|
||||
|
||||
|
||||
def test_allreduce_fusion_openstate():
|
||||
"""
|
||||
Feature: test priority of "openstate" and "comm_fusion"
|
||||
Description: test priority of "openstate" and "comm_fusion"
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"openstate": False, "allreduce": {"mode": "size", "config": 32}}
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
|
||||
|
||||
def test_allreduce_fusion_auto_with_openstate():
|
||||
"""
|
||||
Feature: test_allreduce_fusion in auto mode with openstate
|
||||
Description: allreduce fusion in auto mode with openstate
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"openstate": True, "allreduce": {"mode": "auto", "config": None}}
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
|
||||
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
|
||||
allreduce_fusion_dict = train_common(net)
|
||||
expect_dict = {'backbone2.fc8.weight': 1,
|
||||
'backbone2.fc7.weight': 1,
|
||||
'backbone2.fc6.weight': 1,
|
||||
'backbone1.fc4.weight': 1,
|
||||
'backbone1.fc3.weight': 1,
|
||||
'backbone1.fc2.weight': 1,
|
||||
'backbone2.fc5.weight': 1,
|
||||
'backbone2.fc4.weight': 1,
|
||||
'backbone2.fc3.weight': 1,
|
||||
'backbone2.fc2.weight': 1,
|
||||
'backbone2.fc1.weight': 1,
|
||||
'backbone1.fc1.weight': 1}
|
||||
assert allreduce_fusion_dict == expect_dict
|
||||
|
|
|
@ -241,3 +241,16 @@ def test_reducescatter_fusion_invalid_value_failed():
|
|||
with pytest.raises(KeyError):
|
||||
comm_fusion_dict = {"reducescatter": {"mode": "size"}}
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||
|
||||
|
||||
def test_openstate_comm_fusion():
|
||||
"""
|
||||
Feature: test_openstate_comm_fusion
|
||||
Description: test openstate in comm_fusion
|
||||
Expectation: success
|
||||
"""
|
||||
comm_fusion_dict = {"openstate": False}
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
|
||||
assert auto_parallel_context().get_enable_all_reduce_fusion() is False
|
||||
assert auto_parallel_context().get_enable_all_gather_fusion() is False
|
||||
assert auto_parallel_context().get_enable_reduce_scatter_fusion() is False
|
||||
|
|
Loading…
Reference in New Issue