diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc index cbffc10e701..3e24aa017e0 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc @@ -42,7 +42,8 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { } else { auto device_arrangement = tensor_layout->device_arrangement().array(); auto tensor_map = tensor_layout->tensor_map().array(); - std::pair, std::vector> layout(device_arrangement, tensor_map); + auto slice_shape = tensor_layout->slice_shape().array(); + std::vector> layout = {device_arrangement, tensor_map, slice_shape}; dict[py::str(name)] = layout; MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); } diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index 3d314cf724d..019ec4aa11b 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -203,19 +203,19 @@ def _load_tensor_by_layout(tensor, layout): Args: tensor (Tensor): The input tensor. - layout (tuple): The tensor layout in auto parallel. + layout (list): The tensor layout in auto parallel. Returns: - Tensor, the sliced tensor.. + Tensor, the sliced tensor. Raises: - TypeError: If layout is not tuple. - ValueError: If the length of layout is not 2. + TypeError: If layout is not list. + ValueError: If the length of layout is not 3. """ - if not isinstance(layout, tuple): - raise TypeError("layout should be tuple! layout is {}".format(layout)) - if len(layout) != 2: - raise ValueError("The length of layout must be 2! layout is {}".format(layout)) + if not isinstance(layout, list): + raise TypeError("The layout should be list! layout is {}".format(layout)) + if len(layout) != 3: + raise ValueError("The length of layout must be 3! layout is {}".format(layout)) dev_mat = layout[0] tensor_map = layout[1] if tensor.size() == 1: diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index b390ce9b335..8588cec21a3 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -48,8 +48,8 @@ def test_get_parameter_layout(): net.set_auto_parallel() exe = me._executor exe.compile(net, x, auto_parallel_mode=True) - x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1] - weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1] + x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] + weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] expect_dict = {'x': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert (net.parameter_layout_dict == expect_dict)