diff --git a/tests/ut/python/parallel/test_virtual_dataset_with_strategy.py b/tests/ut/python/parallel/test_virtual_dataset_with_strategy.py index 8f276148b78..45264982298 100644 --- a/tests/ut/python/parallel/test_virtual_dataset_with_strategy.py +++ b/tests/ut/python/parallel/test_virtual_dataset_with_strategy.py @@ -75,6 +75,18 @@ class Net2(nn.Cell): out = self.biasadd(out, b) return out +class Net3(nn.Cell): + def __init__(self, strategy1, strategy2, strategy3): + super().__init__() + self.matmul1 = P.MatMul().shard(strategy1) + self.matmul2 = P.MatMul().shard(strategy2) + self.gelu = P.GeLU().shard(strategy3) + + def construct(self, x, y, b): + out = self.gelu(self.matmul1(x, y)) + out = self.matmul2(out, b) + return out + def compile_net(net, x, y, b): net.set_auto_parallel() net.set_train() @@ -216,6 +228,44 @@ def test_virtual_dataset_data_parallel_not_fully_shard_repeat_right(): net = GradWrap(NetWithLoss(backbone)) compile_net(net, x, y, b) +def test_without_virtual_dataset_model_parallel_semi_auto_parallel(): + """ + Feature: distribute operator virtual_dataset in auto parallel. + Description: virtual_dataset/model_parallel/fully shard/repeat in left. + Expectation: compile done without error. + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + context.set_auto_parallel_context(device_num=8, global_rank=0) + strategy0 = ((1, 8), (1, 8), (1, 8)) + context.set_auto_parallel_context(dataset_strategy=strategy0) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((2, 2), (2, 2)) + strategy3 = ((2, 4),) + net = GradWrap(NetWithLoss(Net3(strategy1, strategy2, strategy3))) + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 2048]), dtype=ms.float32) + compile_net(net, x, y, b) + +def test_without_virtual_dataset_model_parallel_auto_parallel(): + """ + Feature: distribute operator virtual_dataset in auto parallel. + Description: virtual_dataset/model_parallel/fully shard/repeat in left. + Expectation: compile done without error. + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel") + context.set_auto_parallel_context(device_num=8, global_rank=0) + strategy0 = ((1, 8), (1, 8), (1, 8)) + context.set_auto_parallel_context(dataset_strategy=strategy0) + strategy1 = None + strategy2 = None + strategy3 = None + net = GradWrap(NetWithLoss(Net3(strategy1, strategy2, strategy3))) + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 2048]), dtype=ms.float32) + compile_net(net, x, y, b) + if __name__ == '__main__': context.reset_auto_parallel_context()