From 1502b406ab19a35613ffe092662a2e38c9ebde69 Mon Sep 17 00:00:00 2001 From: jiahongqian Date: Thu, 2 Feb 2023 17:30:48 +0800 Subject: [PATCH] fix allreduce and add openstate --- .../mindspore.set_auto_parallel_context.rst | 1 + mindspore/python/mindspore/context.py | 3 + .../parallel/_auto_parallel_context.py | 187 +++++++++++------- .../python/parallel/test_allreduce_fusion.py | 39 ++++ tests/ut/python/parallel/test_comm_fusion.py | 13 ++ 5 files changed, 170 insertions(+), 73 deletions(-) diff --git a/docs/api/api_python/mindspore/mindspore.set_auto_parallel_context.rst b/docs/api/api_python/mindspore/mindspore.set_auto_parallel_context.rst index b996d8c314b..e98429df60c 100644 --- a/docs/api/api_python/mindspore/mindspore.set_auto_parallel_context.rst +++ b/docs/api/api_python/mindspore/mindspore.set_auto_parallel_context.rst @@ -61,6 +61,7 @@ mindspore.set_auto_parallel_context - **comm_fusion** (dict) - 用于设置通信算子的融合配置。可以同一类型的通信算子按梯度张量的大小或者顺序分块传输。输入格式为{"通信类型": {"mode":str, "config": None int 或者 list}},每种通信算子的融合配置有两个键:"mode"和"config"。支持以下通信类型的融合类型和配置: + - openstate:是否开启通信融合功能。通过 True 或 False 来开启或关闭通信融合功能。默认值:True。 - allreduce:进行AllReduce算子的通信融合。"mode"包含:"auto"、"size"和"index"。在"auto"模式下,融合的是梯度变量的大小,默认值阈值为"64"MB,"config"对应的值为None。在"size"模式下,需要用户在config的字典中指定梯度大小阈值,这个值必须大于"0"MB。在"mode"为"index"时,它与"all_reduce_fusion_config"相同,用户需要给"config"传入一个列表,里面每个值表示梯度的索引。 - allgather:进行AllGather算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。 - reducescatter:进行ReduceScatter算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。 diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index 3c7895df309..042391f871b 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -606,6 +606,9 @@ def set_auto_parallel_context(**kwargs): communication fusion config has two keys: "mode" and "config". It supports following communication fusion types and configurations: + - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on + the communication fusion, otherwise, turn off the communication fusion. Default: `True`. + - allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size diff --git a/mindspore/python/mindspore/parallel/_auto_parallel_context.py b/mindspore/python/mindspore/parallel/_auto_parallel_context.py index f461292c09f..bd13248c1e3 100644 --- a/mindspore/python/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/python/mindspore/parallel/_auto_parallel_context.py @@ -40,6 +40,7 @@ class _ParallelFusionConfig: AUTO = "auto" INDEX = "index" SIZE = "size" + OPENSTATE = "openstate" class _ParallelOptimizerConfig: @@ -117,6 +118,9 @@ class _AutoParallelContext: KeyError: When key of comm_fusion is not 'allreduce'. """ self.check_context_handle() + config = config.copy() + if _ParallelFusionConfig.OPENSTATE not in config.keys(): + config[_ParallelFusionConfig.OPENSTATE] = True for key in list(config.keys()): if key == _ParallelFusionConfig.ALLREDUCE: self._set_allreduce_comm_fusion(config[key]) @@ -124,8 +128,11 @@ class _AutoParallelContext: self._set_allgather_comm_fusion(config[key], key) elif key == _ParallelFusionConfig.REDUCESCATTER: self._set_allgather_comm_fusion(config[key], key) + elif key == _ParallelFusionConfig.OPENSTATE: + self._set_openstate_comm_fusion(config[key]) else: - raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key)) + raise KeyError("comm fusion type must be openstate," + "allreduce, allgather or reducescatter, but got {}".format(key)) def get_comm_fusion(self): """Get comm fusion config.""" @@ -138,78 +145,6 @@ class _AutoParallelContext: return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode, _ParallelFusionConfig.FUSION_CONFIG: config}} - def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"): - """ - Set allgather and reducescatter fusion method for auto parallel. - - Args: - comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it - supports four fusion methods: `auto` and `size`. - comm_type (str): The name of the communication operator, `allgather` or `reducescatter`. - - Raises: - KeyError: When key of comm_fusion is not 'mode' or 'config'. - KeyError: When `mode` is not 'auto', 'size'. - """ - self.check_context_handle() - if comm_type == "allgather" and not self.get_enable_all_gather_fusion(): - return - if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion(): - return - if not isinstance(comm_fusion, dict): - raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format( - comm_type, type(comm_fusion))) - if _ParallelFusionConfig.MODE not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") - if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'config' should be contained.") - check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE] - if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: - self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) - else: - raise KeyError("fusion method mode must be auto or size, but got {}".format( - comm_fusion[_ParallelFusionConfig.MODE])) - - fusion_threshold = 64 - if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO: - fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG] - self.set_fusion_threshold_mb(fusion_threshold, comm_type) - - def _set_allreduce_comm_fusion(self, comm_fusion): - """ - Set fusion method for auto parallel. - - Args: - comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it - supports four fusion methods: `auto`, `size` and `index`. - - Raises: - KeyError: When key of comm_fusion is not 'mode' or 'config'. - KeyError: When `mode` is not 'auto', 'size' or 'index'. - """ - self.check_context_handle() - if not self.get_enable_all_reduce_fusion(): - return - if not isinstance(comm_fusion, dict): - raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format( - type(comm_fusion))) - if _ParallelFusionConfig.MODE not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") - if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'config' should be contained.") - check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] - if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: - self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) - else: - raise KeyError("fusion method mode must be auto, index or size, but got {}".format( - comm_fusion[_ParallelFusionConfig.MODE])) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: - self.set_fusion_threshold_mb(fusion_threshold=64) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: - self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: - self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) - def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"): """ Set fusion threshold (MB) for auto parallel. @@ -943,6 +878,101 @@ class _AutoParallelContext: group = _DEFAULT_NCCL_FUSION_GROUP_NAME return group + def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"): + """ + Set allgather and reducescatter fusion method for auto parallel. + + Args: + comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it + supports four fusion methods: `auto` and `size`. + comm_type (str): The name of the communication operator, `allgather` or `reducescatter`. + + Raises: + KeyError: When key of comm_fusion is not 'mode' or 'config'. + KeyError: When `mode` is not 'auto', 'size'. + """ + self.check_context_handle() + if comm_type == "allgather" and not self.get_enable_all_gather_fusion(): + return + if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion(): + return + if not isinstance(comm_fusion, dict): + raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format( + comm_type, type(comm_fusion))) + if _ParallelFusionConfig.MODE not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") + if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'config' should be contained.") + check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE] + if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: + self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) + else: + raise KeyError("fusion method mode must be auto or size, but got {}".format( + comm_fusion[_ParallelFusionConfig.MODE])) + + fusion_threshold = 64 + if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO: + fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG] + self.set_fusion_threshold_mb(fusion_threshold, comm_type) + + def _set_allreduce_comm_fusion(self, comm_fusion): + """ + Set fusion method for auto parallel. + + Args: + comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it + supports four fusion methods: `auto`, `size` and `index`. + + Raises: + KeyError: When key of comm_fusion is not 'mode' or 'config'. + KeyError: When `mode` is not 'auto', 'size' or 'index'. + """ + self.check_context_handle() + if not self.get_enable_all_reduce_fusion(): + return + if not isinstance(comm_fusion, dict): + raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format( + type(comm_fusion))) + if _ParallelFusionConfig.MODE not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") + if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'config' should be contained.") + check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] + if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: + self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) + else: + raise KeyError("fusion method mode must be auto, index or size, but got {}".format( + comm_fusion[_ParallelFusionConfig.MODE])) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: + self.set_fusion_threshold_mb(fusion_threshold=64) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: + self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: + self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) + + def _set_openstate_comm_fusion(self, openstate): + """ + Set open state for comm fusion. + + Args: + openstate (bool): The open state value to set the fusion method whether or not. Currently it + supports two states: `True`, or `Flase`. + + Raises: + TypeError: When the value is not bool. + """ + self.check_context_handle() + if not self.get_enable_all_reduce_fusion(): + return + if not isinstance(openstate, bool): + raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format( + type(openstate))) + if not openstate: + self.set_enable_all_reduce_fusion(openstate) + self.set_enable_all_gather_fusion(openstate) + self.set_enable_reduce_scatter_fusion(openstate) + + _AUTO_PARALLEL_CONTEXT = None @@ -1098,12 +1128,23 @@ def _set_auto_parallel_context(**kwargs): communication fusion config has two keys: "mode" and "config". It supports following communication fusion types and configurations: + - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on + the communication fusion, otherwise, turn off the communication fusion. Default: `True`. + - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as `all_reduce_fusion_config`. + - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`. + In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion + threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size + manually, and the fusion threshold must be larger than `0` MB. + + - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto` + and `size`. Config is same as `allgather`. + Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index faf3da2a096..a06cb08b1f6 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -287,3 +287,42 @@ def test_enable_invalid_value_failed(): """ with pytest.raises(TypeError): auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion="fusion") + + +def test_allreduce_fusion_openstate(): + """ + Feature: test priority of "openstate" and "comm_fusion" + Description: test priority of "openstate" and "comm_fusion" + Expectation: success + """ + comm_fusion_dict = {"openstate": False, "allreduce": {"mode": "size", "config": 32}} + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict) + net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) + allreduce_fusion_dict = train_common(net) + expect_dict = {} + assert allreduce_fusion_dict == expect_dict + + +def test_allreduce_fusion_auto_with_openstate(): + """ + Feature: test_allreduce_fusion in auto mode with openstate + Description: allreduce fusion in auto mode with openstate + Expectation: success + """ + comm_fusion_dict = {"openstate": True, "allreduce": {"mode": "auto", "config": None}} + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict) + net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None)) + allreduce_fusion_dict = train_common(net) + expect_dict = {'backbone2.fc8.weight': 1, + 'backbone2.fc7.weight': 1, + 'backbone2.fc6.weight': 1, + 'backbone1.fc4.weight': 1, + 'backbone1.fc3.weight': 1, + 'backbone1.fc2.weight': 1, + 'backbone2.fc5.weight': 1, + 'backbone2.fc4.weight': 1, + 'backbone2.fc3.weight': 1, + 'backbone2.fc2.weight': 1, + 'backbone2.fc1.weight': 1, + 'backbone1.fc1.weight': 1} + assert allreduce_fusion_dict == expect_dict diff --git a/tests/ut/python/parallel/test_comm_fusion.py b/tests/ut/python/parallel/test_comm_fusion.py index 0ea84ebc9a9..85278f58a6d 100644 --- a/tests/ut/python/parallel/test_comm_fusion.py +++ b/tests/ut/python/parallel/test_comm_fusion.py @@ -241,3 +241,16 @@ def test_reducescatter_fusion_invalid_value_failed(): with pytest.raises(KeyError): comm_fusion_dict = {"reducescatter": {"mode": "size"}} context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict) + + +def test_openstate_comm_fusion(): + """ + Feature: test_openstate_comm_fusion + Description: test openstate in comm_fusion + Expectation: success + """ + comm_fusion_dict = {"openstate": False} + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict) + assert auto_parallel_context().get_enable_all_reduce_fusion() is False + assert auto_parallel_context().get_enable_all_gather_fusion() is False + assert auto_parallel_context().get_enable_reduce_scatter_fusion() is False