modify import api and parameter comment

This commit is contained in:
changzherui 2021-06-23 12:45:49 +08:00
parent ac6d75b803
commit bf18942554
2 changed files with 8 additions and 5 deletions

View File

@ -384,6 +384,7 @@ class Parameter(Tensor_):
@property
def layerwise_parallel(self):
"""Return whether the parameter is layerwise parallel."""
return self.param_info.layerwise_parallel
@layerwise_parallel.setter
@ -438,10 +439,11 @@ class Parameter(Tensor_):
@property
def data(self):
"""Return the parameter object."""
return self
def _update_tensor_data(self, data):
"Update the parameter by a Tensor."
"""Update the parameter by a Tensor."""
if isinstance(self, Tensor):
self.init_flag = False
self.init = None
@ -452,7 +454,7 @@ class Parameter(Tensor_):
def set_data(self, data, slice_shape=False):
"""
Set `set_data` of current `Parameter`.
Set Parameter's data.
Args:
data (Union[Tensor, int, float]): new data.
@ -512,7 +514,7 @@ class Parameter(Tensor_):
def init_data(self, layout=None, set_sliced=False):
"""
Initialize the parameter data.
Initialize the parameter's data.
Args:
layout (Union[None, list(list(int))]): Parameter slice
@ -527,6 +529,7 @@ class Parameter(Tensor_):
Raises:
RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
ValueError: If the length of the layout is less than 3.
Returns:
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,

View File

@ -23,9 +23,9 @@ from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint"]
"load_distributed_checkpoint", "async_ckpt_thread_status"]