forked from mindspore-Ecosystem/mindspore
!26804 virtual_dataset_avoid_auto_parallel
Merge pull request !26804 from yao_yf/virtual_dataset_avoid_auto_parallel
This commit is contained in:
commit
6ecbc97fd6
|
@ -42,7 +42,12 @@ Status GetNextInfo::InferTensorMap() {
|
|||
if (dim == 1) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
} else if (dim == shard_num_) {
|
||||
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
|
||||
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_origin_.size() != dev_matrix_shape_.size() &&
|
||||
is_auto_parallel_) {
|
||||
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - slice_dim);
|
||||
} else {
|
||||
tensor_map_index.push_back(dev_matrix_shape_origin_.size() - 1 - slice_dim);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support fully shard in one dim.";
|
||||
return FAILED;
|
||||
|
|
|
@ -108,7 +108,12 @@ Status VirtualDatasetInfo::InferTensorMap() {
|
|||
if (dim == 1) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
} else if (dim == shard_num_) {
|
||||
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
|
||||
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_.size() != dev_mat_origin.size() &&
|
||||
is_auto_parallel_) {
|
||||
tensor_map_index.push_back(dev_mat_origin.size() - slice_dim);
|
||||
} else {
|
||||
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
||||
return FAILED;
|
||||
|
|
|
@ -82,6 +82,11 @@ def compile_net(net, x, y, b):
|
|||
|
||||
|
||||
def test_virtual_dataset_model_parallel_semi_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1, 8))
|
||||
|
@ -96,6 +101,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel():
|
|||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_model_parallel_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1, 8))
|
||||
|
@ -110,6 +120,11 @@ def test_virtual_dataset_model_parallel_auto_parallel():
|
|||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (8,))
|
||||
|
@ -124,6 +139,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim():
|
|||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (8,))
|
||||
|
@ -138,6 +158,11 @@ def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim():
|
|||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim_not_fully_shard():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/not fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1,))
|
||||
|
@ -152,6 +177,11 @@ def test_virtual_dataset_model_parallel_semi_auto_parallel_diff_input_dim_not_fu
|
|||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim_not_fully_shard():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/not fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1,))
|
||||
|
@ -165,6 +195,27 @@ def test_virtual_dataset_model_parallel_auto_parallel_diff_input_dim_not_fully_s
|
|||
net = GradWrap(NetWithLoss(Net2(strategy1, strategy2, strategy3)))
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
def test_virtual_dataset_data_parallel_not_fully_shard_repeat_right():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/data_parallel/not fully shard/repeat in right.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||
strategy0 = ((4, 1), (4, 1), (4,))
|
||||
context.set_auto_parallel_context(dataset_strategy=strategy0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((1, 8), (8,))
|
||||
strategy3 = ((2, 4),)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
backbone = Net2(strategy1, strategy2, strategy3)
|
||||
backbone.virtual_dataset.add_prim_attr("repeat_dim_direct", "right")
|
||||
net = GradWrap(NetWithLoss(backbone))
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue