!41633 refactor batchnorm

Merge pull request !41633 from changzherui/mod_bn
This commit is contained in:
i-robot 2022-09-14 09:20:21 +00:00 committed by Gitee
commit 08cff886ca
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 143 additions and 333 deletions

View File

@ -25,7 +25,6 @@ mindspore.nn.BatchNorm3d
- **moving_mean_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。平均值的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'zeros'。
- **moving_var_init** (Union[Tensor, str, Initializer, numbers.Number]) - 动态均值和动态方差所使用的动量。方差的初始化方法。str的值引用自函数 `mindspore.common.initializer` ,包括'zeros'、'ones'等。默认值:'ones'。
- **use_batch_statistics** (bool) - 如果为True则使用当前批次数据的平均值和方差值。如果为False则使用指定的平均值和方差值。如果为None训练时将使用当前批次数据的均值和方差并更新动态均值和方差验证过程将直接使用动态均值和方差。默认值None。
- **data_format** (str) - 数据格式的可选值为'NCDHW'。默认值:'NCDHW'。
输入:
- **x** (Tensor) - 输入shape为 :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` 的Tensor。

View File

@ -24,7 +24,6 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor
from mindspore.common._decorator import deprecated
from mindspore.ops.primitive import constexpr
import mindspore.context as context
from mindspore._checkparam import Rel
@ -37,7 +36,7 @@ from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.nn.cell import Cell
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
SYNC_BN_GROUP_NAME = ""
@ -56,9 +55,6 @@ class _BatchNorm(Cell):
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None,
device_num_each_group=1,
process_groups=0,
input_dims='2d',
data_format='NCHW'):
"""Initialize _BatchNorm."""
super(_BatchNorm, self).__init__()
@ -69,7 +65,6 @@ class _BatchNorm(Cell):
if momentum < 0 or momentum > 1:
raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
f"but got {momentum}.")
self.input_dims = input_dims
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
@ -92,39 +87,8 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
self.cls_name)
self.process_groups = process_groups
self.is_global = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
global SYNC_BN_GROUP_NAME
# for GlobalBatchNorm
if self.group_device_num != 1:
self.rank_id = get_rank()
self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group_device_num)
self.rank_list_idx = len(self.rank_list)
self._create_global_groups()
# for SyncBatchNorm
if self.process_groups != 0:
self.rank_id = get_rank()
self.rank_size = get_group_size()
if self.process_groups is not None:
validator.check_isinstance("process_groups", self.process_groups, list)
self._check_rank_ids(self.process_groups, self.rank_size)
self._create_sync_groups()
elif self.rank_size > 1:
self.is_global = True
self.group_device_num = self.rank_size
self.device_list = [i for i in range(0, self.rank_size)]
if context.get_context("device_target") == "Ascend":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group0"
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
elif context.get_context("device_target") == "GPU":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "nccl_world_group"
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean(keep_dims=True)
@ -136,20 +100,11 @@ class _BatchNorm(Cell):
self._target = context.get_context("device_target")
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps,
momentum=self.momentum,
data_format=self.format)
if self.is_global:
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
momentum=self.momentum,
group=SYNC_BN_GROUP_NAME,
device_num=self.group_device_num)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
if _is_in_auto_parallel_mode():
@ -165,54 +120,13 @@ class _BatchNorm(Cell):
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
def _check_data_dim(self, x):
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
raise NotImplementedError
def list_group(self, world_rank, group_size):
""" Check whether world_rank and group_size are valid. """
if group_size > get_group_size():
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' cannot be greater than "
f"local rank size, but got 'device_num_each_group': {group_size}, "
f"local rank size: {get_group_size()}.")
if len(world_rank) % group_size != 0:
raise ValueError(f"For '{self.cls_name}', the dimension of device_list must be divisible by "
f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, "
f"'device_num_each_group': {group_size}.")
world_rank_list = zip(*(iter(world_rank),) * group_size)
group_list = [list(i) for i in world_rank_list]
return group_list
def _check_rank_ids(self, process_groups, rank_size):
seen = set()
for rid in itertools.chain(*process_groups):
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
if rid in seen:
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
f"but got {process_groups}.")
seen.add(rid)
def _create_global_groups(self):
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i]:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
def _create_sync_groups(self):
for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
def construct(self, x):
_shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
self._check_input_dim(self.shape(x), self.cls_name)
if self.use_batch_statistics is None:
if self.training:
return self.bn_train(x,
@ -245,61 +159,6 @@ class _BatchNorm(Cell):
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
@constexpr
def _channel_check(channel, num_channel, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if channel != num_channel:
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
f"but got channel: {channel}, num_channels: {num_channel}.")
@constexpr
def _shape_check(in_shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(in_shape) != 4:
raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.")
@constexpr
def _shape_check_bn(in_shape, in_dims, prim_name=None):
"""check input dims of batch norm."""
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
dim = len(in_shape)
if in_dims == '1d' and dim != 2:
raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.")
if in_dims == '2d' and dim != 4:
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
if in_dims == '3d' and dim != 5:
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
if in_dims == 'both' and dim != 2 and dim != 4:
raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.")
@constexpr
def _shape_check_in(in_shape, in_dims, prim_name=None):
"""check input dims of batch norm."""
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
dim = len(in_shape)
if in_dims == '1d' and dim != 3:
raise ValueError(f"{msg_prefix} in_shape must have 3 dims, but got {len(in_shape)}.")
if in_dims == '2d' and dim != 4:
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
if in_dims == '3d' and dim != 5:
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
@constexpr
def _shape_infer(x_shape, num_feature):
"""global Batch Normalization shape and axes infer"""
if len(x_shape) == 4:
axes = (0, 2, 3)
re_shape = (1, num_feature, 1, 1)
else:
axes = (0,)
re_shape = (1, num_feature)
return axes, re_shape
class BatchNorm1d(_BatchNorm):
r"""
Batch Normalization layer over a 2D input.
@ -366,31 +225,12 @@ class BatchNorm1d(_BatchNorm):
[ 0.4999975 0.399998 0.59999704 0.89999545 ]]
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None):
"""Initialize BatchNorm1d."""
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
input_dims='1d')
def _check_data_dim(self, x):
if x.ndim != 2:
pass
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 2:
raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
class BatchNorm2d(_BatchNorm):
@ -479,49 +319,15 @@ class BatchNorm2d(_BatchNorm):
[ 0.999995 0.999995 ]]]]
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None,
data_format='NCHW'):
"""Initialize BatchNorm2d."""
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
input_dims='2d',
data_format=data_format)
def _check_data_dim(self, x):
if x.ndim != 4:
pass
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
@constexpr
def _check_3d_shape(input_shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(input_shape) != 5:
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
f"{len(input_shape)}.")
@constexpr
def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
validator.check_type_name(args_name, dtype, valid_dtypes, prim_name)
class BatchNorm3d(Cell):
class BatchNorm3d(_BatchNorm):
r"""
Batch Normalization layer over a 5D input.
@ -557,7 +363,6 @@ class BatchNorm3d(Cell):
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance. Default: None.
data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
@ -595,12 +400,17 @@ class BatchNorm3d(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None,
data_format='NCDHW'):
use_batch_statistics=None):
"""Initialize BatchNorm3d."""
super(BatchNorm3d, self).__init__()
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
self.reshape = P.Reshape()
super(BatchNorm3d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
self.bn2d = BatchNorm2d(num_features=num_features,
eps=eps,
momentum=momentum,
@ -612,58 +422,22 @@ class BatchNorm3d(Cell):
use_batch_statistics=use_batch_statistics,
data_format="NCHW")
def construct(self, input_x):
x_shape = F.shape(input_x)
_check_3d_shape(x_shape, self.cls_name)
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(input_x)
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 5:
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
def construct(self, x):
x_shape = self.shape(x)
self._check_input_dim(x_shape, self.cls_name)
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(x)
bn3d_out = self.reshape(bn2d_out, x_shape)
return bn3d_out
class GlobalBatchNorm(_BatchNorm):
r"""
The GlobalBatchNorm interface is deprecated, please use the :class:`mindspore.nn.SyncBatchNorm` instead.
Supported Platforms:
deprecated
"""
@deprecated("1.2", "SyncBatchNorm", True)
def __init__(self,
num_features,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None,
device_num_each_group=2):
"""Initialize GlobalBatchNorm."""
super(GlobalBatchNorm, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
device_num_each_group,
input_dims='both')
self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
self.cls_name)
if self.group_device_num <= 1:
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, "
f"but got {self.group_device_num}.")
def _check_data_dim(self, x):
if x.dim == 0:
pass
class SyncBatchNorm(_BatchNorm):
r"""
Sync Batch Normalization layer over a N-dimension input.
@ -772,13 +546,61 @@ class SyncBatchNorm(_BatchNorm):
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics,
process_groups=process_groups,
input_dims='both')
use_batch_statistics)
self.is_global = False
global SYNC_BN_GROUP_NAME
self.process_groups = process_groups
if self.process_groups != 0:
self.rank_id = get_rank()
self.rank_size = get_group_size()
if self.process_groups is not None:
validator.check_isinstance("process_groups", self.process_groups, list)
self._check_rank_ids(self.process_groups, self.rank_size)
self._create_sync_groups()
elif self.rank_size > 1:
self.is_global = True
self.group_device_num = self.rank_size
self.device_list = [i for i in range(0, self.rank_size)]
if context.get_context("device_target") == "Ascend":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group0"
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
elif context.get_context("device_target") == "GPU":
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "nccl_world_group"
def _check_data_dim(self, x):
if x.dim == 0:
pass
if self.is_global:
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
momentum=self.momentum,
group=SYNC_BN_GROUP_NAME,
device_num=self.group_device_num)
def _create_sync_groups(self):
for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True
global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (2, 4):
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
def _check_rank_ids(self, process_groups, rank_size):
seen = set()
for rid in itertools.chain(*process_groups):
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
if rid in seen:
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
f"but got {process_groups}.")
seen.add(rid)
class LayerNorm(Cell):
@ -871,7 +693,6 @@ class LayerNorm(Cell):
class _InstanceNorm(Cell):
"""Instance Normalization base class."""
@cell_attr_register
def __init__(self,
num_features,
@ -879,8 +700,7 @@ class _InstanceNorm(Cell):
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros',
input_dims='2d'):
beta_init='zeros'):
"""Initialize Normalization base class."""
super(_InstanceNorm, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
@ -897,7 +717,6 @@ class _InstanceNorm(Cell):
f"but got {momentum}.")
self.num_features = num_features
self.eps = eps
self.input_dims = input_dims
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(
@ -910,7 +729,7 @@ class _InstanceNorm(Cell):
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
def construct(self, x):
_shape_check_in(self.shape(x), self.input_dims, self.cls_name)
self._check_input_dim(self.shape(x), self.cls_name)
return self.instance_bn(x,
self.gamma,
self.beta,
@ -1005,21 +824,12 @@ class InstanceNorm1d(_InstanceNorm):
(2, 3, 5)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros'):
"""Initialize InstanceNorm2d."""
super(InstanceNorm1d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
input_dims='1d')
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 3:
raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.")
class InstanceNorm2d(_InstanceNorm):
@ -1095,21 +905,12 @@ class InstanceNorm2d(_InstanceNorm):
(2, 3, 2, 2)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros'):
"""Initialize InstanceNorm2d."""
super(InstanceNorm2d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
input_dims='2d')
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
class InstanceNorm3d(_InstanceNorm):
@ -1184,22 +985,12 @@ class InstanceNorm3d(_InstanceNorm):
>>> print(output.shape)
(2, 3, 5, 2, 2)
"""
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros'):
"""Initialize InstanceNorm2d."""
super(InstanceNorm3d, self).__init__(num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
input_dims='3d')
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 5:
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
class GroupNorm(Cell):
@ -1283,7 +1074,7 @@ class GroupNorm(Cell):
def _cal_output(self, x):
"""calculate groupnorm output"""
batch, channel, height, width = self.shape(x)
_channel_check(channel, self.num_channels, self.cls_name)
self._channel_check(channel, self.num_channels, self.cls_name)
x = self.reshape(x, (batch, self.num_groups, -1))
mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
@ -1293,11 +1084,31 @@ class GroupNorm(Cell):
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
return output
def construct(self, x):
_shape_check(self.shape(x), self.cls_name)
_check_dtype(x.dtype, [mstype.float16, mstype.float32], "input", self.cls_name)
output = self._cal_output(x)
return output
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
@staticmethod
@constexpr
def _channel_check(channel, num_channel, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if channel != num_channel:
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
f"but got channel: {channel}, num_channels: {num_channel}.")
@staticmethod
@constexpr
def _check_dtype(dtype, valid_dtypes, prim_name=None):
validator.check_type_name("input", dtype, valid_dtypes, prim_name)
def extend_repr(self):
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
def construct(self, x):
self._check_input_dim(self.shape(x), self.cls_name)
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
output = self._cal_output(x)
return output

View File

@ -210,7 +210,7 @@ def find_net_layertype_recur(net, layertype_map):
if subcell.gamma.requires_grad:
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
nn.BatchNorm1d, nn.GroupNorm)):
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
else:

View File

@ -134,7 +134,7 @@ class ConvertNetUtils:
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)):
continue
elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm,
nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
continue
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)):
prefix = subcell.param_prefix