diff --git a/tests/ut/python/parallel/test_conv2d.py b/tests/ut/python/parallel/test_conv2d.py index 34e63d04e7b..d89929c0b0b 100644 --- a/tests/ut/python/parallel/test_conv2d.py +++ b/tests/ut/python/parallel/test_conv2d.py @@ -64,6 +64,16 @@ def test_conv2d_data_parallel(): compile_net(net) +def test_conv2d_data_parallel_invalid_stride(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1), + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_conv2d_data_parallel_dilation(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) @@ -176,6 +186,15 @@ def test_conv2d_output_can_not_divisible_by_strategy(): compile_net(net) +def test_conv2d_output_can_not_divisible_by_strategy2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 8),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_split_kernel(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2)) @@ -203,6 +222,33 @@ def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_v compile_net(net, _x2) +def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 8),) + net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net, _x2) + + +def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 8),) + net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net, _x2) + + +def test_split_h_dimension_and_pad_mode_is_pad(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 8),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) diff --git a/tests/ut/python/parallel/test_conv2d_transpose.py b/tests/ut/python/parallel/test_conv2d_transpose.py index bd2cd2d32f1..aa7202fe656 100644 --- a/tests/ut/python/parallel/test_conv2d_transpose.py +++ b/tests/ut/python/parallel/test_conv2d_transpose.py @@ -112,6 +112,26 @@ def test_conv2d_transpose_all_rank_no_need_overlap(): compile_net(net) +def test_conv2d_transpose_split_h_or_w_in_pad_mode(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) + strategy2 = ((2, 2, 1, 4),) + net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="pad", stride=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_conv2d_transpose_split_h_in_same_mode(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1)) + strategy2 = ((2, 2, 1, 4),) + net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_conv2d_transpose_overlap_size_too_large(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) diff --git a/tests/ut/python/parallel/test_maxpool_avgpool.py b/tests/ut/python/parallel/test_maxpool_avgpool.py index 9604282d4a2..b88d2370fc1 100644 --- a/tests/ut/python/parallel/test_maxpool_avgpool.py +++ b/tests/ut/python/parallel/test_maxpool_avgpool.py @@ -52,17 +52,18 @@ class Net2(Cell): return out +_x0 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32) _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) -def compile_net(net): +def compile_net(net, inputs=_x): optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) train_net = TrainOneStepCell(net, optimizer) train_net.set_auto_parallel() train_net.set_train() - _executor.compile(train_net, _x, _b) + _executor.compile(train_net, inputs, _b) context.reset_auto_parallel_context() @@ -99,7 +100,7 @@ def test_maxpool_auto_parallel(): compile_net(net) -def test_maxpool_output_can_not_divisible_by_strategy(): +def test_maxpool_output_is_not_divisible_by_strategy_w_dimension(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy2 = ((1, 1, 1, 8),) @@ -109,6 +110,56 @@ def test_maxpool_output_can_not_divisible_by_strategy(): compile_net(net) +def test_maxpool_output_is_not_divisible_by_strategy_h_dimension(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 8, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_maxpool_shard_h_and_kernel_size_larger_than_stride(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 2, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_maxpool_shard_w_and_kernel_size_larger_than_stride(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 2),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_maxpool_shard_h_and_input_slice_is_not_divisible_by_stride(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 2, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net, inputs=_x0) + + +def test_maxpool_shard_w_and_input_slice_is_not_divisible_by_stride(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((1, 1, 2, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net, inputs=_x0) + + def test_avgpool_data_parallel(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))