diff --git a/mindspore/python/mindspore/nn/layer/normalization.py b/mindspore/python/mindspore/nn/layer/normalization.py index c79965eb87b..07c59f5285c 100644 --- a/mindspore/python/mindspore/nn/layer/normalization.py +++ b/mindspore/python/mindspore/nn/layer/normalization.py @@ -18,6 +18,7 @@ from __future__ import division import itertools import numbers +import hashlib from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -36,11 +37,20 @@ from mindspore.communication import management from mindspore.common import dtype as mstype from mindspore.parallel._utils import _is_in_auto_parallel_mode from mindspore.nn.cell import Cell +from mindspore import log as logger __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] -SYNC_BN_GROUP_NAME = "" + +SYNCBN_GROUP_DICT = None + + +def _syncbatchnorm_group_dict(): + global SYNCBN_GROUP_DICT + if SYNCBN_GROUP_DICT is None: + SYNCBN_GROUP_DICT = dict() + return SYNCBN_GROUP_DICT class _BatchNorm(Cell): @@ -97,18 +107,20 @@ class _BatchNorm(Cell): self.cls_name) self.process_groups = process_groups self.is_global = False + self.group_name = None self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - global SYNC_BN_GROUP_NAME + # for GlobalBatchNorm if self.group_device_num != 1: + self.is_global = True self.rank_id = get_rank() self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] self.rank_list = self.list_group(self.device_list, self.group_device_num) - self.rank_list_idx = len(self.rank_list) self._create_global_groups() # for SyncBatchNorm if self.process_groups != 0: + self.is_global = True self.rank_id = get_rank() self.rank_size = get_group_size() if self.process_groups is not None: @@ -116,16 +128,11 @@ class _BatchNorm(Cell): self._check_rank_ids(self.process_groups, self.rank_size) self._create_sync_groups() elif self.rank_size > 1: - self.is_global = True self.group_device_num = self.rank_size - self.device_list = [i for i in range(0, self.rank_size)] if context.get_context("device_target") == "Ascend": - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group0" - management.create_group(SYNC_BN_GROUP_NAME, self.device_list) + self.group_name = "hccl_world_group" elif context.get_context("device_target") == "GPU": - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "nccl_world_group" + self.group_name = "nccl_world_group" self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) @@ -149,7 +156,7 @@ class _BatchNorm(Cell): if self.is_global: self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, momentum=self.momentum, - group=SYNC_BN_GROUP_NAME, + group=self.group_name, device_num=self.group_device_num) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) @@ -226,25 +233,34 @@ class _BatchNorm(Cell): f"but got {process_groups}.") seen.add(rid) + def _create_groups(self, process_groups): + """ create groups by process groups. """ + for sub_group in process_groups: + validator.check_isinstance("sub group", sub_group, list) + self.group_device_num = len(sub_group) + if self.rank_id in sub_group and self.group_device_num > 1: + rank_list_name = '_'.join('%s' % id for id in sub_group) + group_dict = _syncbatchnorm_group_dict() + if rank_list_name not in group_dict: + md5 = hashlib.md5() + md5.update(rank_list_name.encode('utf-8')) + hash_name = md5.hexdigest() + self.group_name = str(self.group_device_num) + '_' + hash_name + group_dict[rank_list_name] = self.group_name + management.create_group(self.group_name, sub_group) + logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format( + rank_list_name, self.group_name)) + else: + self.group_name = group_dict[rank_list_name] + logger.info("the group for {} already exists, no need to create".format(rank_list_name)) + def _create_global_groups(self): - for i in range(self.rank_list_idx): - if self.rank_id in self.rank_list[i]: - self.is_global = True - global SYNC_BN_GROUP_NAME - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i - management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) + """ create global groups. """ + self._create_groups(self.rank_list) def _create_sync_groups(self): - for i in range(len(self.process_groups)): - validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list) - self.group_device_num = len(self.process_groups[i]) - if self.rank_id in self.process_groups[i] and self.group_device_num > 1: - self.is_global = True - global SYNC_BN_GROUP_NAME - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i - management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) + """ create sync groups. """ + self._create_groups(self.process_groups) @constexpr @@ -304,6 +320,8 @@ def _shape_infer(x_shape, num_feature): class BatchNorm1d(_BatchNorm): r""" + Batch Normalization layer over a 2D input. + This layer applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to reduce internal covariate shift. Batch Normalization is widely used in convolutional networks. @@ -395,6 +413,8 @@ class BatchNorm1d(_BatchNorm): class BatchNorm2d(_BatchNorm): r""" + Batch Normalization layer over a 4D input. + Batch Normalization is widely used in convolutional networks. This layer applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) to avoid internal covariate shift as described @@ -521,6 +541,8 @@ def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None): class BatchNorm3d(Cell): r""" + Batch Normalization layer over a 5D input. + Batch Normalization is widely used in convolutional networks. This layer applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) to avoid internal covariate shift. @@ -747,7 +769,7 @@ class SyncBatchNorm(_BatchNorm): [[ 0.999995 0.999995 ] [ 0.999995 0.999995 ]]]] """ - + @cell_attr_register(attrs=['num_features', 'process_groups']) def __init__(self, num_features, eps=1e-5, @@ -930,8 +952,10 @@ class _InstanceNorm(Cell): class InstanceNorm1d(_InstanceNorm): r""" + Instance Normalization layer over a 3D input. + This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with - additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for + additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `_. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula. @@ -1018,8 +1042,10 @@ class InstanceNorm1d(_InstanceNorm): class InstanceNorm2d(_InstanceNorm): r""" + Instance Normalization layer over a 4D input. + This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with - additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for + additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `_. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula. @@ -1106,8 +1132,10 @@ class InstanceNorm2d(_InstanceNorm): class InstanceNorm3d(_InstanceNorm): r""" + Instance Normalization layer over a 5D input. + This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with - additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for + additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for Fast Stylization `_. It rescales and recenters the feature using a mini-batch of data and the learned parameters which can be described in the following formula. diff --git a/tests/ut/python/parallel/test_syncbatchnorm.py b/tests/ut/python/parallel/test_syncbatchnorm.py new file mode 100644 index 00000000000..c912ac170e3 --- /dev/null +++ b/tests/ut/python/parallel/test_syncbatchnorm.py @@ -0,0 +1,69 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _cell_graph_executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum, SyncBatchNorm +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride): + super().__init__() + self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride) + self.conv2d_weight = Parameter(conv2d_weight, "w1") + self.bn1 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]]) + self.bn2 = SyncBatchNorm(num_features=8, process_groups=[[0, 1, 2, 3]]) + self.bn3 = SyncBatchNorm(num_features=8) + self.bn4 = SyncBatchNorm(num_features=8, process_groups=[[0, 1], [2, 3]]) + + def construct(self, x, b): + out = self.conv2d(x, self.conv2d_weight) + out = self.bn1(out) + out = self.bn2(out) + out = self.bn3(out) + out = self.bn4(out) + return out + + +_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) +_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) +_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_train() + _cell_graph_executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_syncbatchnorm(): + """ + Feature: test syncbatchnorm + Description: create group + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=4, global_rank=0) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) + compile_net(net) + assert net.bn1.group_name == "2_174882033225436b1440b7de44686450" + assert net.bn2.group_name == "4_937e3b535d29ac4571b6fecb60df6169" + assert net.bn3.group_name == "hccl_world_group" + assert net.bn4.group_name == "2_174882033225436b1440b7de44686450"