add slice shape for param info

This commit is contained in:
yangzhenzhang 2020-05-11 14:07:35 +08:00
parent 95d4665db9
commit 05fde3d23d
3 changed files with 12 additions and 11 deletions

View File

@ -42,7 +42,8 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
} else {
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
std::pair<std::vector<int32_t>, std::vector<int32_t>> layout(device_arrangement, tensor_map);
auto slice_shape = tensor_layout->slice_shape().array();
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape};
dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
}

View File

@ -203,19 +203,19 @@ def _load_tensor_by_layout(tensor, layout):
Args:
tensor (Tensor): The input tensor.
layout (tuple): The tensor layout in auto parallel.
layout (list): The tensor layout in auto parallel.
Returns:
Tensor, the sliced tensor..
Tensor, the sliced tensor.
Raises:
TypeError: If layout is not tuple.
ValueError: If the length of layout is not 2.
TypeError: If layout is not list.
ValueError: If the length of layout is not 3.
"""
if not isinstance(layout, tuple):
raise TypeError("layout should be tuple! layout is {}".format(layout))
if len(layout) != 2:
raise ValueError("The length of layout must be 2! layout is {}".format(layout))
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}".format(layout))
dev_mat = layout[0]
tensor_map = layout[1]
if tensor.size() == 1:

View File

@ -48,8 +48,8 @@ def test_get_parameter_layout():
net.set_auto_parallel()
exe = me._executor
exe.compile(net, x, auto_parallel_mode=True)
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert (net.parameter_layout_dict == expect_dict)