resolve frontend layput

This commit is contained in:
caozhou 2020-08-24 18:53:47 +08:00
parent 8d0b52fb13
commit 8287445f95
1 changed files with 11 additions and 7 deletions

View File

@ -712,8 +712,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
Args:
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
strategy (dict): Parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
strategy (dict): Parameter slice strategy, the default is None.
If strategy is None, just merge parameter slices in 0 axis order.
- key (str): Parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
@ -728,11 +728,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> sliced_parameters = [
>>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
>>> "network.embedding_table"),
>>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
>>> "network.embedding_table"),
>>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
>>> "network.embedding_tabel"),
>>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
>>> "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
"""
if not isinstance(sliced_parameters, list):