add test cases

This commit is contained in:
yangzhenzhang 2021-08-25 16:54:29 +08:00
parent 6301361570
commit 0b9b2a9458
3 changed files with 120 additions and 3 deletions

View File

@ -64,6 +64,16 @@ def test_conv2d_data_parallel():
compile_net(net) compile_net(net)
def test_conv2d_data_parallel_invalid_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_data_parallel_dilation(): def test_conv2d_data_parallel_dilation():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
@ -176,6 +186,15 @@ def test_conv2d_output_can_not_divisible_by_strategy():
compile_net(net) compile_net(net)
def test_conv2d_output_can_not_divisible_by_strategy2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_split_kernel(): def test_split_kernel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2)) strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
@ -203,6 +222,33 @@ def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_v
compile_net(net, _x2) compile_net(net, _x2)
def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x2)
def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x2)
def test_split_h_dimension_and_pad_mode_is_pad():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride(): def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))

View File

@ -112,6 +112,26 @@ def test_conv2d_transpose_all_rank_no_need_overlap():
compile_net(net) compile_net(net)
def test_conv2d_transpose_split_h_or_w_in_pad_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="pad", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_split_h_in_same_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_overlap_size_too_large(): def test_conv2d_transpose_overlap_size_too_large():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))

View File

@ -52,17 +52,18 @@ class Net2(Cell):
return out return out
_x0 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net): def compile_net(net, inputs=_x):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer) train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel() train_net.set_auto_parallel()
train_net.set_train() train_net.set_train()
_executor.compile(train_net, _x, _b) _executor.compile(train_net, inputs, _b)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
@ -99,7 +100,7 @@ def test_maxpool_auto_parallel():
compile_net(net) compile_net(net)
def test_maxpool_output_can_not_divisible_by_strategy(): def test_maxpool_output_is_not_divisible_by_strategy_w_dimension():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),) strategy2 = ((1, 1, 1, 8),)
@ -109,6 +110,56 @@ def test_maxpool_output_can_not_divisible_by_strategy():
compile_net(net) compile_net(net)
def test_maxpool_output_is_not_divisible_by_strategy_h_dimension():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 8, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_maxpool_shard_h_and_kernel_size_larger_than_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 2, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_maxpool_shard_w_and_kernel_size_larger_than_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_maxpool_shard_h_and_input_slice_is_not_divisible_by_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 2, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, inputs=_x0)
def test_maxpool_shard_w_and_input_slice_is_not_divisible_by_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 2, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, inputs=_x0)
def test_avgpool_data_parallel(): def test_avgpool_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))