diff --git a/docs/api/api_python/nn/mindspore.nn.BatchNorm3d.rst b/docs/api/api_python/nn/mindspore.nn.BatchNorm3d.rst index 7a16cc99aa1..bfd57bae4df 100644 --- a/docs/api/api_python/nn/mindspore.nn.BatchNorm3d.rst +++ b/docs/api/api_python/nn/mindspore.nn.BatchNorm3d.rst @@ -25,7 +25,6 @@ mindspore.nn.BatchNorm3d - **moving_mean_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。平均值的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'zeros'。 - **moving_var_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。方差的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'ones'。 - **use_batch_statistics** (bool) - 如果为True,则使用当前批次数据的平均值和方差值。如果为False,则使用指定的平均值和方差值。如果为None,训练时,将使用当前批次数据的均值和方差,并更新动态均值和方差,验证过程将直接使用动态均值和方差。默认值:None。 - - **data_format** (str) - 数据格式的可选值为'NCDHW'。默认值:'NCDHW'。 输入: - **x** (Tensor) - 输入shape为 :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。 diff --git a/mindspore/python/mindspore/nn/layer/normalization.py b/mindspore/python/mindspore/nn/layer/normalization.py index 4c66f0fd76b..2e7a5468ffc 100644 --- a/mindspore/python/mindspore/nn/layer/normalization.py +++ b/mindspore/python/mindspore/nn/layer/normalization.py @@ -24,7 +24,6 @@ from mindspore.ops.operations import _inner_ops as inner from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer, Initializer from mindspore.common.tensor import Tensor -from mindspore.common._decorator import deprecated from mindspore.ops.primitive import constexpr import mindspore.context as context from mindspore._checkparam import Rel @@ -37,7 +36,7 @@ from mindspore.parallel._utils import _is_in_auto_parallel_mode from mindspore.nn.cell import Cell __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', - 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] + 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] SYNC_BN_GROUP_NAME = "" @@ -56,9 +55,6 @@ class _BatchNorm(Cell): moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, - device_num_each_group=1, - process_groups=0, - input_dims='2d', data_format='NCHW'): """Initialize _BatchNorm.""" super(_BatchNorm, self).__init__() @@ -69,7 +65,6 @@ class _BatchNorm(Cell): if momentum < 0 or momentum > 1: raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], " f"but got {momentum}.") - self.input_dims = input_dims self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device " @@ -92,39 +87,8 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) - self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group", - self.cls_name) - self.process_groups = process_groups - self.is_global = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - global SYNC_BN_GROUP_NAME - # for GlobalBatchNorm - if self.group_device_num != 1: - self.rank_id = get_rank() - self.rank_size = get_group_size() - self.device_list = [i for i in range(0, self.rank_size)] - self.rank_list = self.list_group(self.device_list, self.group_device_num) - self.rank_list_idx = len(self.rank_list) - self._create_global_groups() - # for SyncBatchNorm - if self.process_groups != 0: - self.rank_id = get_rank() - self.rank_size = get_group_size() - if self.process_groups is not None: - validator.check_isinstance("process_groups", self.process_groups, list) - self._check_rank_ids(self.process_groups, self.rank_size) - self._create_sync_groups() - elif self.rank_size > 1: - self.is_global = True - self.group_device_num = self.rank_size - self.device_list = [i for i in range(0, self.rank_size)] - if context.get_context("device_target") == "Ascend": - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group0" - management.create_group(SYNC_BN_GROUP_NAME, self.device_list) - elif context.get_context("device_target") == "GPU": - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "nccl_world_group" self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) @@ -136,20 +100,11 @@ class _BatchNorm(Cell): self._target = context.get_context("device_target") self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - momentum - if context.get_context("enable_ge"): - self.is_ge_backend = True - else: - self.is_ge_backend = False self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, momentum=self.momentum, data_format=self.format) - if self.is_global: - self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, - momentum=self.momentum, - group=SYNC_BN_GROUP_NAME, - device_num=self.group_device_num) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) if _is_in_auto_parallel_mode(): @@ -165,54 +120,13 @@ class _BatchNorm(Cell): self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) - def _check_data_dim(self, x): + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): raise NotImplementedError - def list_group(self, world_rank, group_size): - """ Check whether world_rank and group_size are valid. """ - if group_size > get_group_size(): - raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' cannot be greater than " - f"local rank size, but got 'device_num_each_group': {group_size}, " - f"local rank size: {get_group_size()}.") - if len(world_rank) % group_size != 0: - raise ValueError(f"For '{self.cls_name}', the dimension of device_list must be divisible by " - f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, " - f"'device_num_each_group': {group_size}.") - world_rank_list = zip(*(iter(world_rank),) * group_size) - group_list = [list(i) for i in world_rank_list] - return group_list - - def _check_rank_ids(self, process_groups, rank_size): - seen = set() - for rid in itertools.chain(*process_groups): - validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name) - if rid in seen: - raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, " - f"but got {process_groups}.") - seen.add(rid) - - def _create_global_groups(self): - for i in range(self.rank_list_idx): - if self.rank_id in self.rank_list[i]: - self.is_global = True - global SYNC_BN_GROUP_NAME - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i - management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) - - def _create_sync_groups(self): - for i in range(len(self.process_groups)): - validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list) - self.group_device_num = len(self.process_groups[i]) - if self.rank_id in self.process_groups[i] and self.group_device_num > 1: - self.is_global = True - global SYNC_BN_GROUP_NAME - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i - management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) - def construct(self, x): - _shape_check_bn(self.shape(x), self.input_dims, self.cls_name) + self._check_input_dim(self.shape(x), self.cls_name) if self.use_batch_statistics is None: if self.training: return self.bn_train(x, @@ -245,61 +159,6 @@ class _BatchNorm(Cell): self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) -@constexpr -def _channel_check(channel, num_channel, prim_name=None): - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - if channel != num_channel: - raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, " - f"but got channel: {channel}, num_channels: {num_channel}.") - - -@constexpr -def _shape_check(in_shape, prim_name=None): - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - if len(in_shape) != 4: - raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.") - - -@constexpr -def _shape_check_bn(in_shape, in_dims, prim_name=None): - """check input dims of batch norm.""" - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - dim = len(in_shape) - if in_dims == '1d' and dim != 2: - raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.") - if in_dims == '2d' and dim != 4: - raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.") - if in_dims == '3d' and dim != 5: - raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.") - if in_dims == 'both' and dim != 2 and dim != 4: - raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.") - - -@constexpr -def _shape_check_in(in_shape, in_dims, prim_name=None): - """check input dims of batch norm.""" - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - dim = len(in_shape) - if in_dims == '1d' and dim != 3: - raise ValueError(f"{msg_prefix} in_shape must have 3 dims, but got {len(in_shape)}.") - if in_dims == '2d' and dim != 4: - raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.") - if in_dims == '3d' and dim != 5: - raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.") - - -@constexpr -def _shape_infer(x_shape, num_feature): - """global Batch Normalization shape and axes infer""" - if len(x_shape) == 4: - axes = (0, 2, 3) - re_shape = (1, num_feature, 1, 1) - else: - axes = (0,) - re_shape = (1, num_feature) - return axes, re_shape - - class BatchNorm1d(_BatchNorm): r""" Batch Normalization layer over a 2D input. @@ -366,31 +225,12 @@ class BatchNorm1d(_BatchNorm): [ 0.4999975 0.399998 0.59999704 0.89999545 ]] """ - def __init__(self, - num_features, - eps=1e-5, - momentum=0.9, - affine=True, - gamma_init='ones', - beta_init='zeros', - moving_mean_init='zeros', - moving_var_init='ones', - use_batch_statistics=None): - """Initialize BatchNorm1d.""" - super(BatchNorm1d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - input_dims='1d') - - def _check_data_dim(self, x): - if x.ndim != 2: - pass + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 2: + raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.") class BatchNorm2d(_BatchNorm): @@ -479,49 +319,15 @@ class BatchNorm2d(_BatchNorm): [ 0.999995 0.999995 ]]]] """ - def __init__(self, - num_features, - eps=1e-5, - momentum=0.9, - affine=True, - gamma_init='ones', - beta_init='zeros', - moving_mean_init='zeros', - moving_var_init='ones', - use_batch_statistics=None, - data_format='NCHW'): - """Initialize BatchNorm2d.""" - super(BatchNorm2d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - input_dims='2d', - data_format=data_format) - - def _check_data_dim(self, x): - if x.ndim != 4: - pass + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 4: + raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.") -@constexpr -def _check_3d_shape(input_shape, prim_name=None): - msg_prefix = f"For '{prim_name}', the" if prim_name else "The" - if len(input_shape) != 5: - raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: " - f"{len(input_shape)}.") - - -@constexpr -def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None): - validator.check_type_name(args_name, dtype, valid_dtypes, prim_name) - - -class BatchNorm3d(Cell): +class BatchNorm3d(_BatchNorm): r""" Batch Normalization layer over a 5D input. @@ -557,7 +363,6 @@ class BatchNorm3d(Cell): use the mean value and variance value of specified value. If None, the training process will use the mean and variance of current batch data and track the running mean and variance, the evaluation process will use the running mean and variance. Default: None. - data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. @@ -595,12 +400,17 @@ class BatchNorm3d(Cell): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=None, - data_format='NCDHW'): + use_batch_statistics=None): """Initialize BatchNorm3d.""" - super(BatchNorm3d, self).__init__() - self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) - self.reshape = P.Reshape() + super(BatchNorm3d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) self.bn2d = BatchNorm2d(num_features=num_features, eps=eps, momentum=momentum, @@ -612,58 +422,22 @@ class BatchNorm3d(Cell): use_batch_statistics=use_batch_statistics, data_format="NCHW") - def construct(self, input_x): - x_shape = F.shape(input_x) - _check_3d_shape(x_shape, self.cls_name) - input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) - bn2d_out = self.bn2d(input_x) + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 5: + raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.") + + def construct(self, x): + x_shape = self.shape(x) + self._check_input_dim(x_shape, self.cls_name) + x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) + bn2d_out = self.bn2d(x) bn3d_out = self.reshape(bn2d_out, x_shape) return bn3d_out -class GlobalBatchNorm(_BatchNorm): - r""" - The GlobalBatchNorm interface is deprecated, please use the :class:`mindspore.nn.SyncBatchNorm` instead. - - Supported Platforms: - deprecated - """ - - @deprecated("1.2", "SyncBatchNorm", True) - def __init__(self, - num_features, - eps=1e-5, - momentum=0.9, - affine=True, - gamma_init='ones', - beta_init='zeros', - moving_mean_init='zeros', - moving_var_init='ones', - use_batch_statistics=None, - device_num_each_group=2): - """Initialize GlobalBatchNorm.""" - super(GlobalBatchNorm, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - device_num_each_group, - input_dims='both') - self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group", - self.cls_name) - if self.group_device_num <= 1: - raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, " - f"but got {self.group_device_num}.") - - def _check_data_dim(self, x): - if x.dim == 0: - pass - - class SyncBatchNorm(_BatchNorm): r""" Sync Batch Normalization layer over a N-dimension input. @@ -772,13 +546,61 @@ class SyncBatchNorm(_BatchNorm): beta_init, moving_mean_init, moving_var_init, - use_batch_statistics, - process_groups=process_groups, - input_dims='both') + use_batch_statistics) + self.is_global = False + global SYNC_BN_GROUP_NAME + self.process_groups = process_groups + if self.process_groups != 0: + self.rank_id = get_rank() + self.rank_size = get_group_size() + if self.process_groups is not None: + validator.check_isinstance("process_groups", self.process_groups, list) + self._check_rank_ids(self.process_groups, self.rank_size) + self._create_sync_groups() + elif self.rank_size > 1: + self.is_global = True + self.group_device_num = self.rank_size + self.device_list = [i for i in range(0, self.rank_size)] + if context.get_context("device_target") == "Ascend": + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group0" + management.create_group(SYNC_BN_GROUP_NAME, self.device_list) + elif context.get_context("device_target") == "GPU": + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "nccl_world_group" - def _check_data_dim(self, x): - if x.dim == 0: - pass + if self.is_global: + self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, + momentum=self.momentum, + group=SYNC_BN_GROUP_NAME, + device_num=self.group_device_num) + + def _create_sync_groups(self): + for i in range(len(self.process_groups)): + validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list) + self.group_device_num = len(self.process_groups[i]) + if self.rank_id in self.process_groups[i] and self.group_device_num > 1: + self.is_global = True + global SYNC_BN_GROUP_NAME + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i + management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) + + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim not in (2, 4): + raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.") + + def _check_rank_ids(self, process_groups, rank_size): + seen = set() + for rid in itertools.chain(*process_groups): + validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name) + if rid in seen: + raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, " + f"but got {process_groups}.") + seen.add(rid) class LayerNorm(Cell): @@ -871,7 +693,6 @@ class LayerNorm(Cell): class _InstanceNorm(Cell): """Instance Normalization base class.""" - @cell_attr_register def __init__(self, num_features, @@ -879,8 +700,7 @@ class _InstanceNorm(Cell): momentum=0.1, affine=True, gamma_init='ones', - beta_init='zeros', - input_dims='2d'): + beta_init='zeros'): """Initialize Normalization base class.""" super(_InstanceNorm, self).__init__() validator.check_value_type('num_features', num_features, [int], self.cls_name) @@ -897,7 +717,6 @@ class _InstanceNorm(Cell): f"but got {momentum}.") self.num_features = num_features self.eps = eps - self.input_dims = input_dims self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False) self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False) self.gamma = Parameter(initializer( @@ -910,7 +729,7 @@ class _InstanceNorm(Cell): self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum) def construct(self, x): - _shape_check_in(self.shape(x), self.input_dims, self.cls_name) + self._check_input_dim(self.shape(x), self.cls_name) return self.instance_bn(x, self.gamma, self.beta, @@ -1005,21 +824,12 @@ class InstanceNorm1d(_InstanceNorm): (2, 3, 5) """ - def __init__(self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - gamma_init='ones', - beta_init='zeros'): - """Initialize InstanceNorm2d.""" - super(InstanceNorm1d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - input_dims='1d') + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 3: + raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.") class InstanceNorm2d(_InstanceNorm): @@ -1095,21 +905,12 @@ class InstanceNorm2d(_InstanceNorm): (2, 3, 2, 2) """ - def __init__(self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - gamma_init='ones', - beta_init='zeros'): - """Initialize InstanceNorm2d.""" - super(InstanceNorm2d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - input_dims='2d') + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 4: + raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.") class InstanceNorm3d(_InstanceNorm): @@ -1184,22 +985,12 @@ class InstanceNorm3d(_InstanceNorm): >>> print(output.shape) (2, 3, 5, 2, 2) """ - - def __init__(self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - gamma_init='ones', - beta_init='zeros'): - """Initialize InstanceNorm2d.""" - super(InstanceNorm3d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - input_dims='3d') + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 5: + raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.") class GroupNorm(Cell): @@ -1283,7 +1074,7 @@ class GroupNorm(Cell): def _cal_output(self, x): """calculate groupnorm output""" batch, channel, height, width = self.shape(x) - _channel_check(channel, self.num_channels, self.cls_name) + self._channel_check(channel, self.num_channels, self.cls_name) x = self.reshape(x, (batch, self.num_groups, -1)) mean = self.reduce_mean(x, 2) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups) @@ -1293,11 +1084,31 @@ class GroupNorm(Cell): output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1)) return output - def construct(self, x): - _shape_check(self.shape(x), self.cls_name) - _check_dtype(x.dtype, [mstype.float16, mstype.float32], "input", self.cls_name) - output = self._cal_output(x) - return output + @staticmethod + @constexpr + def _check_input_dim(shape, cls_name): + dim = len(shape) + if dim != 4: + raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.") + + @staticmethod + @constexpr + def _channel_check(channel, num_channel, prim_name=None): + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" + if channel != num_channel: + raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, " + f"but got channel: {channel}, num_channels: {num_channel}.") + + @staticmethod + @constexpr + def _check_dtype(dtype, valid_dtypes, prim_name=None): + validator.check_type_name("input", dtype, valid_dtypes, prim_name) def extend_repr(self): return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) + + def construct(self, x): + self._check_input_dim(self.shape(x), self.cls_name) + self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name) + output = self._cal_output(x) + return output diff --git a/mindspore/python/mindspore/nn/optim/thor.py b/mindspore/python/mindspore/nn/optim/thor.py index fe0da8d8472..b563584b394 100644 --- a/mindspore/python/mindspore/nn/optim/thor.py +++ b/mindspore/python/mindspore/nn/optim/thor.py @@ -210,7 +210,7 @@ def find_net_layertype_recur(net, layertype_map): if subcell.gamma.requires_grad: layertype_map.append(BatchNorm) elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, - nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)): + nn.BatchNorm1d, nn.GroupNorm)): if isinstance(subcell, (nn.Dense, nn.Conv2d)): get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map) else: diff --git a/mindspore/python/mindspore/train/train_thor/convert_utils.py b/mindspore/python/mindspore/train/train_thor/convert_utils.py index 86c6c2151d6..4e0eb581b97 100644 --- a/mindspore/python/mindspore/train/train_thor/convert_utils.py +++ b/mindspore/python/mindspore/train/train_thor/convert_utils.py @@ -134,7 +134,7 @@ class ConvertNetUtils: elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)): continue elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm, - nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)): + nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)): continue elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)): prefix = subcell.param_prefix