add global batch normalization

This commit is contained in:
zhaojichen 2020-04-17 21:47:39 -04:00
parent c5120e770c
commit 17e27824c5
1 changed files with 0 additions and 18 deletions

View File

@ -73,21 +73,3 @@ def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
_executor.compile(net, input_data)
class GlobalBNNet(nn.Cell):
def __init__(self):
super(GlobalBNNet, self).__init__()
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2)
def construct(self, x):
return self.bn(x)
def test_global_bn():
init("hccl")
size = 4
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=size, parameter_broadcast=True)
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
_executor.compile(net,input_data)