diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 070ce68545c..6b43d0d0eb3 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -712,8 +712,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): Args: sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. - strategy (dict): Parameter slice strategy. Default: None. - If strategy is None, just merge parameter slices in 0 axis order. + strategy (dict): Parameter slice strategy, the default is None. + If strategy is None, just merge parameter slices in 0 axis order. - key (str): Parameter name. - value (): Slice strategy of this parameter. @@ -728,11 +728,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): Examples: >>> strategy = build_searched_strategy("./strategy_train.ckpt") - >>> sliced_parameters = [\ - Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \ - Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \ - Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \ - Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")] + >>> sliced_parameters = [ + >>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), + >>> "network.embedding_table"), + >>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), + >>> "network.embedding_table"), + >>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), + >>> "network.embedding_tabel"), + >>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), + >>> "network.embedding_table")] >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) """ if not isinstance(sliced_parameters, list):