forked from mindspore-Ecosystem/mindspore
add Conv1d ops
This commit is contained in:
parent
ba393c83a9
commit
7f9bbfd338
|
@ -18,11 +18,12 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
|
||||
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose']
|
||||
|
||||
class _Conv(Cell):
|
||||
"""
|
||||
|
@ -241,6 +242,174 @@ class Conv2d(_Conv):
|
|||
return s
|
||||
|
||||
|
||||
class Conv1d(_Conv):
|
||||
r"""
|
||||
1D convolution layer.
|
||||
|
||||
Applies a 1D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, W_{in})`,
|
||||
where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape
|
||||
:math:`(C_{in}, W_{in})`, the formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
|
||||
|
||||
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
|
||||
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
|
||||
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||
of kernel and it has shape :math:`(\text{ks_w})`, where :math:`\text{ks_w}` are width of the convolution kernel.
|
||||
The full kernel has shape :math:`(C_{out}, C_{in} // \text{group}, \text{ks_w})`, where group is the group number
|
||||
to split the input in the channel dimension.
|
||||
|
||||
If the 'pad_mode' is set to be "valid", the output width will be
|
||||
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||
|
||||
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
|
||||
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (int): The data type is int. Specifies the
|
||||
width of the 1D convolution window.
|
||||
stride (int): The distance of kernel moving, an int number that represents
|
||||
the width of movement. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are
|
||||
"same", "valid", "pad". Default: "same".
|
||||
|
||||
- same: Adopts the way of completion. Output width will be the same as the input.
|
||||
Total number of padding will be calculated for horizontal
|
||||
direction and evenly distributed to left and right if possible. Otherwise, the
|
||||
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
|
||||
must be 0.
|
||||
|
||||
- valid: Adopts the way of discarding. The possibly largest width of output will be return
|
||||
without padding. Extra pixels will be discarded. If this mode is set, `padding`
|
||||
must be 0.
|
||||
|
||||
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||
Tensor borders. `padding` should be greater than or equal to 0.
|
||||
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): The data type is int. Specifies the dilation rate
|
||||
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||
be greater or equal to 1 and bounded by the height and width of the
|
||||
input. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32)
|
||||
>>> net(input).shape
|
||||
(1, 240, 640)
|
||||
"""
|
||||
@cell_attr_register
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
|
||||
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
|
||||
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
|
||||
kernel_size = (1, kernel_size)
|
||||
stride = (1, stride)
|
||||
dilation = (1, dilation)
|
||||
|
||||
super(Conv1d, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init)
|
||||
self.padding = (0, 0, padding, padding)
|
||||
self.conv2d = P.Conv2D(out_channel=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
mode=1,
|
||||
pad_mode=self.pad_mode,
|
||||
pad=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=self.group)
|
||||
self.bias_add = P.BiasAdd()
|
||||
if pad_mode not in ('valid', 'same', 'pad'):
|
||||
raise ValueError('Attr \'pad_mode\' of \'Conv1d\' Op passed '
|
||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(2)
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
if len(x_shape) == 3:
|
||||
x = self.expand_dims(x, 2)
|
||||
output = self.conv2d(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if len(x_shape) == 3:
|
||||
output = self.squeeze(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'input_channels={}, output_channels={}, kernel_size={},' \
|
||||
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
|
||||
'group={}, has_bias={},' \
|
||||
'weight_init={}, bias_init={}'.format(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.pad_mode,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.group,
|
||||
self.has_bias,
|
||||
self.weight,
|
||||
self.bias)
|
||||
|
||||
if self.has_bias:
|
||||
s += ', bias={}'.format(self.bias)
|
||||
return s
|
||||
|
||||
|
||||
class Conv2dTranspose(_Conv):
|
||||
r"""
|
||||
2D transposed convolution layer.
|
||||
|
@ -400,6 +569,181 @@ class Conv2dTranspose(_Conv):
|
|||
return s
|
||||
|
||||
|
||||
class Conv1dTranspose(_Conv):
|
||||
r"""
|
||||
1D transposed convolution layer.
|
||||
|
||||
Compute a 1D transposed convolution, which is also know as a deconvolution
|
||||
(although it is not actual deconvolution).
|
||||
|
||||
Input is typically of shape :math:`(N, C, W)`, where :math:`N` is batch size and :math:`C` is channel number.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
kernel_size (int): int, which specifies the width of the 1D convolution window.
|
||||
stride (int): The distance of kernel moving, an int number that represents
|
||||
the width of movement. Default: 1.
|
||||
pad_mode (str): Select the mode of the pad. The optional values are
|
||||
"pad", "same", "valid". Default: "same".
|
||||
|
||||
- pad: Implicit paddings on both sides of the input.
|
||||
|
||||
- same: Adopted the way of completion.
|
||||
|
||||
- valid: Adopted the way of discarding.
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): The data type is int. Specifies the dilation rate
|
||||
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||
be greater or equal to 1 and bounded by the width of the
|
||||
input. Default: 1.
|
||||
group (int): Split filter into groups, `in_channels` and `out_channels` should be
|
||||
divisible by the number of groups. This is not support for Davinci devices when group > 1. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
|
||||
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
|
||||
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
|
||||
kernel_size = (1, kernel_size)
|
||||
stride = (1, stride)
|
||||
dilation = (1, dilation)
|
||||
# out_channels and in_channels swap.
|
||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
|
||||
# then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
|
||||
super(Conv1dTranspose, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init,
|
||||
transposed=True)
|
||||
self.padding = (0, 0, padding, padding)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.shape = P.Shape()
|
||||
if pad_mode not in ('valid', 'same', 'pad'):
|
||||
raise ValueError('Attr \'pad_mode\' of \'Conv1dTranspose\' Op passed '
|
||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||
self.is_valid = self.pad_mode == 'valid'
|
||||
self.is_same = self.pad_mode == 'same'
|
||||
self.is_pad = self.pad_mode == 'pad'
|
||||
if check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
|
||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
|
||||
self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels,
|
||||
kernel_size=kernel_size,
|
||||
mode=1,
|
||||
pad_mode=pad_mode,
|
||||
pad=self.padding,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
group=group)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(2)
|
||||
|
||||
def set_strategy(self, strategy):
|
||||
self.conv2d_transpose.set_strategy(strategy)
|
||||
return self
|
||||
|
||||
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):
|
||||
"""Calculate the width and height of output."""
|
||||
length = 0
|
||||
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
||||
if self.is_valid:
|
||||
if filter_size - stride_size > 0:
|
||||
length = input_length * stride_size + filter_size - stride_size
|
||||
else:
|
||||
length = input_length * stride_size
|
||||
elif self.is_same:
|
||||
length = input_length * stride_size
|
||||
elif self.is_pad:
|
||||
length = input_length * stride_size - padding + filter_size - stride_size
|
||||
|
||||
return length
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
if len(x_shape) == 3:
|
||||
x = self.expand_dims(x, 2)
|
||||
|
||||
n, _, h, w = self.shape(x)
|
||||
|
||||
h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0],
|
||||
self.padding[0] + self.padding[1])
|
||||
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1],
|
||||
self.padding[2] + self.padding[3])
|
||||
if self.has_bias:
|
||||
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
|
||||
self.bias)
|
||||
output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out))
|
||||
if len(x_shape) == 3:
|
||||
output = self.squeeze(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'input_channels={}, output_channels={}, kernel_size={},' \
|
||||
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
|
||||
'group={}, has_bias={},' \
|
||||
'weight_init={}, bias_init={}'.format(self.in_channels,
|
||||
self.out_channels,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.pad_mode,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.group,
|
||||
self.has_bias,
|
||||
self.weight,
|
||||
self.bias)
|
||||
return s
|
||||
|
||||
|
||||
class DepthwiseConv2d(Cell):
|
||||
r"""
|
||||
2D depthwise convolution layer.
|
||||
|
|
|
@ -780,7 +780,9 @@ class Conv2D(PrimitiveWithInfer):
|
|||
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
||||
2 deconvolution, 3 depthwise convolution. Default: 1.
|
||||
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
||||
pad (int): The pad value to fill. Default: 0.
|
||||
pad (Union(int, tuple[int])): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
|
||||
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
|
||||
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
|
||||
stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1.
|
||||
dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1.
|
||||
group (int): Split input into groups. Default: 1.
|
||||
|
@ -820,11 +822,19 @@ class Conv2D(PrimitiveWithInfer):
|
|||
self.add_prim_attr('stride', self.stride)
|
||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('dilation', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int,), self.name)
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 4
|
||||
else:
|
||||
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
||||
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
if self.pad_mode == 'pad':
|
||||
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
|
||||
for item in pad:
|
||||
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
||||
|
||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
@ -862,11 +872,11 @@ class Conv2D(PrimitiveWithInfer):
|
|||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
elif self.pad_mode == 'pad':
|
||||
pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
|
||||
pad_top, pad_bottom, pad_left, pad_right = self.padding
|
||||
|
||||
h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
|
||||
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
|
||||
/ stride_h
|
||||
w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
|
||||
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
|
||||
/ stride_w
|
||||
h_out = math.floor(h_out)
|
||||
w_out = math.floor(w_out)
|
||||
|
@ -1277,7 +1287,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
out_channel (int): The dimensionality of the output space.
|
||||
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
|
||||
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
||||
pad (int): The pad value to fill. Default: 0.
|
||||
pad (Union[int, tuple[int]]): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
|
||||
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
|
||||
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
|
||||
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
||||
2 deconvolution, 3 depthwise convolution. Default: 1.
|
||||
stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
|
||||
|
@ -1314,9 +1326,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
self.add_prim_attr('stride', self.stride)
|
||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('dilation', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int,), self.name)
|
||||
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 4
|
||||
self.pad = pad
|
||||
else:
|
||||
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
if self.pad_mode == 'pad':
|
||||
for item in pad:
|
||||
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
||||
|
||||
pad_mode = pad_mode.upper()
|
||||
self.add_prim_attr('pad_mode', pad_mode)
|
||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||
|
@ -1358,7 +1382,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
pad_right = pad_needed_w - pad_left
|
||||
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
|
||||
elif self.pad_mode == 'PAD':
|
||||
pad_list = (self.pad,) * 4
|
||||
pad_list = self.pad
|
||||
self.add_prim_attr('pad_list', pad_list)
|
||||
out = {
|
||||
'value': None,
|
||||
|
|
|
@ -22,11 +22,22 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
|||
|
||||
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
||||
"""Rearranges an image to row vector"""
|
||||
batch_num, channel, height, width = img.shape
|
||||
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
|
||||
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
|
||||
if isinstance(pad, int):
|
||||
pad_top = pad
|
||||
pad_bottom = pad
|
||||
pad_left = pad
|
||||
pad_right = pad
|
||||
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||
else:
|
||||
raise ValueError(f"The \'pad\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {pad}")
|
||||
|
||||
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
|
||||
batch_num, channel, height, width = img.shape
|
||||
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
|
||||
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
|
||||
|
||||
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
|
||||
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
||||
|
||||
for y in range(filter_h):
|
||||
|
@ -43,10 +54,21 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
|||
def conv2d(x, weight, bias=None, stride=1, pad=0,
|
||||
dilation=1, groups=1, padding_mode='zeros'):
|
||||
"""Convolution 2D"""
|
||||
if isinstance(pad, int):
|
||||
pad_top = pad
|
||||
pad_bottom = pad
|
||||
pad_left = pad
|
||||
pad_right = pad
|
||||
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||
else:
|
||||
raise ValueError(f"The \'pad\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {pad}")
|
||||
|
||||
batch_num, _, x_h, x_w = x.shape
|
||||
filter_num, _, filter_h, filter_w = weight.shape
|
||||
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
|
||||
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
|
||||
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
|
||||
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
|
||||
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
||||
col_w = np.reshape(weight, (filter_num, -1)).T
|
||||
out = np.dot(col, col_w)
|
||||
|
|
|
@ -169,16 +169,32 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
|
|||
raise ValueError(f"The \'stride\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {stride}")
|
||||
|
||||
if isinstance(pad, int):
|
||||
pad_top = pad
|
||||
pad_bottom = pad
|
||||
pad_left = pad
|
||||
pad_right = pad
|
||||
elif isinstance(pad, tuple) and len(pad) == 2:
|
||||
pad_top = pad[0]
|
||||
pad_bottom = pad[0]
|
||||
pad_left = pad[1]
|
||||
pad_right = pad[1]
|
||||
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||
else:
|
||||
raise ValueError(f"The \'pad\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {pad}")
|
||||
|
||||
batch_num, channel, height, width = input_shape
|
||||
out_h = (height + 2 * pad - filter_h) // stride_h + 1
|
||||
out_w = (width + 2 * pad - filter_w) // stride_w + 1
|
||||
out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1
|
||||
out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1
|
||||
col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
|
||||
.transpose(0, 3, 4, 5, 1, 2)
|
||||
|
||||
img = np.zeros((batch_num,
|
||||
channel,
|
||||
height + 2 * pad + stride_h - 1,
|
||||
width + 2 * pad + stride_w - 1)) \
|
||||
height + pad_top + pad_bottom + stride_h - 1,
|
||||
width + pad_left + pad_right + stride_w - 1)) \
|
||||
.astype(col.dtype)
|
||||
for y in range(filter_h):
|
||||
y_max = y + stride_h * out_h
|
||||
|
@ -186,7 +202,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
|
|||
x_max = x + stride_h * out_w
|
||||
img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
|
||||
|
||||
return img[:, :, pad:height + pad, pad:width + pad]
|
||||
return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right]
|
||||
|
||||
|
||||
def convolve(x, w, b=None, pad_mode="valid"):
|
||||
|
@ -243,10 +259,21 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
|
|||
dilation_h = dilation[0]
|
||||
dilation_w = dilation[1]
|
||||
|
||||
if isinstance(pad, int):
|
||||
pad_top = pad
|
||||
pad_bottom = pad
|
||||
pad_left = pad
|
||||
pad_right = pad
|
||||
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||
else:
|
||||
raise ValueError(f"The \'pad\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {pad}")
|
||||
|
||||
batch_num, _, x_h, x_w = x.shape
|
||||
filter_num, _, filter_h, filter_w = weight.shape
|
||||
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
|
||||
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
|
||||
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
|
||||
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
|
||||
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
||||
col_w = np.reshape(weight, (filter_num, -1)).T
|
||||
out = np.dot(col, col_w)
|
||||
|
@ -348,11 +375,22 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
|||
raise ValueError(f"The \'dilation\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {dilation}")
|
||||
|
||||
batch_num, channel, height, width = img.shape
|
||||
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
|
||||
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
|
||||
if isinstance(pad, int):
|
||||
pad_top = pad
|
||||
pad_bottom = pad
|
||||
pad_left = pad
|
||||
pad_right = pad
|
||||
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||
else:
|
||||
raise ValueError(f"The \'pad\' should be an int number or "
|
||||
f"a tuple of two or four int numbers, but got {pad}")
|
||||
|
||||
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
|
||||
batch_num, channel, height, width = img.shape
|
||||
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
|
||||
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
|
||||
|
||||
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
|
||||
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
||||
|
||||
for y in range(filter_h):
|
||||
|
|
Loading…
Reference in New Issue