modify parallel API

This commit is contained in:
lilei 2021-09-06 11:00:21 +08:00
parent fa12d62d4d
commit 62e832eaab
4 changed files with 41 additions and 16 deletions

View File

@ -86,18 +86,17 @@ def init(backend_name=None):
This method should be used after set_context.
Args:
backend_name (str): Backend, using HCCL/NCCL. if not been set, infer it by device_target. Default: None.
backend_name (str): Backend, using HCCL/NCCL. If the `backend_name` is None, system will
recognize `device_target` by devices. Default: None.
Raises:
TypeError: If `backend_name` is not a string.
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
have not been exported when backend is HCCL.
ValueError: If the environment variable RANK_ID has not been exported as a number.
Examples:
>>> from mindspore.context import set_context
>>> set_context(device_target="Ascend")
>>> init()
"""
if _is_role_pserver() or _is_role_sched():
@ -172,6 +171,8 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If HCCL/NCCL is not available.
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -195,6 +196,8 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -217,6 +220,8 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If HCCL/NCCL is not available.
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -240,6 +245,8 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
ValueError: If backend is invalid.
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -273,8 +280,11 @@ def get_world_rank_from_group_rank(group, group_rank_id):
>>> rank_ids = [0,4]
>>> create_group(group, rank_ids)
>>> world_rank_id = get_world_rank_from_group_rank(group, 1)
>>> print("world_rank_id is: ", world_rank_id) # world_rank_id is: 4
>>> print("world_rank_id is: ", world_rank_id)
world_rank_id is: 4
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
@ -308,8 +318,11 @@ def get_group_rank_from_world_rank(world_rank_id, group):
>>> rank_ids = [0,4]
>>> create_group(group, rank_ids)
>>> group_rank_id = get_group_rank_from_world_rank(4, group)
>>> print("group_rank_id is: ", group_rank_id) # group_rank_id is: 1
>>> print("group_rank_id is: ", group_rank_id)
group_rank_id is: 1
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
@ -320,9 +333,7 @@ def create_group(group, rank_ids):
Note:
GPU version of MindSpore doesn't support this method.
The size of rank_ids should be larger than 1.
Rank_ids should not have duplicate data.
The size of rank_ids should be larger than 1, rank_ids should not have duplicate data.
This method should be used after init().
@ -345,6 +356,8 @@ def create_group(group, rank_ids):
>>> rank_ids = [0,8]
>>> create_group(group, rank_ids)
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
@ -365,4 +378,6 @@ def destroy_group(group):
ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
"""
if not isinstance(group, str):
raise TypeError("Group name must be a string, but got {}".format(type(group)))
_destroy_group_helper(group, backend=GlobalComm.BACKEND)

View File

@ -336,7 +336,7 @@ class Cell(Cell_):
res.append(self.cast(item, dst_type))
return tuple(res)
def do_parameter_broadcast(self):
def _do_parameter_broadcast(self):
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
if not self.parameter_broadcast_done:
_pynative_executor.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
@ -392,7 +392,7 @@ class Cell(Cell_):
return out
# Run in PyNative mode.
self.do_parameter_broadcast()
self._do_parameter_broadcast()
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("The cell inputs should not be numpy arrays.")
@ -671,6 +671,12 @@ class Cell(Cell_):
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
def auto_parallel_compile_and_run(self):
"""
Whether or not to execute compile and run.
Returns:
bool, `_auto_parallel_compile_and_run` value.
"""
return self._auto_parallel_compile_and_run
def exec_checkpoint_graph(self):
@ -860,6 +866,15 @@ class Cell(Cell_):
return param_dict
def parameters_broadcast_dict(self, recurse=True):
"""
Gets the parameters broadcast dictionary of this cell.
Args:
recurse (bool): Whether contains the parameters of subcells. Default: True.
Returns:
OrderedDict, return parameters broadcast dictionary.
"""
param_dict = OrderedDict()
for param in self.get_parameters(expand=recurse):
if param.layerwise_parallel is False:

View File

@ -1018,7 +1018,7 @@ class Model:
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)
>>> predict_map = model.infer_predict_layout(input_data)
"""
if context.get_context("mode") != context.GRAPH_MODE:
raise RuntimeError('Pre-compile process only supports GRAPH MODE currently.')

View File

@ -132,13 +132,8 @@ def test_raise_error_funcs():
assert has_raise_error(destroy_group2, '0-1') is False
def test_get_rank_none():
assert D.get_rank(group=None) == 0
def test_group_funs():
D.GlobalComm.BACKEND = D.Backend.HCCL
assert D.get_group_size(group=None) == 1
assert D.get_group_size('2-abcd') == 2
assert D.get_world_rank_from_group_rank('0-1', 0) == 0
assert D.get_group_rank_from_world_rank(0, '0-1') == 0