!12608 change dimension of input for FusedBatchNormEx from 2D to 4D in test_two_matmul_batchnorm_ex.

From: @wangshuide2020
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-02-25 14:06:43 +08:00 committed by Gitee
commit 34daed0fbe
1 changed files with 7 additions and 7 deletions

View File

@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().shard(strategy1)
self.matmul1 = P.BatchMatMul().shard(strategy1)
self.norm = P.FusedBatchNormEx()
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta")
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean")
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var")
self.matmul2 = P.MatMul().shard(strategy2)
self.matmul2 = P.BatchMatMul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul1(x, y)
@ -66,12 +66,12 @@ def test_two_matmul_batchnorm_ex():
return out
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8)
strategy1 = ((4, 2), (2, 1))
strategy2 = ((1, 8), (8, 1))
strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1))
strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y, b)