forked from mindspore-Ecosystem/mindspore
!13232 add parallel validation
From: @gong_zi_yan Reviewed-by: @stsuteng,@kisnwang Signed-off-by: @stsuteng
This commit is contained in:
commit
f930710b53
|
@ -19,6 +19,8 @@ import numbers
|
|||
import numpy as np
|
||||
from .._c_expression import ParamInfo
|
||||
from . import dtype as mstype
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode
|
||||
from .initializer import initializer
|
||||
from .tensor import Tensor
|
||||
from .._checkparam import Validator
|
||||
|
@ -292,7 +294,18 @@ class Parameter(Tensor_):
|
|||
|
||||
@comm_fusion.setter
|
||||
def comm_fusion(self, comm_fusion_):
|
||||
"""Set the fusion type for communication operators corresponding to this parameter."""
|
||||
"""
|
||||
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
|
||||
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated
|
||||
for this parameter. Only `Ascend` and `Graph` mode is supported.
|
||||
|
||||
Args:
|
||||
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
|
||||
When the value of fusion is 0, operators will not be fused together.
|
||||
"""
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
|
||||
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
|
||||
Validator.check_non_negative_int(comm_fusion_)
|
||||
self.param_info.comm_fusion = comm_fusion_
|
||||
|
||||
@property
|
||||
|
|
|
@ -1129,7 +1129,18 @@ class Cell(Cell_):
|
|||
param.set_param_ps(init_in_server)
|
||||
|
||||
def set_comm_fusion(self, fusion_type, recurse=True):
|
||||
Validator.check_is_int(fusion_type)
|
||||
"""
|
||||
Set `comm_fusion` for all the parameters in the Net. Please refer to the description of
|
||||
`mindspore.common.parameter.comm_fusion`.
|
||||
|
||||
Note:
|
||||
The value of attribute will be overwritten when the function is called multiply.
|
||||
|
||||
Args:
|
||||
fusion_type (int): The value of `comm_fusion`.
|
||||
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
|
||||
"""
|
||||
Validator.check_non_negative_int(fusion_type)
|
||||
for param in self.trainable_params(recurse):
|
||||
param.comm_fusion = fusion_type
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue