forked from mindspore-Ecosystem/mindspore
!49755 create group for syncbatchnorm 1.10
Merge pull request !49755 from yangzhenzhang/modify-create-group-for-syncbn-1.10
This commit is contained in:
commit
8570d26ed2
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||
|
||||
import itertools
|
||||
import numbers
|
||||
import hashlib
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -36,11 +37,20 @@ from mindspore.communication import management
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore import log as logger
|
||||
|
||||
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
|
||||
'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
||||
|
||||
SYNC_BN_GROUP_NAME = ""
|
||||
|
||||
SYNCBN_GROUP_DICT = None
|
||||
|
||||
|
||||
def _syncbatchnorm_group_dict():
|
||||
global SYNCBN_GROUP_DICT
|
||||
if SYNCBN_GROUP_DICT is None:
|
||||
SYNCBN_GROUP_DICT = dict()
|
||||
return SYNCBN_GROUP_DICT
|
||||
|
||||
|
||||
class _BatchNorm(Cell):
|
||||
|
@ -97,18 +107,20 @@ class _BatchNorm(Cell):
|
|||
self.cls_name)
|
||||
self.process_groups = process_groups
|
||||
self.is_global = False
|
||||
self.group_name = None
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
global SYNC_BN_GROUP_NAME
|
||||
|
||||
# for GlobalBatchNorm
|
||||
if self.group_device_num != 1:
|
||||
self.is_global = True
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
self.rank_list = self.list_group(self.device_list, self.group_device_num)
|
||||
self.rank_list_idx = len(self.rank_list)
|
||||
self._create_global_groups()
|
||||
# for SyncBatchNorm
|
||||
if self.process_groups != 0:
|
||||
self.is_global = True
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
if self.process_groups is not None:
|
||||
|
@ -116,16 +128,11 @@ class _BatchNorm(Cell):
|
|||
self._check_rank_ids(self.process_groups, self.rank_size)
|
||||
self._create_sync_groups()
|
||||
elif self.rank_size > 1:
|
||||
self.is_global = True
|
||||
self.group_device_num = self.rank_size
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
self.group_name = "hccl_world_group"
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
||||
self.group_name = "nccl_world_group"
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
|
@ -149,7 +156,7 @@ class _BatchNorm(Cell):
|
|||
if self.is_global:
|
||||
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
group=SYNC_BN_GROUP_NAME,
|
||||
group=self.group_name,
|
||||
device_num=self.group_device_num)
|
||||
|
||||
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
||||
|
@ -226,25 +233,34 @@ class _BatchNorm(Cell):
|
|||
f"but got {process_groups}.")
|
||||
seen.add(rid)
|
||||
|
||||
def _create_groups(self, process_groups):
|
||||
""" create groups by process groups. """
|
||||
for sub_group in process_groups:
|
||||
validator.check_isinstance("sub group", sub_group, list)
|
||||
self.group_device_num = len(sub_group)
|
||||
if self.rank_id in sub_group and self.group_device_num > 1:
|
||||
rank_list_name = '_'.join('%s' % id for id in sub_group)
|
||||
group_dict = _syncbatchnorm_group_dict()
|
||||
if rank_list_name not in group_dict:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(rank_list_name.encode('utf-8'))
|
||||
hash_name = md5.hexdigest()
|
||||
self.group_name = str(self.group_device_num) + '_' + hash_name
|
||||
group_dict[rank_list_name] = self.group_name
|
||||
management.create_group(self.group_name, sub_group)
|
||||
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
|
||||
rank_list_name, self.group_name))
|
||||
else:
|
||||
self.group_name = group_dict[rank_list_name]
|
||||
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
|
||||
|
||||
def _create_global_groups(self):
|
||||
for i in range(self.rank_list_idx):
|
||||
if self.rank_id in self.rank_list[i]:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
||||
""" create global groups. """
|
||||
self._create_groups(self.rank_list)
|
||||
|
||||
def _create_sync_groups(self):
|
||||
for i in range(len(self.process_groups)):
|
||||
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
|
||||
self.group_device_num = len(self.process_groups[i])
|
||||
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||
""" create sync groups. """
|
||||
self._create_groups(self.process_groups)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -304,6 +320,8 @@ def _shape_infer(x_shape, num_feature):
|
|||
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
Batch Normalization layer over a 2D input.
|
||||
|
||||
This layer
|
||||
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
|
||||
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
|
||||
|
@ -395,6 +413,8 @@ class BatchNorm1d(_BatchNorm):
|
|||
|
||||
class BatchNorm2d(_BatchNorm):
|
||||
r"""
|
||||
Batch Normalization layer over a 4D input.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This layer
|
||||
applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
|
||||
additional channel dimension) to avoid internal covariate shift as described
|
||||
|
@ -521,6 +541,8 @@ def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
|
|||
|
||||
class BatchNorm3d(Cell):
|
||||
r"""
|
||||
Batch Normalization layer over a 5D input.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This layer
|
||||
applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
|
||||
additional channel dimension) to avoid internal covariate shift.
|
||||
|
@ -747,7 +769,7 @@ class SyncBatchNorm(_BatchNorm):
|
|||
[[ 0.999995 0.999995 ]
|
||||
[ 0.999995 0.999995 ]]]]
|
||||
"""
|
||||
|
||||
@cell_attr_register(attrs=['num_features', 'process_groups'])
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
|
@ -930,8 +952,10 @@ class _InstanceNorm(Cell):
|
|||
|
||||
class InstanceNorm1d(_InstanceNorm):
|
||||
r"""
|
||||
Instance Normalization layer over a 3D input.
|
||||
|
||||
This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with
|
||||
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
|
||||
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
|
||||
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
|
||||
of data and the learned parameters which can be described in the following formula.
|
||||
|
||||
|
@ -1018,8 +1042,10 @@ class InstanceNorm1d(_InstanceNorm):
|
|||
|
||||
class InstanceNorm2d(_InstanceNorm):
|
||||
r"""
|
||||
Instance Normalization layer over a 4D input.
|
||||
|
||||
This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
|
||||
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
|
||||
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
|
||||
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
|
||||
of data and the learned parameters which can be described in the following formula.
|
||||
|
||||
|
@ -1106,8 +1132,10 @@ class InstanceNorm2d(_InstanceNorm):
|
|||
|
||||
class InstanceNorm3d(_InstanceNorm):
|
||||
r"""
|
||||
Instance Normalization layer over a 5D input.
|
||||
|
||||
This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with
|
||||
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
|
||||
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
|
||||
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
|
||||
of data and the learned parameters which can be described in the following formula.
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum, SyncBatchNorm
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride):
|
||||
super().__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.bn1 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
|
||||
self.bn2 = SyncBatchNorm(num_features=8, process_groups=[[0, 1, 2, 3]])
|
||||
self.bn3 = SyncBatchNorm(num_features=8)
|
||||
self.bn4 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.bn1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.bn3(out)
|
||||
out = self.bn4(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_syncbatchnorm():
|
||||
"""
|
||||
Feature: test syncbatchnorm
|
||||
Description: create group
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=4, global_rank=0)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
assert net.bn1.group_name == "2_174882033225436b1440b7de44686450"
|
||||
assert net.bn2.group_name == "4_937e3b535d29ac4571b6fecb60df6169"
|
||||
assert net.bn3.group_name == "hccl_world_group"
|
||||
assert net.bn4.group_name == "2_174882033225436b1440b7de44686450"
|
Loading…
Reference in New Issue