diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index efe9e5dcc0d..6e6b7296245 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -384,6 +384,7 @@ class Parameter(Tensor_): @property def layerwise_parallel(self): + """Return whether the parameter is layerwise parallel.""" return self.param_info.layerwise_parallel @layerwise_parallel.setter @@ -438,10 +439,11 @@ class Parameter(Tensor_): @property def data(self): + """Return the parameter object.""" return self def _update_tensor_data(self, data): - "Update the parameter by a Tensor." + """Update the parameter by a Tensor.""" if isinstance(self, Tensor): self.init_flag = False self.init = None @@ -452,7 +454,7 @@ class Parameter(Tensor_): def set_data(self, data, slice_shape=False): """ - Set `set_data` of current `Parameter`. + Set Parameter's data. Args: data (Union[Tensor, int, float]): new data. @@ -512,7 +514,7 @@ class Parameter(Tensor_): def init_data(self, layout=None, set_sliced=False): """ - Initialize the parameter data. + Initialize the parameter's data. Args: layout (Union[None, list(list(int))]): Parameter slice @@ -527,6 +529,7 @@ class Parameter(Tensor_): Raises: RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created. + ValueError: If the length of the layout is less than 3. Returns: Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, diff --git a/mindspore/train/__init__.py b/mindspore/train/__init__.py index 5a05c655b0c..2f567d39e64 100644 --- a/mindspore/train/__init__.py +++ b/mindspore/train/__init__.py @@ -23,9 +23,9 @@ from . import amp from .amp import build_train_network from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\ - build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint + build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", - "load_distributed_checkpoint"] + "load_distributed_checkpoint", "async_ckpt_thread_status"]