fix allreduce and add openstate

This commit is contained in:
jiahongqian 2023-02-02 17:30:48 +08:00
parent d700e5d073
commit 1502b406ab
5 changed files with 170 additions and 73 deletions

View File

@ -61,6 +61,7 @@ mindspore.set_auto_parallel_context
- **comm_fusion** (dict) - 用于设置通信算子的融合配置。可以同一类型的通信算子按梯度张量的大小或者顺序分块传输。输入格式为{"通信类型": {"mode":str, "config": None int 或者 list}},每种通信算子的融合配置有两个键:"mode"和"config"。支持以下通信类型的融合类型和配置:
- openstate是否开启通信融合功能。通过 True 或 False 来开启或关闭通信融合功能。默认值True。
- allreduce进行AllReduce算子的通信融合。"mode"包含:"auto"、"size"和"index"。在"auto"模式下,融合的是梯度变量的大小,默认值阈值为"64"MB"config"对应的值为None。在"size"模式下需要用户在config的字典中指定梯度大小阈值这个值必须大于"0"MB。在"mode"为"index"时,它与"all_reduce_fusion_config"相同,用户需要给"config"传入一个列表,里面每个值表示梯度的索引。
- allgather进行AllGather算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。
- reducescatter进行ReduceScatter算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。

View File

@ -606,6 +606,9 @@ def set_auto_parallel_context(**kwargs):
communication fusion config has two keys: "mode" and "config".
It supports following communication fusion types and configurations:
- openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
- allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default
fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size

View File

@ -40,6 +40,7 @@ class _ParallelFusionConfig:
AUTO = "auto"
INDEX = "index"
SIZE = "size"
OPENSTATE = "openstate"
class _ParallelOptimizerConfig:
@ -117,6 +118,9 @@ class _AutoParallelContext:
KeyError: When key of comm_fusion is not 'allreduce'.
"""
self.check_context_handle()
config = config.copy()
if _ParallelFusionConfig.OPENSTATE not in config.keys():
config[_ParallelFusionConfig.OPENSTATE] = True
for key in list(config.keys()):
if key == _ParallelFusionConfig.ALLREDUCE:
self._set_allreduce_comm_fusion(config[key])
@ -124,8 +128,11 @@ class _AutoParallelContext:
self._set_allgather_comm_fusion(config[key], key)
elif key == _ParallelFusionConfig.REDUCESCATTER:
self._set_allgather_comm_fusion(config[key], key)
elif key == _ParallelFusionConfig.OPENSTATE:
self._set_openstate_comm_fusion(config[key])
else:
raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
raise KeyError("comm fusion type must be openstate,"
"allreduce, allgather or reducescatter, but got {}".format(key))
def get_comm_fusion(self):
"""Get comm fusion config."""
@ -138,78 +145,6 @@ class _AutoParallelContext:
return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
_ParallelFusionConfig.FUSION_CONFIG: config}}
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
"""
Set allgather and reducescatter fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto` and `size`.
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size'.
"""
self.check_context_handle()
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
return
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
return
if not isinstance(comm_fusion, dict):
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
comm_type, type(comm_fusion)))
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
fusion_threshold = 64
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
def _set_allreduce_comm_fusion(self, comm_fusion):
"""
Set fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto`, `size` and `index`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size' or 'index'.
"""
self.check_context_handle()
if not self.get_enable_all_reduce_fusion():
return
if not isinstance(comm_fusion, dict):
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
type(comm_fusion)))
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
self.set_fusion_threshold_mb(fusion_threshold=64)
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
"""
Set fusion threshold (MB) for auto parallel.
@ -943,6 +878,101 @@ class _AutoParallelContext:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return group
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
"""
Set allgather and reducescatter fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto` and `size`.
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size'.
"""
self.check_context_handle()
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
return
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
return
if not isinstance(comm_fusion, dict):
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
comm_type, type(comm_fusion)))
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
fusion_threshold = 64
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
def _set_allreduce_comm_fusion(self, comm_fusion):
"""
Set fusion method for auto parallel.
Args:
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
supports four fusion methods: `auto`, `size` and `index`.
Raises:
KeyError: When key of comm_fusion is not 'mode' or 'config'.
KeyError: When `mode` is not 'auto', 'size' or 'index'.
"""
self.check_context_handle()
if not self.get_enable_all_reduce_fusion():
return
if not isinstance(comm_fusion, dict):
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
type(comm_fusion)))
if _ParallelFusionConfig.MODE not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
else:
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
comm_fusion[_ParallelFusionConfig.MODE]))
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
self.set_fusion_threshold_mb(fusion_threshold=64)
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
def _set_openstate_comm_fusion(self, openstate):
"""
Set open state for comm fusion.
Args:
openstate (bool): The open state value to set the fusion method whether or not. Currently it
supports two states: `True`, or `Flase`.
Raises:
TypeError: When the value is not bool.
"""
self.check_context_handle()
if not self.get_enable_all_reduce_fusion():
return
if not isinstance(openstate, bool):
raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
type(openstate)))
if not openstate:
self.set_enable_all_reduce_fusion(openstate)
self.set_enable_all_gather_fusion(openstate)
self.set_enable_reduce_scatter_fusion(openstate)
_AUTO_PARALLEL_CONTEXT = None
@ -1098,12 +1128,23 @@ def _set_auto_parallel_context(**kwargs):
communication fusion config has two keys: "mode" and "config".
It supports following communication fusion types and configurations:
- openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
`all_reduce_fusion_config`.
- allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB.
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
and `size`. Config is same as `allgather`.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -287,3 +287,42 @@ def test_enable_invalid_value_failed():
"""
with pytest.raises(TypeError):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion="fusion")
def test_allreduce_fusion_openstate():
"""
Feature: test priority of "openstate" and "comm_fusion"
Description: test priority of "openstate" and "comm_fusion"
Expectation: success
"""
comm_fusion_dict = {"openstate": False, "allreduce": {"mode": "size", "config": 32}}
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {}
assert allreduce_fusion_dict == expect_dict
def test_allreduce_fusion_auto_with_openstate():
"""
Feature: test_allreduce_fusion in auto mode with openstate
Description: allreduce fusion in auto mode with openstate
Expectation: success
"""
comm_fusion_dict = {"openstate": True, "allreduce": {"mode": "auto", "config": None}}
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {'backbone2.fc8.weight': 1,
'backbone2.fc7.weight': 1,
'backbone2.fc6.weight': 1,
'backbone1.fc4.weight': 1,
'backbone1.fc3.weight': 1,
'backbone1.fc2.weight': 1,
'backbone2.fc5.weight': 1,
'backbone2.fc4.weight': 1,
'backbone2.fc3.weight': 1,
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc1.weight': 1}
assert allreduce_fusion_dict == expect_dict

View File

@ -241,3 +241,16 @@ def test_reducescatter_fusion_invalid_value_failed():
with pytest.raises(KeyError):
comm_fusion_dict = {"reducescatter": {"mode": "size"}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
def test_openstate_comm_fusion():
"""
Feature: test_openstate_comm_fusion
Description: test openstate in comm_fusion
Expectation: success
"""
comm_fusion_dict = {"openstate": False}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
assert auto_parallel_context().get_enable_all_reduce_fusion() is False
assert auto_parallel_context().get_enable_all_gather_fusion() is False
assert auto_parallel_context().get_enable_reduce_scatter_fusion() is False