Error white list

This commit is contained in:
ZPaC 2022-10-11 21:28:12 +08:00
parent 1f0ee25d7b
commit 7f897a6ab8
3 changed files with 51 additions and 5 deletions

View File

@ -2185,11 +2185,19 @@ class Cell(Cell_):
"""
Set the label for all operators in this cell.
This label tells MindSpore compiler on which process this cell should be launched.
Each process's identical id consists of inputs 'role' and 'rank_id'.
And each process's identical label consists of input 'role' and 'rank_id'.
So by setting different cells with different labels, which will be launched on different processes,
users can launch a distributed training job.
Note:
- 'role' only supports the value 'MS_WORKER' for now.
- This method is effective only after
"mindspore.communication.init()" is called for dynamic cluster building.
- The rank is unique in processes with the same role.
Args:
role (string): The role of the process on which this cell will be launched.
rank_id (string): The rank id of the process on which this cell will be launched.
rank_id (int): The rank id of the process on which this cell will be launched.
"""
all_ops = self._get_prims_recursively()
for op in all_ops:

View File

@ -20,7 +20,8 @@ import copy
from mindspore.common.api import _wrap_func
from mindspore.log import _LogActionOnce
from mindspore import context, log as logger
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_parallel_mode, _is_in_hybrid_parallel_mode
from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
from mindspore.common.parameter import Parameter
from mindspore.common.api import _pynative_executor
from mindspore._c_expression import Primitive_, prim_type
@ -398,12 +399,41 @@ class Primitive(Primitive_):
"""
Set the label for this primitive.
This label tells MindSpore compiler on which process this operator should be launched.
Each process's identical id consists of inputs 'role' and 'rank_id'.
And each process's identical label consists of input 'role' and 'rank_id'.
So by setting different operators with different labels,
which will be launched on different processes, users can launch a distributed training job.
Note:
- 'role' only supports the value 'MS_WORKER' for now.
- This method is effective only after
"mindspore.communication.init()" is called for dynamic cluster building.
- The rank is unique in processes with the same role.
Args:
role (string): The role of the process on which this operator will be launched.
rank_id (string): The rank id of the process on which this operator will be launched.
rank_id (int): The rank id of the process on which this operator will be launched.
"""
if _is_role_sched():
return
Validator.check_non_negative_int(rank_id, "rank_id", "Primitive.place")
Validator.check_string(role, "MS_WORKER", "role", "Primitive.place")
# Get the execution context and check whether calling of this 'place' method is valid.
# This is because placing operators to arbitrary processes while other distributed training mode
# is enabled is very unpredictable and may cause fatal error.
# Some of these cases are under development and others should not be supported.
if _is_ps_mode():
raise RuntimeError(
"You are calling Primitive.place mixed with Parameter Server training. "
"This case is not supported yet. "
"Please call Primitive.place without Parameter Server training.")
if _is_in_auto_parallel_mode() or _is_in_data_parallel_mode() or _is_in_hybrid_parallel_mode():
raise RuntimeError(
"You are calling Primitive.place mixed with other parallel features: "
"'auto_parallel', 'data_parallel' and 'hybrid_parallel'. "
"This case is still under development and not supported yet. "
"Please call Primitive.place without these features.")
self.add_prim_attr("ms_role", role)
self.add_prim_attr("rank_id", rank_id)

View File

@ -43,6 +43,14 @@ def _is_in_auto_parallel_mode():
return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
def _is_in_data_parallel_mode():
return _get_parallel_mode() == ms.ParallelMode.DATA_PARALLEL
def _is_in_hybrid_parallel_mode():
return _get_parallel_mode() == ms.ParallelMode.HYBRID_PARALLEL
def _is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')