!49755 create group for syncbatchnorm 1.10

Merge pull request !49755 from yangzhenzhang/modify-create-group-for-syncbn-1.10
This commit is contained in:
i-robot 2023-03-07 11:38:48 +00:00 committed by Gitee
commit 8570d26ed2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 128 additions and 31 deletions

View File

@ -18,6 +18,7 @@ from __future__ import division
import itertools import itertools
import numbers import numbers
import hashlib
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -36,11 +37,20 @@ from mindspore.communication import management
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore import log as logger
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
SYNC_BN_GROUP_NAME = ""
SYNCBN_GROUP_DICT = None
def _syncbatchnorm_group_dict():
global SYNCBN_GROUP_DICT
if SYNCBN_GROUP_DICT is None:
SYNCBN_GROUP_DICT = dict()
return SYNCBN_GROUP_DICT
class _BatchNorm(Cell): class _BatchNorm(Cell):
@ -97,18 +107,20 @@ class _BatchNorm(Cell):
self.cls_name) self.cls_name)
self.process_groups = process_groups self.process_groups = process_groups
self.is_global = False self.is_global = False
self.group_name = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
global SYNC_BN_GROUP_NAME
# for GlobalBatchNorm # for GlobalBatchNorm
if self.group_device_num != 1: if self.group_device_num != 1:
self.is_global = True
self.rank_id = get_rank() self.rank_id = get_rank()
self.rank_size = get_group_size() self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)] self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group_device_num) self.rank_list = self.list_group(self.device_list, self.group_device_num)
self.rank_list_idx = len(self.rank_list)
self._create_global_groups() self._create_global_groups()
# for SyncBatchNorm # for SyncBatchNorm
if self.process_groups != 0: if self.process_groups != 0:
self.is_global = True
self.rank_id = get_rank() self.rank_id = get_rank()
self.rank_size = get_group_size() self.rank_size = get_group_size()
if self.process_groups is not None: if self.process_groups is not None:
@ -116,16 +128,11 @@ class _BatchNorm(Cell):
self._check_rank_ids(self.process_groups, self.rank_size) self._check_rank_ids(self.process_groups, self.rank_size)
self._create_sync_groups() self._create_sync_groups()
elif self.rank_size > 1: elif self.rank_size > 1:
self.is_global = True
self.group_device_num = self.rank_size self.group_device_num = self.rank_size
self.device_list = [i for i in range(0, self.rank_size)]
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
if SYNC_BN_GROUP_NAME == "": self.group_name = "hccl_world_group"
SYNC_BN_GROUP_NAME = "sync_bn_group0"
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
elif context.get_context("device_target") == "GPU": elif context.get_context("device_target") == "GPU":
if SYNC_BN_GROUP_NAME == "": self.group_name = "nccl_world_group"
SYNC_BN_GROUP_NAME = "nccl_world_group"
self.shape = P.Shape() self.shape = P.Shape()
self.reduce_mean = P.ReduceMean(keep_dims=True) self.reduce_mean = P.ReduceMean(keep_dims=True)
@ -149,7 +156,7 @@ class _BatchNorm(Cell):
if self.is_global: if self.is_global:
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
momentum=self.momentum, momentum=self.momentum,
group=SYNC_BN_GROUP_NAME, group=self.group_name,
device_num=self.group_device_num) device_num=self.group_device_num)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
@ -226,25 +233,34 @@ class _BatchNorm(Cell):
f"but got {process_groups}.") f"but got {process_groups}.")
seen.add(rid) seen.add(rid)
def _create_groups(self, process_groups):
""" create groups by process groups. """
for sub_group in process_groups:
validator.check_isinstance("sub group", sub_group, list)
self.group_device_num = len(sub_group)
if self.rank_id in sub_group and self.group_device_num > 1:
rank_list_name = '_'.join('%s' % id for id in sub_group)
group_dict = _syncbatchnorm_group_dict()
if rank_list_name not in group_dict:
md5 = hashlib.md5()
md5.update(rank_list_name.encode('utf-8'))
hash_name = md5.hexdigest()
self.group_name = str(self.group_device_num) + '_' + hash_name
group_dict[rank_list_name] = self.group_name
management.create_group(self.group_name, sub_group)
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
rank_list_name, self.group_name))
else:
self.group_name = group_dict[rank_list_name]
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
def _create_global_groups(self): def _create_global_groups(self):
for i in range(self.rank_list_idx): """ create global groups. """
if self.rank_id in self.rank_list[i]: self._create_groups(self.rank_list)
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
def _create_sync_groups(self): def _create_sync_groups(self):
for i in range(len(self.process_groups)): """ create sync groups. """
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list) self._create_groups(self.process_groups)
self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
@constexpr @constexpr
@ -304,6 +320,8 @@ def _shape_infer(x_shape, num_feature):
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r""" r"""
Batch Normalization layer over a 2D input.
This layer This layer
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks. reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
@ -395,6 +413,8 @@ class BatchNorm1d(_BatchNorm):
class BatchNorm2d(_BatchNorm): class BatchNorm2d(_BatchNorm):
r""" r"""
Batch Normalization layer over a 4D input.
Batch Normalization is widely used in convolutional networks. This layer Batch Normalization is widely used in convolutional networks. This layer
applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension) to avoid internal covariate shift as described additional channel dimension) to avoid internal covariate shift as described
@ -521,6 +541,8 @@ def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
class BatchNorm3d(Cell): class BatchNorm3d(Cell):
r""" r"""
Batch Normalization layer over a 5D input.
Batch Normalization is widely used in convolutional networks. This layer Batch Normalization is widely used in convolutional networks. This layer
applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
additional channel dimension) to avoid internal covariate shift. additional channel dimension) to avoid internal covariate shift.
@ -747,7 +769,7 @@ class SyncBatchNorm(_BatchNorm):
[[ 0.999995 0.999995 ] [[ 0.999995 0.999995 ]
[ 0.999995 0.999995 ]]]] [ 0.999995 0.999995 ]]]]
""" """
@cell_attr_register(attrs=['num_features', 'process_groups'])
def __init__(self, def __init__(self,
num_features, num_features,
eps=1e-5, eps=1e-5,
@ -930,8 +952,10 @@ class _InstanceNorm(Cell):
class InstanceNorm1d(_InstanceNorm): class InstanceNorm1d(_InstanceNorm):
r""" r"""
Instance Normalization layer over a 3D input.
This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula. of data and the learned parameters which can be described in the following formula.
@ -1018,8 +1042,10 @@ class InstanceNorm1d(_InstanceNorm):
class InstanceNorm2d(_InstanceNorm): class InstanceNorm2d(_InstanceNorm):
r""" r"""
Instance Normalization layer over a 4D input.
This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula. of data and the learned parameters which can be described in the following formula.
@ -1106,8 +1132,10 @@ class InstanceNorm2d(_InstanceNorm):
class InstanceNorm3d(_InstanceNorm): class InstanceNorm3d(_InstanceNorm):
r""" r"""
Instance Normalization layer over a 5D input.
This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula. of data and the learned parameters which can be described in the following formula.

View File

@ -0,0 +1,69 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum, SyncBatchNorm
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.bn1 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
self.bn2 = SyncBatchNorm(num_features=8, process_groups=[[0, 1, 2, 3]])
self.bn3 = SyncBatchNorm(num_features=8)
self.bn4 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.bn1(out)
out = self.bn2(out)
out = self.bn3(out)
out = self.bn4(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_syncbatchnorm():
"""
Feature: test syncbatchnorm
Description: create group
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=4, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
assert net.bn1.group_name == "2_174882033225436b1440b7de44686450"
assert net.bn2.group_name == "4_937e3b535d29ac4571b6fecb60df6169"
assert net.bn3.group_name == "hccl_world_group"
assert net.bn4.group_name == "2_174882033225436b1440b7de44686450"