!13232 add parallel validation

From: @gong_zi_yan
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-03-16 17:27:24 +08:00 committed by Gitee
commit f930710b53
2 changed files with 26 additions and 2 deletions

View File

@ -19,6 +19,8 @@ import numbers
import numpy as np
from .._c_expression import ParamInfo
from . import dtype as mstype
from .. import context
from ..parallel._utils import _get_parallel_mode
from .initializer import initializer
from .tensor import Tensor
from .._checkparam import Validator
@ -292,7 +294,18 @@ class Parameter(Tensor_):
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
"""Set the fusion type for communication operators corresponding to this parameter."""
"""
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated
for this parameter. Only `Ascend` and `Graph` mode is supported.
Args:
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
When the value of fusion is 0, operators will not be fused together.
"""
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
Validator.check_non_negative_int(comm_fusion_)
self.param_info.comm_fusion = comm_fusion_
@property

View File

@ -1129,7 +1129,18 @@ class Cell(Cell_):
param.set_param_ps(init_in_server)
def set_comm_fusion(self, fusion_type, recurse=True):
Validator.check_is_int(fusion_type)
"""
Set `comm_fusion` for all the parameters in the Net. Please refer to the description of
`mindspore.common.parameter.comm_fusion`.
Note:
The value of attribute will be overwritten when the function is called multiply.
Args:
fusion_type (int): The value of `comm_fusion`.
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
"""
Validator.check_non_negative_int(fusion_type)
for param in self.trainable_params(recurse):
param.comm_fusion = fusion_type
return self