forked from mindspore-Ecosystem/mindspore
modify parallel API
This commit is contained in:
parent
fa12d62d4d
commit
62e832eaab
|
@ -86,18 +86,17 @@ def init(backend_name=None):
|
|||
This method should be used after set_context.
|
||||
|
||||
Args:
|
||||
backend_name (str): Backend, using HCCL/NCCL. if not been set, infer it by device_target. Default: None.
|
||||
backend_name (str): Backend, using HCCL/NCCL. If the `backend_name` is None, system will
|
||||
recognize `device_target` by devices. Default: None.
|
||||
|
||||
Raises:
|
||||
TypeError: If `backend_name` is not a string.
|
||||
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
|
||||
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
|
||||
have not been exported when backend is HCCL.
|
||||
ValueError: If the environment variable RANK_ID has not been exported as a number.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
"""
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
|
@ -172,6 +171,8 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
ValueError: If backend is invalid.
|
||||
RuntimeError: If HCCL/NCCL is not available.
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -195,6 +196,8 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
ValueError: If backend is invalid.
|
||||
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -217,6 +220,8 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
ValueError: If backend is invalid.
|
||||
RuntimeError: If HCCL/NCCL is not available.
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -240,6 +245,8 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|||
ValueError: If backend is invalid.
|
||||
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -273,8 +280,11 @@ def get_world_rank_from_group_rank(group, group_rank_id):
|
|||
>>> rank_ids = [0,4]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> world_rank_id = get_world_rank_from_group_rank(group, 1)
|
||||
>>> print("world_rank_id is: ", world_rank_id) # world_rank_id is: 4
|
||||
>>> print("world_rank_id is: ", world_rank_id)
|
||||
world_rank_id is: 4
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -308,8 +318,11 @@ def get_group_rank_from_world_rank(world_rank_id, group):
|
|||
>>> rank_ids = [0,4]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> group_rank_id = get_group_rank_from_world_rank(4, group)
|
||||
>>> print("group_rank_id is: ", group_rank_id) # group_rank_id is: 1
|
||||
>>> print("group_rank_id is: ", group_rank_id)
|
||||
group_rank_id is: 1
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -320,9 +333,7 @@ def create_group(group, rank_ids):
|
|||
Note:
|
||||
GPU version of MindSpore doesn't support this method.
|
||||
|
||||
The size of rank_ids should be larger than 1.
|
||||
|
||||
Rank_ids should not have duplicate data.
|
||||
The size of rank_ids should be larger than 1, rank_ids should not have duplicate data.
|
||||
|
||||
This method should be used after init().
|
||||
|
||||
|
@ -345,6 +356,8 @@ def create_group(group, rank_ids):
|
|||
>>> rank_ids = [0,8]
|
||||
>>> create_group(group, rank_ids)
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
|
||||
|
||||
|
||||
|
@ -365,4 +378,6 @@ def destroy_group(group):
|
|||
ValueError: If group is "hccl_world_group" or backend is invalid.
|
||||
RuntimeError: If HCCL/NCCL is not available or MindSpore is GPU version.
|
||||
"""
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("Group name must be a string, but got {}".format(type(group)))
|
||||
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
||||
|
|
|
@ -336,7 +336,7 @@ class Cell(Cell_):
|
|||
res.append(self.cast(item, dst_type))
|
||||
return tuple(res)
|
||||
|
||||
def do_parameter_broadcast(self):
|
||||
def _do_parameter_broadcast(self):
|
||||
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
||||
if not self.parameter_broadcast_done:
|
||||
_pynative_executor.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
|
||||
|
@ -392,7 +392,7 @@ class Cell(Cell_):
|
|||
return out
|
||||
|
||||
# Run in PyNative mode.
|
||||
self.do_parameter_broadcast()
|
||||
self._do_parameter_broadcast()
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("The cell inputs should not be numpy arrays.")
|
||||
|
@ -671,6 +671,12 @@ class Cell(Cell_):
|
|||
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
|
||||
|
||||
def auto_parallel_compile_and_run(self):
|
||||
"""
|
||||
Whether or not to execute compile and run.
|
||||
|
||||
Returns:
|
||||
bool, `_auto_parallel_compile_and_run` value.
|
||||
"""
|
||||
return self._auto_parallel_compile_and_run
|
||||
|
||||
def exec_checkpoint_graph(self):
|
||||
|
@ -860,6 +866,15 @@ class Cell(Cell_):
|
|||
return param_dict
|
||||
|
||||
def parameters_broadcast_dict(self, recurse=True):
|
||||
"""
|
||||
Gets the parameters broadcast dictionary of this cell.
|
||||
|
||||
Args:
|
||||
recurse (bool): Whether contains the parameters of subcells. Default: True.
|
||||
|
||||
Returns:
|
||||
OrderedDict, return parameters broadcast dictionary.
|
||||
"""
|
||||
param_dict = OrderedDict()
|
||||
for param in self.get_parameters(expand=recurse):
|
||||
if param.layerwise_parallel is False:
|
||||
|
|
|
@ -1018,7 +1018,7 @@ class Model:
|
|||
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
|
||||
>>> model = Model(Net())
|
||||
>>> model.infer_predict_layout(input_data)
|
||||
>>> predict_map = model.infer_predict_layout(input_data)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE:
|
||||
raise RuntimeError('Pre-compile process only supports GRAPH MODE currently.')
|
||||
|
|
|
@ -132,13 +132,8 @@ def test_raise_error_funcs():
|
|||
assert has_raise_error(destroy_group2, '0-1') is False
|
||||
|
||||
|
||||
def test_get_rank_none():
|
||||
assert D.get_rank(group=None) == 0
|
||||
|
||||
|
||||
def test_group_funs():
|
||||
D.GlobalComm.BACKEND = D.Backend.HCCL
|
||||
assert D.get_group_size(group=None) == 1
|
||||
assert D.get_group_size('2-abcd') == 2
|
||||
assert D.get_world_rank_from_group_rank('0-1', 0) == 0
|
||||
assert D.get_group_rank_from_world_rank(0, '0-1') == 0
|
||||
|
|
Loading…
Reference in New Issue