!26804 virtual_dataset_avoid_auto_parallel

Merge pull request !26804 from yao_yf/virtual_dataset_avoid_auto_parallel
This commit is contained in:
i-robot 2021-11-26 01:48:53 +00:00 committed by Gitee
commit 6ecbc97fd6
3 changed files with 63 additions and 2 deletions

View File

@ -42,7 +42,12 @@ Status GetNextInfo::InferTensorMap() {
if (dim == 1) {
tensor_map_index.push_back(MAP_NONE);
} else if (dim == shard_num_) {
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_origin_.size() != dev_matrix_shape_.size() &&
is_auto_parallel_) {
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - slice_dim);
} else {
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
}
} else {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support fully shard in one dim.";
return FAILED;

View File

@ -108,7 +108,12 @@ Status VirtualDatasetInfo::InferTensorMap() {
if (dim == 1) {
tensor_map_index.push_back(MAP_NONE);
} else if (dim == shard_num_) {
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_.size() != dev_mat_origin.size() &&
is_auto_parallel_) {
tensor_map_index.push_back(dev_mat_origin.size() - slice_dim);
} else {
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
}
} else {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
return FAILED;

View File

@ -82,6 +82,11 @@ def compile_net(net, x, y, b):
def test_virtual_dataset_model_parallel_semi_auto_parallel():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy0 = ((1, 8), (1, 8), (1, 8))
@ -96,6 +101,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel():
compile_net(net, x, y, b)
def test_virtual_dataset_model_parallel_auto_parallel():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy0 = ((1, 8), (1, 8), (1, 8))
@ -110,6 +120,11 @@ def test_virtual_dataset_model_parallel_auto_parallel():
compile_net(net, x, y, b)
def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy0 = ((1, 8), (1, 8), (8,))
@ -124,6 +139,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim():
compile_net(net, x, y, b)
def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy0 = ((1, 8), (1, 8), (8,))
@ -138,6 +158,11 @@ def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim():
compile_net(net, x, y, b)
def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim_not_fully_shard():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/not fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=16, global_rank=0)
strategy0 = ((1, 8), (1, 8), (1,))
@ -152,6 +177,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim_not_fu
compile_net(net, x, y, b)
def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim_not_fully_shard():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/model_parallel/not fully shard/repeat in left.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel")
context.set_auto_parallel_context(device_num=16, global_rank=0)
strategy0 = ((1, 8), (1, 8), (1,))
@ -165,6 +195,27 @@ def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim_not_fully_s
net = GradWrap(NetWithLoss(Net2(strategy1, strategy2, strategy3)))
compile_net(net, x, y, b)
def test_virtual_dataset_data_parallel_not_fully_shard_repeat_right():
"""
Feature: distribute operator virtual_dataset in auto parallel.
Description: virtual_dataset/data_parallel/not fully shard/repeat in right.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=16, global_rank=0)
strategy0 = ((4, 1), (4, 1), (4,))
context.set_auto_parallel_context(dataset_strategy=strategy0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((1, 8), (8,))
strategy3 = ((2, 4),)
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
backbone = Net2(strategy1, strategy2, strategy3)
backbone.virtual_dataset.add_prim_attr("repeat_dim_direct", "right")
net = GradWrap(NetWithLoss(backbone))
compile_net(net, x, y, b)
if __name__ == '__main__':
context.reset_auto_parallel_context()