From ab23750a15a9211cfbc26ffa34e9253e91dbec6d Mon Sep 17 00:00:00 2001 From: Ziyan Date: Fri, 12 Mar 2021 15:41:09 +0800 Subject: [PATCH] add parallel validation --- mindspore/common/parameter.py | 15 ++++++++++++++- mindspore/nn/cell.py | 13 ++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 57efd5bc6d5..6e52409c13e 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -19,6 +19,8 @@ import numbers import numpy as np from .._c_expression import ParamInfo from . import dtype as mstype +from .. import context +from ..parallel._utils import _get_parallel_mode from .initializer import initializer from .tensor import Tensor from .._checkparam import Validator @@ -292,7 +294,18 @@ class Parameter(Tensor_): @comm_fusion.setter def comm_fusion(self, comm_fusion_): - """Set the fusion type for communication operators corresponding to this parameter.""" + """ + In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or + gradients aggregation are inserted automatically.Set the fusion type for communication operators generated + for this parameter. Only `Ascend` and `Graph` mode is supported. + + Args: + comm_fusion_ (int): The value of fusion must be greater than or equal to 0. + When the value of fusion is 0, operators will not be fused together. + """ + if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): + raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") + Validator.check_non_negative_int(comm_fusion_) self.param_info.comm_fusion = comm_fusion_ @property diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 21a7f9fdda1..0f59a873132 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1129,7 +1129,18 @@ class Cell(Cell_): param.set_param_ps(init_in_server) def set_comm_fusion(self, fusion_type, recurse=True): - Validator.check_is_int(fusion_type) + """ + Set `comm_fusion` for all the parameters in the Net. Please refer to the description of + `mindspore.common.parameter.comm_fusion`. + + Note: + The value of attribute will be overwritten when the function is called multiply. + + Args: + fusion_type (int): The value of `comm_fusion`. + recurse (bool): Whether sets the trainable parameters of subcells. Default: True. + """ + Validator.check_non_negative_int(fusion_type) for param in self.trainable_params(recurse): param.comm_fusion = fusion_type return self