forked from mindspore-Ecosystem/mindspore
!1054 Add slice shape for param info
Merge pull request !1054 from yangzhenzhang/add-slice-shape-for-param-info
This commit is contained in:
commit
2af6ee2482
|
@ -42,7 +42,8 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
} else {
|
||||
auto device_arrangement = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
std::pair<std::vector<int32_t>, std::vector<int32_t>> layout(device_arrangement, tensor_map);
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape};
|
||||
dict[py::str(name)] = layout;
|
||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||
}
|
||||
|
|
|
@ -203,19 +203,19 @@ def _load_tensor_by_layout(tensor, layout):
|
|||
|
||||
Args:
|
||||
tensor (Tensor): The input tensor.
|
||||
layout (tuple): The tensor layout in auto parallel.
|
||||
layout (list): The tensor layout in auto parallel.
|
||||
|
||||
Returns:
|
||||
Tensor, the sliced tensor..
|
||||
Tensor, the sliced tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If layout is not tuple.
|
||||
ValueError: If the length of layout is not 2.
|
||||
TypeError: If layout is not list.
|
||||
ValueError: If the length of layout is not 3.
|
||||
"""
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError("layout should be tuple! layout is {}".format(layout))
|
||||
if len(layout) != 2:
|
||||
raise ValueError("The length of layout must be 2! layout is {}".format(layout))
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}".format(layout))
|
||||
if len(layout) != 3:
|
||||
raise ValueError("The length of layout must be 3! layout is {}".format(layout))
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
if tensor.size() == 1:
|
||||
|
|
|
@ -48,8 +48,8 @@ def test_get_parameter_layout():
|
|||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, auto_parallel_mode=True)
|
||||
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert (net.parameter_layout_dict == expect_dict)
|
||||
|
|
Loading…
Reference in New Issue