From 17e27824c54c163cb11bd1d4b8d3b257b149b123 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 21:47:39 -0400 Subject: [PATCH] add global batch normalization --- tests/ut/python/nn/test_batchnorm.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index b6e27e69502..10b4cb00a1e 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -73,21 +73,3 @@ def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) _executor.compile(net, input_data) - -class GlobalBNNet(nn.Cell): - def __init__(self): - super(GlobalBNNet, self).__init__() - self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2) - def construct(self, x): - return self.bn(x) - -def test_global_bn(): - init("hccl") - size = 4 - context.set_context(mode=context.GRAPH_MODE) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, - device_num=size, parameter_broadcast=True) - net = GlobalBNNet() - input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - _executor.compile(net,input_data)