!49755 create group for syncbatchnorm 1.10

Merge pull request !49755 from yangzhenzhang/modify-create-group-for-syncbn-1.10
This commit is contained in:
i-robot 2023-03-07 11:38:48 +00:00 committed by Gitee
commit 8570d26ed2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 128 additions and 31 deletions

View File

@ -18,6 +18,7 @@ from __future__ import division
import itertools
import numbers
import hashlib
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@ -36,11 +37,20 @@ from mindspore.communication import management
from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.nn.cell import Cell
from mindspore import log as logger
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
SYNC_BN_GROUP_NAME = ""
SYNCBN_GROUP_DICT = None
def _syncbatchnorm_group_dict():
global SYNCBN_GROUP_DICT
if SYNCBN_GROUP_DICT is None:
SYNCBN_GROUP_DICT = dict()
return SYNCBN_GROUP_DICT
class _BatchNorm(Cell):
@ -97,18 +107,20 @@ class _BatchNorm(Cell):
self.cls_name)
self.process_groups = process_groups
self.is_global = False
self.group_name = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
global SYNC_BN_GROUP_NAME
# for GlobalBatchNorm
if self.group_device_num != 1:
self.is_global = True
self.rank_id = get_rank()
self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group_device_num)
self.rank_list_idx = len(self.rank_list)
self._create_global_groups()
# for SyncBatchNorm
if self.process_groups != 0:
self.is_global = True
self.rank_id = get_rank()
self.rank_size = get_group_size()
if self.process_groups is not None:
@ -116,16 +128,11 @@ class _BatchNorm(Cell):
self._check_rank_ids(self.process_groups, self.rank_size)
self._create_sync_groups()
elif self.rank_size > 1:
self.is_global = True
self.group_device_num = self.rank_size
self.device_list = [i for i in range(0, self.rank_size)]
if context.get_context("device_target") == "Ascend":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group0"
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
self.group_name = "hccl_world_group"
elif context.get_context("device_target") == "GPU":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "nccl_world_group"
self.group_name = "nccl_world_group"
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean(keep_dims=True)
@ -149,7 +156,7 @@ class _BatchNorm(Cell):
if self.is_global:
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
momentum=self.momentum,
group=SYNC_BN_GROUP_NAME,
group=self.group_name,
device_num=self.group_device_num)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
@ -226,25 +233,34 @@ class _BatchNorm(Cell):
f"but got {process_groups}.")
seen.add(rid)
def _create_groups(self, process_groups):
""" create groups by process groups. """
for sub_group in process_groups:
validator.check_isinstance("sub group", sub_group, list)
self.group_device_num = len(sub_group)
if self.rank_id in sub_group and self.group_device_num > 1:
rank_list_name = '_'.join('%s' % id for id in sub_group)
group_dict = _syncbatchnorm_group_dict()
if rank_list_name not in group_dict:
md5 = hashlib.md5()
md5.update(rank_list_name.encode('utf-8'))
hash_name = md5.hexdigest()
self.group_name = str(self.group_device_num) + '_' + hash_name
group_dict[rank_list_name] = self.group_name
management.create_group(self.group_name, sub_group)
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
rank_list_name, self.group_name))
else:
self.group_name = group_dict[rank_list_name]
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
def _create_global_groups(self):
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i]:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
""" create global groups. """
self._create_groups(self.rank_list)
def _create_sync_groups(self):
for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
""" create sync groups. """
self._create_groups(self.process_groups)
@constexpr
@ -304,6 +320,8 @@ def _shape_infer(x_shape, num_feature):
class BatchNorm1d(_BatchNorm):
r"""
Batch Normalization layer over a 2D input.
This layer
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
@ -395,6 +413,8 @@ class BatchNorm1d(_BatchNorm):
class BatchNorm2d(_BatchNorm):
r"""
Batch Normalization layer over a 4D input.
Batch Normalization is widely used in convolutional networks. This layer
applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension) to avoid internal covariate shift as described
@ -521,6 +541,8 @@ def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
class BatchNorm3d(Cell):
r"""
Batch Normalization layer over a 5D input.
Batch Normalization is widely used in convolutional networks. This layer
applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
additional channel dimension) to avoid internal covariate shift.
@ -747,7 +769,7 @@ class SyncBatchNorm(_BatchNorm):
[[ 0.999995 0.999995 ]
[ 0.999995 0.999995 ]]]]
"""
@cell_attr_register(attrs=['num_features', 'process_groups'])
def __init__(self,
num_features,
eps=1e-5,
@ -930,8 +952,10 @@ class _InstanceNorm(Cell):
class InstanceNorm1d(_InstanceNorm):
r"""
Instance Normalization layer over a 3D input.
This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
@ -1018,8 +1042,10 @@ class InstanceNorm1d(_InstanceNorm):
class InstanceNorm2d(_InstanceNorm):
r"""
Instance Normalization layer over a 4D input.
This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
@ -1106,8 +1132,10 @@ class InstanceNorm2d(_InstanceNorm):
class InstanceNorm3d(_InstanceNorm):
r"""
Instance Normalization layer over a 5D input.
This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with
additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.

View File

@ -0,0 +1,69 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum, SyncBatchNorm
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.bn1 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
self.bn2 = SyncBatchNorm(num_features=8, process_groups=[[0, 1, 2, 3]])
self.bn3 = SyncBatchNorm(num_features=8)
self.bn4 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]])
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.bn1(out)
out = self.bn2(out)
out = self.bn3(out)
out = self.bn4(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_syncbatchnorm():
"""
Feature: test syncbatchnorm
Description: create group
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=4, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
assert net.bn1.group_name == "2_174882033225436b1440b7de44686450"
assert net.bn2.group_name == "4_937e3b535d29ac4571b6fecb60df6169"
assert net.bn3.group_name == "hccl_world_group"
assert net.bn4.group_name == "2_174882033225436b1440b7de44686450"