forked from mindspore-Ecosystem/mindspore
add_virtualdataset_ut
This commit is contained in:
parent
b87cc92a94
commit
bc62e24d94
|
@ -75,6 +75,18 @@ class Net2(nn.Cell):
|
|||
out = self.biasadd(out, b)
|
||||
return out
|
||||
|
||||
class Net3(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(strategy1)
|
||||
self.matmul2 = P.MatMul().shard(strategy2)
|
||||
self.gelu = P.GeLU().shard(strategy3)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.gelu(self.matmul1(x, y))
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
def compile_net(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
|
@ -216,6 +228,44 @@ def test_virtual_dataset_data_parallel_not_fully_shard_repeat_right():
|
|||
net = GradWrap(NetWithLoss(backbone))
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
def test_without_virtual_dataset_model_parallel_semi_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1, 8))
|
||||
context.set_auto_parallel_context(dataset_strategy=strategy0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((2, 2), (2, 2))
|
||||
strategy3 = ((2, 4),)
|
||||
net = GradWrap(NetWithLoss(Net3(strategy1, strategy2, strategy3)))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
def test_without_virtual_dataset_model_parallel_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator virtual_dataset in auto parallel.
|
||||
Description: virtual_dataset/model_parallel/fully shard/repeat in left.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy0 = ((1, 8), (1, 8), (1, 8))
|
||||
context.set_auto_parallel_context(dataset_strategy=strategy0)
|
||||
strategy1 = None
|
||||
strategy2 = None
|
||||
strategy3 = None
|
||||
net = GradWrap(NetWithLoss(Net3(strategy1, strategy2, strategy3)))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue