forked from mindspore-Ecosystem/mindspore
add global batch normalization
This commit is contained in:
parent
c5120e770c
commit
17e27824c5
|
@ -73,21 +73,3 @@ def test_compile_groupnorm():
|
|||
net = nn.GroupNorm(16, 64)
|
||||
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
|
||||
_executor.compile(net, input_data)
|
||||
|
||||
class GlobalBNNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GlobalBNNet, self).__init__()
|
||||
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2)
|
||||
def construct(self, x):
|
||||
return self.bn(x)
|
||||
|
||||
def test_global_bn():
|
||||
init("hccl")
|
||||
size = 4
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=size, parameter_broadcast=True)
|
||||
net = GlobalBNNet()
|
||||
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
|
||||
_executor.compile(net,input_data)
|
||||
|
|
Loading…
Reference in New Issue