create group for syncbn

This commit is contained in:
yangzhenzhang 2023-03-06 14:31:12 +08:00
parent 82ed9ac3fa
commit 91bbca34a4
2 changed files with 106 additions and 22 deletions

View File

@ -18,6 +18,7 @@ from __future__ import division
import itertools import itertools
import numbers import numbers
import hashlib
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -35,12 +36,11 @@ from mindspore.communication import management
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore import log as logger
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
SYNC_BN_GROUP_NAME = ""
class _BatchNorm(Cell): class _BatchNorm(Cell):
"""Batch Normalization base class.""" """Batch Normalization base class."""
@ -404,6 +404,16 @@ class BatchNorm3d(Cell):
return bn3d_out return bn3d_out
SYNCBN_GROUP_DICT = None
def _syncbatchnorm_group_dict():
global SYNCBN_GROUP_DICT
if SYNCBN_GROUP_DICT is None:
SYNCBN_GROUP_DICT = dict()
return SYNCBN_GROUP_DICT
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
r""" r"""
Sync Batch Normalization layer over a N-dimension input. Sync Batch Normalization layer over a N-dimension input.
@ -500,7 +510,7 @@ class SyncBatchNorm(_BatchNorm):
[[ 0.999995 0.999995 ] [[ 0.999995 0.999995 ]
[ 0.999995 0.999995 ]]]] [ 0.999995 0.999995 ]]]]
""" """
@cell_attr_register(attrs=['num_features', 'process_groups'])
def __init__(self, def __init__(self,
num_features, num_features,
eps=1e-5, eps=1e-5,
@ -523,9 +533,10 @@ class SyncBatchNorm(_BatchNorm):
moving_var_init, moving_var_init,
use_batch_statistics) use_batch_statistics)
self.is_global = False self.is_global = False
global SYNC_BN_GROUP_NAME self.group_name = None
self.process_groups = process_groups self.process_groups = process_groups
if self.process_groups != 0: if self.process_groups != 0:
self.is_global = True
self.rank_id = get_rank() self.rank_id = get_rank()
self.rank_size = get_group_size() self.rank_size = get_group_size()
if self.process_groups is not None: if self.process_groups is not None:
@ -533,34 +544,38 @@ class SyncBatchNorm(_BatchNorm):
self._check_rank_ids(self.process_groups, self.rank_size) self._check_rank_ids(self.process_groups, self.rank_size)
self._create_sync_groups() self._create_sync_groups()
elif self.rank_size > 1: elif self.rank_size > 1:
self.is_global = True
self.group_device_num = self.rank_size self.group_device_num = self.rank_size
self.device_list = [i for i in range(0, self.rank_size)]
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
if SYNC_BN_GROUP_NAME == "": self.group_name = "hccl_world_group"
SYNC_BN_GROUP_NAME = "sync_bn_group0"
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
elif context.get_context("device_target") == "GPU": elif context.get_context("device_target") == "GPU":
if SYNC_BN_GROUP_NAME == "": self.group_name = "nccl_world_group"
SYNC_BN_GROUP_NAME = "nccl_world_group"
if self.is_global: if self.is_global:
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
momentum=self.momentum, momentum=self.momentum,
group=SYNC_BN_GROUP_NAME, group=self.group_name,
device_num=self.group_device_num) device_num=self.group_device_num)
def _create_sync_groups(self): def _create_sync_groups(self):
for i in range(len(self.process_groups)): """ create groups by process groups. """
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list) for sub_group in self.process_groups:
self.group_device_num = len(self.process_groups[i]) validator.check_isinstance("sub group", sub_group, list)
if self.rank_id in self.process_groups[i] and self.group_device_num > 1: self.group_device_num = len(sub_group)
self.is_global = True if self.rank_id in sub_group and self.group_device_num > 1:
global SYNC_BN_GROUP_NAME rank_list_name = '_'.join('%s' % id for id in sub_group)
if SYNC_BN_GROUP_NAME == "": group_dict = _syncbatchnorm_group_dict()
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i if rank_list_name not in group_dict:
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) md5 = hashlib.md5()
md5.update(rank_list_name.encode('utf-8'))
hash_name = md5.hexdigest()
self.group_name = str(self.group_device_num) + '_' + hash_name
group_dict[rank_list_name] = self.group_name
management.create_group(self.group_name, sub_group)
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
rank_list_name, self.group_name))
else:
self.group_name = group_dict[rank_list_name]
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
def _check_rank_ids(self, process_groups, rank_size): def _check_rank_ids(self, process_groups, rank_size):
seen = set() seen = set()

View File

@ -0,0 +1,69 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum, SyncBatchNorm
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.bn1 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
self.bn2 = SyncBatchNorm(num_features=8, process_groups=[[0, 1, 2, 3]])
self.bn3 = SyncBatchNorm(num_features=8)
self.bn4 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.bn1(out)
out = self.bn2(out)
out = self.bn3(out)
out = self.bn4(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_syncbatchnorm():
"""
Feature: test syncbatchnorm
Description: create group
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=4, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
assert net.bn1.group_name == "2_174882033225436b1440b7de44686450"
assert net.bn2.group_name == "4_937e3b535d29ac4571b6fecb60df6169"
assert net.bn3.group_name == "hccl_world_group"
assert net.bn4.group_name == "2_174882033225436b1440b7de44686450"