modify import api and parameter comment

This commit is contained in:
changzherui 2021-06-23 12:45:49 +08:00
parent ac6d75b803
commit bf18942554
2 changed files with 8 additions and 5 deletions

View File

@ -384,6 +384,7 @@ class Parameter(Tensor_):
@property @property
def layerwise_parallel(self): def layerwise_parallel(self):
"""Return whether the parameter is layerwise parallel."""
return self.param_info.layerwise_parallel return self.param_info.layerwise_parallel
@layerwise_parallel.setter @layerwise_parallel.setter
@ -438,10 +439,11 @@ class Parameter(Tensor_):
@property @property
def data(self): def data(self):
"""Return the parameter object."""
return self return self
def _update_tensor_data(self, data): def _update_tensor_data(self, data):
"Update the parameter by a Tensor." """Update the parameter by a Tensor."""
if isinstance(self, Tensor): if isinstance(self, Tensor):
self.init_flag = False self.init_flag = False
self.init = None self.init = None
@ -452,7 +454,7 @@ class Parameter(Tensor_):
def set_data(self, data, slice_shape=False): def set_data(self, data, slice_shape=False):
""" """
Set `set_data` of current `Parameter`. Set Parameter's data.
Args: Args:
data (Union[Tensor, int, float]): new data. data (Union[Tensor, int, float]): new data.
@ -512,7 +514,7 @@ class Parameter(Tensor_):
def init_data(self, layout=None, set_sliced=False): def init_data(self, layout=None, set_sliced=False):
""" """
Initialize the parameter data. Initialize the parameter's data.
Args: Args:
layout (Union[None, list(list(int))]): Parameter slice layout (Union[None, list(list(int))]): Parameter slice
@ -527,6 +529,7 @@ class Parameter(Tensor_):
Raises: Raises:
RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created. RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
ValueError: If the length of the layout is less than 3.
Returns: Returns:
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,

View File

@ -23,9 +23,9 @@ from . import amp
from .amp import build_train_network from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\ from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint"] "load_distributed_checkpoint", "async_ckpt_thread_status"]