forked from mindspore-Ecosystem/mindspore
!41633 refactor batchnorm
Merge pull request !41633 from changzherui/mod_bn
This commit is contained in:
commit
08cff886ca
|
@ -25,7 +25,6 @@ mindspore.nn.BatchNorm3d
|
|||
- **moving_mean_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。平均值的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'zeros'。
|
||||
- **moving_var_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。方差的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'ones'。
|
||||
- **use_batch_statistics** (bool) - 如果为True,则使用当前批次数据的平均值和方差值。如果为False,则使用指定的平均值和方差值。如果为None,训练时,将使用当前批次数据的均值和方差,并更新动态均值和方差,验证过程将直接使用动态均值和方差。默认值:None。
|
||||
- **data_format** (str) - 数据格式的可选值为'NCDHW'。默认值:'NCDHW'。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 输入shape为 :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。
|
||||
|
|
|
@ -24,7 +24,6 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common._decorator import deprecated
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import Rel
|
||||
|
@ -37,7 +36,7 @@ from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|||
from mindspore.nn.cell import Cell
|
||||
|
||||
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
|
||||
'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
||||
'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
||||
|
||||
SYNC_BN_GROUP_NAME = ""
|
||||
|
||||
|
@ -56,9 +55,6 @@ class _BatchNorm(Cell):
|
|||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=1,
|
||||
process_groups=0,
|
||||
input_dims='2d',
|
||||
data_format='NCHW'):
|
||||
"""Initialize _BatchNorm."""
|
||||
super(_BatchNorm, self).__init__()
|
||||
|
@ -69,7 +65,6 @@ class _BatchNorm(Cell):
|
|||
if momentum < 0 or momentum > 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
|
||||
f"but got {momentum}.")
|
||||
self.input_dims = input_dims
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
|
||||
|
@ -92,39 +87,8 @@ class _BatchNorm(Cell):
|
|||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=affine)
|
||||
self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
|
||||
self.cls_name)
|
||||
self.process_groups = process_groups
|
||||
self.is_global = False
|
||||
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
global SYNC_BN_GROUP_NAME
|
||||
# for GlobalBatchNorm
|
||||
if self.group_device_num != 1:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
self.rank_list = self.list_group(self.device_list, self.group_device_num)
|
||||
self.rank_list_idx = len(self.rank_list)
|
||||
self._create_global_groups()
|
||||
# for SyncBatchNorm
|
||||
if self.process_groups != 0:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
if self.process_groups is not None:
|
||||
validator.check_isinstance("process_groups", self.process_groups, list)
|
||||
self._check_rank_ids(self.process_groups, self.rank_size)
|
||||
self._create_sync_groups()
|
||||
elif self.rank_size > 1:
|
||||
self.is_global = True
|
||||
self.group_device_num = self.rank_size
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
|
@ -136,20 +100,11 @@ class _BatchNorm(Cell):
|
|||
self._target = context.get_context("device_target")
|
||||
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
||||
self.momentum = 1.0 - momentum
|
||||
if context.get_context("enable_ge"):
|
||||
self.is_ge_backend = True
|
||||
else:
|
||||
self.is_ge_backend = False
|
||||
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
data_format=self.format)
|
||||
if self.is_global:
|
||||
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
group=SYNC_BN_GROUP_NAME,
|
||||
device_num=self.group_device_num)
|
||||
|
||||
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
||||
if _is_in_auto_parallel_mode():
|
||||
|
@ -165,54 +120,13 @@ class _BatchNorm(Cell):
|
|||
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
||||
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def list_group(self, world_rank, group_size):
|
||||
""" Check whether world_rank and group_size are valid. """
|
||||
if group_size > get_group_size():
|
||||
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' cannot be greater than "
|
||||
f"local rank size, but got 'device_num_each_group': {group_size}, "
|
||||
f"local rank size: {get_group_size()}.")
|
||||
if len(world_rank) % group_size != 0:
|
||||
raise ValueError(f"For '{self.cls_name}', the dimension of device_list must be divisible by "
|
||||
f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, "
|
||||
f"'device_num_each_group': {group_size}.")
|
||||
world_rank_list = zip(*(iter(world_rank),) * group_size)
|
||||
group_list = [list(i) for i in world_rank_list]
|
||||
return group_list
|
||||
|
||||
def _check_rank_ids(self, process_groups, rank_size):
|
||||
seen = set()
|
||||
for rid in itertools.chain(*process_groups):
|
||||
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
|
||||
if rid in seen:
|
||||
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
|
||||
f"but got {process_groups}.")
|
||||
seen.add(rid)
|
||||
|
||||
def _create_global_groups(self):
|
||||
for i in range(self.rank_list_idx):
|
||||
if self.rank_id in self.rank_list[i]:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
||||
|
||||
def _create_sync_groups(self):
|
||||
for i in range(len(self.process_groups)):
|
||||
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
|
||||
self.group_device_num = len(self.process_groups[i])
|
||||
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||
|
||||
def construct(self, x):
|
||||
_shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
|
||||
self._check_input_dim(self.shape(x), self.cls_name)
|
||||
if self.use_batch_statistics is None:
|
||||
if self.training:
|
||||
return self.bn_train(x,
|
||||
|
@ -245,61 +159,6 @@ class _BatchNorm(Cell):
|
|||
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _channel_check(channel, num_channel, prim_name=None):
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
if channel != num_channel:
|
||||
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
|
||||
f"but got channel: {channel}, num_channels: {num_channel}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check(in_shape, prim_name=None):
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
if len(in_shape) != 4:
|
||||
raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check_bn(in_shape, in_dims, prim_name=None):
|
||||
"""check input dims of batch norm."""
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
dim = len(in_shape)
|
||||
if in_dims == '1d' and dim != 2:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.")
|
||||
if in_dims == '2d' and dim != 4:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
|
||||
if in_dims == '3d' and dim != 5:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
|
||||
if in_dims == 'both' and dim != 2 and dim != 4:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check_in(in_shape, in_dims, prim_name=None):
|
||||
"""check input dims of batch norm."""
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
dim = len(in_shape)
|
||||
if in_dims == '1d' and dim != 3:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 3 dims, but got {len(in_shape)}.")
|
||||
if in_dims == '2d' and dim != 4:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
|
||||
if in_dims == '3d' and dim != 5:
|
||||
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_infer(x_shape, num_feature):
|
||||
"""global Batch Normalization shape and axes infer"""
|
||||
if len(x_shape) == 4:
|
||||
axes = (0, 2, 3)
|
||||
re_shape = (1, num_feature, 1, 1)
|
||||
else:
|
||||
axes = (0,)
|
||||
re_shape = (1, num_feature)
|
||||
return axes, re_shape
|
||||
|
||||
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
Batch Normalization layer over a 2D input.
|
||||
|
@ -366,31 +225,12 @@ class BatchNorm1d(_BatchNorm):
|
|||
[ 0.4999975 0.399998 0.59999704 0.89999545 ]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None):
|
||||
"""Initialize BatchNorm1d."""
|
||||
super(BatchNorm1d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
input_dims='1d')
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.ndim != 2:
|
||||
pass
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 2:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
|
||||
|
||||
|
||||
class BatchNorm2d(_BatchNorm):
|
||||
|
@ -479,49 +319,15 @@ class BatchNorm2d(_BatchNorm):
|
|||
[ 0.999995 0.999995 ]]]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
data_format='NCHW'):
|
||||
"""Initialize BatchNorm2d."""
|
||||
super(BatchNorm2d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
input_dims='2d',
|
||||
data_format=data_format)
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.ndim != 4:
|
||||
pass
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 4:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_3d_shape(input_shape, prim_name=None):
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
if len(input_shape) != 5:
|
||||
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
|
||||
f"{len(input_shape)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
|
||||
validator.check_type_name(args_name, dtype, valid_dtypes, prim_name)
|
||||
|
||||
|
||||
class BatchNorm3d(Cell):
|
||||
class BatchNorm3d(_BatchNorm):
|
||||
r"""
|
||||
Batch Normalization layer over a 5D input.
|
||||
|
||||
|
@ -557,7 +363,6 @@ class BatchNorm3d(Cell):
|
|||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance. Default: None.
|
||||
data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
|
@ -595,12 +400,17 @@ class BatchNorm3d(Cell):
|
|||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
data_format='NCDHW'):
|
||||
use_batch_statistics=None):
|
||||
"""Initialize BatchNorm3d."""
|
||||
super(BatchNorm3d, self).__init__()
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
||||
self.reshape = P.Reshape()
|
||||
super(BatchNorm3d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics)
|
||||
self.bn2d = BatchNorm2d(num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
|
@ -612,58 +422,22 @@ class BatchNorm3d(Cell):
|
|||
use_batch_statistics=use_batch_statistics,
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, input_x):
|
||||
x_shape = F.shape(input_x)
|
||||
_check_3d_shape(x_shape, self.cls_name)
|
||||
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||
bn2d_out = self.bn2d(input_x)
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 5:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
self._check_input_dim(x_shape, self.cls_name)
|
||||
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||
bn2d_out = self.bn2d(x)
|
||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||
return bn3d_out
|
||||
|
||||
|
||||
class GlobalBatchNorm(_BatchNorm):
|
||||
r"""
|
||||
The GlobalBatchNorm interface is deprecated, please use the :class:`mindspore.nn.SyncBatchNorm` instead.
|
||||
|
||||
Supported Platforms:
|
||||
deprecated
|
||||
"""
|
||||
|
||||
@deprecated("1.2", "SyncBatchNorm", True)
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=2):
|
||||
"""Initialize GlobalBatchNorm."""
|
||||
super(GlobalBatchNorm, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
device_num_each_group,
|
||||
input_dims='both')
|
||||
self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
|
||||
self.cls_name)
|
||||
if self.group_device_num <= 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, "
|
||||
f"but got {self.group_device_num}.")
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim == 0:
|
||||
pass
|
||||
|
||||
|
||||
class SyncBatchNorm(_BatchNorm):
|
||||
r"""
|
||||
Sync Batch Normalization layer over a N-dimension input.
|
||||
|
@ -772,13 +546,61 @@ class SyncBatchNorm(_BatchNorm):
|
|||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
process_groups=process_groups,
|
||||
input_dims='both')
|
||||
use_batch_statistics)
|
||||
self.is_global = False
|
||||
global SYNC_BN_GROUP_NAME
|
||||
self.process_groups = process_groups
|
||||
if self.process_groups != 0:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
if self.process_groups is not None:
|
||||
validator.check_isinstance("process_groups", self.process_groups, list)
|
||||
self._check_rank_ids(self.process_groups, self.rank_size)
|
||||
self._create_sync_groups()
|
||||
elif self.rank_size > 1:
|
||||
self.is_global = True
|
||||
self.group_device_num = self.rank_size
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim == 0:
|
||||
pass
|
||||
if self.is_global:
|
||||
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
group=SYNC_BN_GROUP_NAME,
|
||||
device_num=self.group_device_num)
|
||||
|
||||
def _create_sync_groups(self):
|
||||
for i in range(len(self.process_groups)):
|
||||
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
|
||||
self.group_device_num = len(self.process_groups[i])
|
||||
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim not in (2, 4):
|
||||
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
|
||||
|
||||
def _check_rank_ids(self, process_groups, rank_size):
|
||||
seen = set()
|
||||
for rid in itertools.chain(*process_groups):
|
||||
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
|
||||
if rid in seen:
|
||||
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
|
||||
f"but got {process_groups}.")
|
||||
seen.add(rid)
|
||||
|
||||
|
||||
class LayerNorm(Cell):
|
||||
|
@ -871,7 +693,6 @@ class LayerNorm(Cell):
|
|||
|
||||
class _InstanceNorm(Cell):
|
||||
"""Instance Normalization base class."""
|
||||
|
||||
@cell_attr_register
|
||||
def __init__(self,
|
||||
num_features,
|
||||
|
@ -879,8 +700,7 @@ class _InstanceNorm(Cell):
|
|||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
input_dims='2d'):
|
||||
beta_init='zeros'):
|
||||
"""Initialize Normalization base class."""
|
||||
super(_InstanceNorm, self).__init__()
|
||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||
|
@ -897,7 +717,6 @@ class _InstanceNorm(Cell):
|
|||
f"but got {momentum}.")
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.input_dims = input_dims
|
||||
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
|
||||
self.gamma = Parameter(initializer(
|
||||
|
@ -910,7 +729,7 @@ class _InstanceNorm(Cell):
|
|||
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
||||
|
||||
def construct(self, x):
|
||||
_shape_check_in(self.shape(x), self.input_dims, self.cls_name)
|
||||
self._check_input_dim(self.shape(x), self.cls_name)
|
||||
return self.instance_bn(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
|
@ -1005,21 +824,12 @@ class InstanceNorm1d(_InstanceNorm):
|
|||
(2, 3, 5)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros'):
|
||||
"""Initialize InstanceNorm2d."""
|
||||
super(InstanceNorm1d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
input_dims='1d')
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 3:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.")
|
||||
|
||||
|
||||
class InstanceNorm2d(_InstanceNorm):
|
||||
|
@ -1095,21 +905,12 @@ class InstanceNorm2d(_InstanceNorm):
|
|||
(2, 3, 2, 2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros'):
|
||||
"""Initialize InstanceNorm2d."""
|
||||
super(InstanceNorm2d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
input_dims='2d')
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 4:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
||||
|
||||
|
||||
class InstanceNorm3d(_InstanceNorm):
|
||||
|
@ -1184,22 +985,12 @@ class InstanceNorm3d(_InstanceNorm):
|
|||
>>> print(output.shape)
|
||||
(2, 3, 5, 2, 2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros'):
|
||||
"""Initialize InstanceNorm2d."""
|
||||
super(InstanceNorm3d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
input_dims='3d')
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 5:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
||||
|
||||
|
||||
class GroupNorm(Cell):
|
||||
|
@ -1283,7 +1074,7 @@ class GroupNorm(Cell):
|
|||
def _cal_output(self, x):
|
||||
"""calculate groupnorm output"""
|
||||
batch, channel, height, width = self.shape(x)
|
||||
_channel_check(channel, self.num_channels, self.cls_name)
|
||||
self._channel_check(channel, self.num_channels, self.cls_name)
|
||||
x = self.reshape(x, (batch, self.num_groups, -1))
|
||||
mean = self.reduce_mean(x, 2)
|
||||
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
||||
|
@ -1293,11 +1084,31 @@ class GroupNorm(Cell):
|
|||
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
|
||||
return output
|
||||
|
||||
def construct(self, x):
|
||||
_shape_check(self.shape(x), self.cls_name)
|
||||
_check_dtype(x.dtype, [mstype.float16, mstype.float32], "input", self.cls_name)
|
||||
output = self._cal_output(x)
|
||||
return output
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim != 4:
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _channel_check(channel, num_channel, prim_name=None):
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
if channel != num_channel:
|
||||
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
|
||||
f"but got channel: {channel}, num_channels: {num_channel}.")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_dtype(dtype, valid_dtypes, prim_name=None):
|
||||
validator.check_type_name("input", dtype, valid_dtypes, prim_name)
|
||||
|
||||
def extend_repr(self):
|
||||
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
||||
|
||||
def construct(self, x):
|
||||
self._check_input_dim(self.shape(x), self.cls_name)
|
||||
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
output = self._cal_output(x)
|
||||
return output
|
||||
|
|
|
@ -210,7 +210,7 @@ def find_net_layertype_recur(net, layertype_map):
|
|||
if subcell.gamma.requires_grad:
|
||||
layertype_map.append(BatchNorm)
|
||||
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
|
||||
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
|
||||
nn.BatchNorm1d, nn.GroupNorm)):
|
||||
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
|
||||
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
|
||||
else:
|
||||
|
|
|
@ -134,7 +134,7 @@ class ConvertNetUtils:
|
|||
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)):
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm,
|
||||
nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
|
||||
nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)):
|
||||
prefix = subcell.param_prefix
|
||||
|
|
Loading…
Reference in New Issue