!14072 for r1.2, Add nn.Conv3d and nn.Conv3dTranspose.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui,@c_34
This commit is contained in:
mindspore-ci-bot 2021-03-26 09:09:51 +08:00 committed by Gitee
commit 331de218f1
4 changed files with 470 additions and 51 deletions

View File

@ -21,11 +21,11 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice
from mindspore._checkparam import Validator, Rel, twice, triple
from mindspore._extends import cell_attr_register
from ..cell import Cell
__all__ = ['Conv2d', 'Conv2dTranspose', 'Conv1d', 'Conv1dTranspose']
__all__ = ['Conv2d', 'Conv2dTranspose', 'Conv1d', 'Conv1dTranspose', 'Conv3d', 'Conv3dTranspose']
class _Conv(Cell):
@ -55,9 +55,11 @@ class _Conv(Cell):
self.pad_mode = pad_mode
self.weight_init = weight_init
self.bias_init = bias_init
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend target.")
if isinstance(padding, int):
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding
@ -71,30 +73,23 @@ class _Conv(Cell):
self.dilation = dilation
self.group = Validator.check_positive_int(group)
self.has_bias = has_bias
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
+ str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
+ str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
raise ValueError("Attr 'dilation' of 'Conv2D' Op passed "
+ str(self.dilation) + ", should be a int or tuple and equal to or greater than 1.")
for kernel_size_elem in kernel_size:
Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
for stride_elem in stride:
Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
for dilation_elem in dilation:
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
if in_channels % group != 0:
raise ValueError("Attr 'in_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op.")
raise ValueError(f"Attr 'in_channels' of {self.cls_name} Op must be divisible by "
f"attr 'group' of {self.cls_name} Op.")
if out_channels % group != 0:
raise ValueError("Attr 'out_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op.")
raise ValueError(f"Attr 'out_channels' {self.cls_name} Op must be divisible by "
f"attr 'group' of {self.cls_name} Op.")
if transposed:
shape = [in_channels, out_channels // group, *kernel_size]
else:
shape = [out_channels, in_channels // group, *kernel_size] if self.format == "NCHW" else \
[out_channels, *kernel_size, in_channels // group]
shape = [out_channels, *kernel_size, in_channels // group] if self.format == "NHWC" else \
[out_channels, in_channels // group, *kernel_size]
self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
if Validator.check_bool(has_bias):
@ -476,6 +471,361 @@ class Conv1d(_Conv):
return s
@constexpr
def _check_input_5dims(input_shape, op_name):
if len(input_shape) != 5:
raise ValueError(f"For {op_name}, input should be 5 dims, but got shape {input_shape}.")
class Conv3d(_Conv):
r"""
3D convolution layer.
Applies a 3D convolution over an input tensor which is typically of shape
For input shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` and output shape
:math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. where :math:`N` is batch size. :math:`C` is channel number.
the formula is defined as:
.. math::
\operatorname{out}\left(N_{i}, C_{\text {out}_j}\right)=\operatorname{bias}\left(C_{\text {out}_j}\right)+
\sum_{k=0}^{C_{in}-1} ccor(\text {weight}\left(C_{\text {out}_j}, k\right),
\operatorname{input}\left(N_{i}, k\right))
where :math:`ccor` is the cross-correlation operator.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} -
(\text{ks_d} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers. Specifies the height
and width of the 3D convolution window. Single int means the value is for the depth, height and the width
of the kernel. A tuple of 3 ints means the first value is for the depth, second value is for height
and the other is for the width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. The depth, height and width of the output will be the same as
the input. The total number of padding will be calculated in depth, horizontal and vertical
directions and evenly distributed to head and tail, top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the tail, bottom and the right side.
If this mode is set, `padding` must be 0.
- valid: Adopts the way of discarding. The possible largest depth, height and width of output
will be returned without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input in depth, height, width. The number of `padding` will
be padded to the input Tensor borders. `padding` must be greater than or equal to 0.
padding (Union(int, tuple[int])): Implicit paddings on both sides of the input.
The data type is int or a tuple of 6 integers. Default: 0. If `padding` is an integer,
the paddings of head, tail, top, bottom, left and right are the same, equal to padding.
If `paddings` is a tuple of three integers, the padding of head, tail, top, bottom, left and right equal to
padding[0], padding[1], padding[2], padding[3], padding[4] and padding[5] correspondingly.
dilation (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers
: math:`(dilation_d, dilation_h, dilation_w)`. Currently, dilation on depth only supports the case of 1.
Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location.
Its value must be greater or equal to 1 and bounded by the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1. Only 1 is currently supported.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
data_format (str): The optional value for data format. Currently only support "NCDHW".
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
Currently input data type only support float16 and float32.
Outputs:
Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
Raises:
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int not a tuple of three.
ValueError: If `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `padding` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
ValueError: If `padding` is a tuple whose length is not equal to 6.
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0, 0, 0).
ValueError: If `data_format` is not 'NCDHW'.
Supported Platforms:
``Ascend``
Examples:
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
>>> conv3d = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 3, 3))
>>> output = conv3d(input)
>>> print(output.shape)
(16, 32, 10, 32, 32)
"""
@cell_attr_register
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
data_format='NCDHW'):
kernel_size = triple(kernel_size)
stride = triple(stride)
dilation = triple(dilation)
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple):
Validator.check_equal_int(len(padding), 6, 'padding size', self.cls_name)
super(Conv3d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init,
data_format)
self.conv3d = P.Conv3D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group,
data_format=self.format)
self.bias_add = P.BiasAdd(data_format=self.format)
self.shape = P.Shape()
def construct(self, x):
x_shape = self.shape(x)
_check_input_5dims(x_shape, self.cls_name)
output = self.conv3d(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={}' \
'weight_init={}, bias_init={}, format={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.has_bias,
self.weight_init,
self.bias_init,
self.format)
return s
class Conv3dTranspose(_Conv):
r"""
Compute a 3D transposed convolution, which is also known as a deconvolution
(although it is not an actual deconvolution).
Input is typically of shape :math:`(N, C, D, H, W)`, where :math:`N` is batch size and :math:`C` is channel number.
If the 'pad_mode' is set to be "pad", the height and width of output are defined as:
.. math::
D_{out} = (D_{in} - 1) \times \text{stride_d} - 2 \times \text{padding_d} + \text{dilation_d} \times
(\text{kernel_size_d} - 1) + \text{output_padding_d} + 1
H_{out} = (H_{in} - 1) \times \text{stride_h} - 2 \times \text{padding_h} + \text{dilation_h} \times
(\text{kernel_size_h} - 1) + \text{output_padding_h} + 1
W_{out} = (W_{in} - 1) \times \text{stride_w} - 2 \times \text{padding_w} + \text{dilation_w} \times
(\text{kernel_size_w} - 1) + \text{output_padding_w} + 1
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Its value must be equal to or greater than 1.
Default: 1.
pad_mode (str): Select the mode of the pad. The optional values are
"pad", "same", "valid". Default: "same".
- same: Adopts the way of completion. The depth, height and width of the output will be the same as
the input. The total number of padding will be calculated in depth, horizontal and vertical
directions and evenly distributed to head and tail, top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the tail, bottom and the right side.
If this mode is set, `padding` and `output_padding` must be 0.
- valid: Adopts the way of discarding. The possible largest depth, height and width of output
will be returned without padding. Extra pixels will be discarded. If this mode is set, `padding`
and `output_padding` must be 0.
- pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
be padded to the input Tensor borders. `padding` must be greater than or equal to 0.
padding (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `padding` is an integer,
the paddings of head, tail, top, bottom, left and right are the same, equal to padding.
If `padding` is a tuple of six integers, the padding of head, tail, top, bottom, left and right equal to
padding[0], padding[1], padding[2], padding[3], padding[4] and padding[5] correspondingly.
dilation (Union(int, tuple[int])): The data type is int or a tuple of 3 integers
: math:`(dilation_d, dilation_h, dilation_w)`. Currently, dilation on depth only supports the case of 1.
Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location.
Its value must be greater or equal to 1 and bounded by the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1. Only 1 is currently supported.
output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
Must be greater than or equal to 0.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
Currently input data type only support float16 and float32.
Outputs:
Tensor, the shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
Supported Platforms:
``Ascend``
Raise:
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
TypeError: If `kernel_size`, `stride`, `padding` , `dilation` or `output_padding`
is neither an int not a tuple of three.
TypeError: If input data type is not float16 or float32.
ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `padding` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
ValueError: If `padding` is a tuple whose length is not equal to 6.
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0, 0, 0).
ValueError: If `data_format` is not 'NCDHW'.
Examples:
>>> input = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
>>> conv3d_transpose = nn.Conv3dTranspose(in_channels=16, out_channels=3, kernel_size=(4, 6, 2), pad_mode='pad')
>>> output = conv3d_transpose(input)
>>> print(output.shape)
(32, 3, 13, 37, 33)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
output_padding=0,
has_bias=False,
weight_init='normal',
bias_init='zeros',
data_format='NCDHW'):
kernel_size = triple(kernel_size)
stride = triple(stride)
dilation = triple(dilation)
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple):
Validator.check_equal_int(len(padding), 6, 'padding size', self.cls_name)
output_padding = triple(output_padding)
super(Conv3dTranspose, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init,
data_format,
transposed=True)
self.conv3d_transpose = P.Conv3DTranspose(in_channel=self.in_channels,
out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group,
output_padding=output_padding,
data_format=self.format)
self.bias_add = P.BiasAdd(data_format=self.format)
self.shape = P.Shape()
def construct(self, x):
x_shape = self.shape(x)
_check_input_5dims(x_shape, self.cls_name)
output = self.conv3d_transpose(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.has_bias,
self.weight_init,
self.bias_init)
return s
class Conv2dTranspose(_Conv):
r"""
2D transposed convolution layer.
@ -501,7 +851,7 @@ class Conv2dTranspose(_Conv):
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
kernel_size (Union[int, tuple]): int or a tuple of 2 integers, which specifies the height
kernel_size (Union[int, tuple]): int or a tuple of 2 integers, which specifies the height
and width of the 2D convolution window. Single int means the value is for both the height and the width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.

View File

@ -89,13 +89,15 @@ def get_bprop_conv3d(self):
@bprop_getters.register(nps.Conv3DTranspose)
def get_bprop_conv3d_transpose(self):
"""Grad definition for `Conv3DTranspose` operation."""
stride = (self.stride[2], self.stride[3], self.stride[4])
dilation = (self.dilation[2], self.dilation[3], self.dilation[4])
input_grad = nps.Conv3D(
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
pad=self.pad_list, stride=stride, dilation=dilation, group=self.group, data_format=self.data_format
)
filter_grad = G.Conv3DBackpropFilter(
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
pad=self.pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
def bprop(x, w, out, dout):

View File

@ -62,7 +62,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
BiasAdd, Conv2D, Conv3D, Conv3DTranspose,
DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
@ -139,6 +139,8 @@ __all__ = [
'Xdivy',
'Xlogy',
'Conv2D',
'Conv3D',
'Conv3DTranspose',
'Flatten',
'MaxPoolWithArgmax',
'BNTrainingReduce',

View File

@ -7765,7 +7765,7 @@ class Conv3D(PrimitiveWithInfer):
for each sampling location. Its value must be greater or equal to 1 and
bounded by the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1.
divisible by the number of groups. Default: 1. Only 1 is currently supported.
data_format (str): The optional value for data format. Currently only support "NCDHW".
Inputs:
@ -7814,10 +7814,9 @@ class Conv3D(PrimitiveWithInfer):
"""Initialize Conv3D"""
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True,
ret_five=True)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=False, ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=False,
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
@ -7854,7 +7853,7 @@ class Conv3D(PrimitiveWithInfer):
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.group = validator.check_equal_int(group, 1, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.add_prim_attr('offset_x', 0)
@ -8074,8 +8073,17 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
return out
def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size):
filter_size = kernel_size + (kernel_size - 1) * (dilation_size - 1)
if filter_size - stride_size > 0:
length = input_length * stride_size + filter_size - stride_size
else:
length = input_length * stride_size
return length
class Conv3DTranspose(PrimitiveWithInfer):
"""
r"""
Compute a 3D transposed convolution, which is also known as a deconvolution
(although it is not an actual deconvolution).
@ -8091,24 +8099,38 @@ class Conv3DTranspose(PrimitiveWithInfer):
(\text{kernel_size_h} - 1) + \text{output_padding_h} + 1
W_{out} = (W_{in} - 1) \times \text{stride_w} - 2 \times \text{padding_w} + \text{dilation_w} \times
(\text{kernel_size_w} - 1) + 1
(\text{kernel_size_w} - 1) + \text{output_padding_w} + 1
Args:
in_channel (int): The channel of the input x.
out_channel (int): The channel of the weight x.
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
mode (int): Modes for different convolutions. Default is 1. Not currently used.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "valid".
- same: Adopts the way of completion. The depth, height and width of the output will be the same as
the input. The total number of padding will be calculated in depth, horizontal and vertical
directions and evenly distributed to head and tail, top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the tail, bottom and the right side.
If this mode is set, `pad` and `output_padding` must be 0.
- valid: Adopts the way of discarding. The possible largest depth, height and width of output
will be returned without padding. Extra pixels will be discarded. If this mode is set, `pad`
and `output_padding` must be 0.
- pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six integers,
the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
group (int): Splits input into groups. Default: 1.
group (int): Splits input into groups. Default: 1. Only 1 is currently supported.
output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
input_size (tuple[int]): A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Not currently used.
Inputs:
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
@ -8127,7 +8149,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
Raise:
TypeError: If `in_channel`, `out_channel` or `group` is not an int.
TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int not a tuple.
TypeError: If `kernel_size`, `stride`, `pad` , `dilation` or `output_padding` is neither an int not a tuple.
ValueError: If `in_channel`, `out_channel`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `pad` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
@ -8152,6 +8174,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
out_channel,
kernel_size,
mode=1,
pad_mode='valid',
pad=0,
stride=1,
dilation=1,
@ -8165,10 +8188,10 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.add_prim_attr('out_channel', self.out_channel)
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True,
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=False,
ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=False,
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
@ -8178,8 +8201,15 @@ class Conv3DTranspose(PrimitiveWithInfer):
raise ValueError(f"For `conv3d` attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.")
self.pad_list = pad
for item in self.pad_list:
validator.check_non_negative_int(item, 'pad item', self.name)
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.add_prim_attr('pad_mode', self.pad_mode)
if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
if self.pad_mode == 'pad':
for item in self.pad_list:
validator.check_non_negative_int(item, 'pad item', self.name)
validator.check_int_range(self.pad_list[0], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.pad_list[1], 0, self.kernel_size[0], Rel.INC_LEFT,
@ -8194,13 +8224,16 @@ class Conv3DTranspose(PrimitiveWithInfer):
'pad_w belonging [0, kernel_size_w)', self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
self.group = validator.check_positive_int(group, 'group', self.name)
self.mode = validator.check_equal_int(group, 1, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
allow_five=True, ret_five=True, greater_zero=False)
allow_five=False, ret_five=True, greater_zero=False)
output_padding = (self.output_padding[2], self.output_padding[3], self.output_padding[4])
if self.pad_mode != 'pad' and output_padding != (0, 0, 0):
raise ValueError(f"For '{self.name}', when output_padding is not 0, pad_mode should be set as 'pad'.")
validator.check_int_range(self.kernel_size[0]*self.kernel_size[1]*self.kernel_size[2], 1, 343, Rel.INC_BOTH,
'The product of height, width and depth of kernel_size belonging [1, 343]', self.name)
validator.check_int_range(self.stride[0]*self.stride[1]*self.stride[2], 1, 343, Rel.INC_BOTH,
@ -8213,7 +8246,6 @@ class Conv3DTranspose(PrimitiveWithInfer):
'output_padding_h belonging [0, max(stride_h,dilation_h))', self.name)
validator.check_int_range(self.output_padding[4], 0, max(self.dilation[4], self.stride[4]), Rel.INC_LEFT,
'output_padding_w belonging [0, max(stride_w,dilation_w))', self.name)
self.add_prim_attr('output_padding', self.output_padding)
def __infer__(self, x, w, b=None):
args = {'x': x['dtype'], 'w': w['dtype']}
@ -8230,14 +8262,47 @@ class Conv3DTranspose(PrimitiveWithInfer):
validator.check("filter's batch", w_shape[0], "input x's channel",
x_shape[1], Rel.EQ, self.name)
kernel_d, kernel_h, kernel_w = self.kernel_size
_, _, stride_d, stride_h, stride_w = self.stride
_, _, dilation_d, dilation_h, dilation_w = self.dilation
if self.pad_mode == "valid":
d_out = _deconv_output_length(x_shape[2], kernel_d, stride_d, dilation_d)
h_out = _deconv_output_length(x_shape[3], kernel_h, stride_h, dilation_h)
w_out = _deconv_output_length(x_shape[4], kernel_w, stride_w, dilation_w)
self.pad_list = (0, 0, 0, 0, 0, 0)
self.output_padding = (0, 0, 0, 0, 0)
elif self.pad_mode == "same":
d_out = x_shape[2] * stride_d
h_out = x_shape[3] * stride_h
w_out = x_shape[4] * stride_w
pad_needed_d = max(0, (x_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - d_out)
pad_head = math.floor(pad_needed_d / 2)
pad_tail = pad_needed_d - pad_head
pad_needed_h = max(0, (x_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - h_out)
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (x_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - w_out)
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
self.output_padding = (0, 0, 0, 0, 0)
elif self.pad_mode == 'pad':
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
d_out = (x_shape[2] - 1) * self.stride[2] - (pad_head + pad_tail) + self.dilation[2] * \
(self.kernel_size[0] - 1) + self.output_padding[2] + 1
h_out = (x_shape[3] - 1) * self.stride[3] - (pad_top + pad_bottom) + self.dilation[3] * \
(self.kernel_size[1] - 1) + self.output_padding[3] + 1
w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
(self.kernel_size[2] - 1) + self.output_padding[4] + 1
self.add_prim_attr('pad_list', self.pad_list)
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
d_out = (x_shape[2] - 1) * self.stride[2] - (pad_head + pad_tail) + self.dilation[2] * \
(self.kernel_size[0] - 1) + self.output_padding[2] + 1
h_out = (x_shape[3] - 1) * self.stride[3] - (pad_top + pad_bottom) + self.dilation[3] * \
(self.kernel_size[1] - 1) + self.output_padding[3] + 1
w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
(self.kernel_size[2] - 1) + self.output_padding[4] + 1
self.add_prim_attr('output_padding', self.output_padding)
output_shape = (x_shape[0], w_shape[1]*self.group, d_out, h_out, w_out)
self.add_prim_attr('input_size', output_shape)
out = {