add Conv1d ops

This commit is contained in:
jiangjinsheng 2020-07-07 16:03:33 +08:00
parent ba393c83a9
commit 7f9bbfd338
4 changed files with 457 additions and 29 deletions

View File

@ -18,11 +18,12 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import Validator
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
from mindspore._extends import cell_attr_register
from ..cell import Cell
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose']
class _Conv(Cell):
"""
@ -241,6 +242,174 @@ class Conv2d(_Conv):
return s
class Conv1d(_Conv):
r"""
1D convolution layer.
Applies a 1D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, W_{in})`,
where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape
:math:`(C_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_w})`, where :math:`\text{ks_w}` are width of the convolution kernel.
The full kernel has shape :math:`(C_{out}, C_{in} // \text{group}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output width will be
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (int): The data type is int. Specifies the
width of the 1D convolution window.
stride (int): The distance of kernel moving, an int number that represents
the width of movement. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output width will be the same as the input.
Total number of padding will be calculated for horizontal
direction and evenly distributed to left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possibly largest width of output will be return
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): The data type is int. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, W_{out})`.
Examples:
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 640)
"""
@cell_attr_register
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
super(Conv1d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init)
self.padding = (0, 0, padding, padding)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.bias_add = P.BiasAdd()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv1d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(2)
self.shape = P.Shape()
def construct(self, x):
x_shape = self.shape(x)
if len(x_shape) == 3:
x = self.expand_dims(x, 2)
output = self.conv2d(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if len(x_shape) == 3:
output = self.squeeze(output)
return output
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
class Conv2dTranspose(_Conv):
r"""
2D transposed convolution layer.
@ -400,6 +569,181 @@ class Conv2dTranspose(_Conv):
return s
class Conv1dTranspose(_Conv):
r"""
1D transposed convolution layer.
Compute a 1D transposed convolution, which is also know as a deconvolution
(although it is not actual deconvolution).
Input is typically of shape :math:`(N, C, W)`, where :math:`N` is batch size and :math:`C` is channel number.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
kernel_size (int): int, which specifies the width of the 1D convolution window.
stride (int): The distance of kernel moving, an int number that represents
the width of movement. Default: 1.
pad_mode (str): Select the mode of the pad. The optional values are
"pad", "same", "valid". Default: "same".
- pad: Implicit paddings on both sides of the input.
- same: Adopted the way of completion.
- valid: Adopted the way of discarding.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): The data type is int. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the width of the
input. Default: 1.
group (int): Split filter into groups, `in_channels` and `out_channels` should be
divisible by the number of groups. This is not support for Davinci devices when group > 1. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, W_{out})`.
Examples:
>>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
# out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
super(Conv1dTranspose, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init,
transposed=True)
self.padding = (0, 0, padding, padding)
self.in_channels = in_channels
self.out_channels = out_channels
self.shape = P.Shape()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv1dTranspose\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
self.is_valid = self.pad_mode == 'valid'
self.is_same = self.pad_mode == 'same'
self.is_pad = self.pad_mode == 'pad'
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=self.padding,
stride=stride,
dilation=dilation,
group=group)
self.bias_add = P.BiasAdd()
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(2)
def set_strategy(self, strategy):
self.conv2d_transpose.set_strategy(strategy)
return self
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):
"""Calculate the width and height of output."""
length = 0
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
if self.is_valid:
if filter_size - stride_size > 0:
length = input_length * stride_size + filter_size - stride_size
else:
length = input_length * stride_size
elif self.is_same:
length = input_length * stride_size
elif self.is_pad:
length = input_length * stride_size - padding + filter_size - stride_size
return length
def construct(self, x):
x_shape = self.shape(x)
if len(x_shape) == 3:
x = self.expand_dims(x, 2)
n, _, h, w = self.shape(x)
h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0],
self.padding[0] + self.padding[1])
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1],
self.padding[2] + self.padding[3])
if self.has_bias:
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
self.bias)
output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out))
if len(x_shape) == 3:
output = self.squeeze(output)
return output
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.has_bias,
self.weight,
self.bias)
return s
class DepthwiseConv2d(Cell):
r"""
2D depthwise convolution layer.

View File

@ -780,7 +780,9 @@ class Conv2D(PrimitiveWithInfer):
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1.
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
pad (int): The pad value to fill. Default: 0.
pad (Union(int, tuple[int])): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1.
dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1.
group (int): Split input into groups. Default: 1.
@ -820,11 +822,19 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
pad = (pad,) * 4
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
for item in pad:
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.add_prim_attr('data_format', "NCHW")
@ -862,11 +872,11 @@ class Conv2D(PrimitiveWithInfer):
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
elif self.pad_mode == 'pad':
pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
/ stride_h
w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
/ stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
@ -1277,7 +1287,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
out_channel (int): The dimensionality of the output space.
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
pad (int): The pad value to fill. Default: 0.
pad (Union[int, tuple[int]]): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1.
stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
@ -1314,9 +1326,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
pad = (pad,) * 4
self.pad = pad
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
for item in pad:
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
@ -1358,7 +1382,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
elif self.pad_mode == 'PAD':
pad_list = (self.pad,) * 4
pad_list = self.pad
self.add_prim_attr('pad_list', pad_list)
out = {
'value': None,

View File

@ -22,11 +22,22 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
"""Rearranges an image to row vector"""
batch_num, channel, height, width = img.shape
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
if isinstance(pad, int):
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
elif isinstance(pad, tuple) and len(pad) == 4:
pad_top, pad_bottom, pad_left, pad_right = pad
else:
raise ValueError(f"The \'pad\' should be an int number or "
f"a tuple of two or four int numbers, but got {pad}")
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
batch_num, channel, height, width = img.shape
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
for y in range(filter_h):
@ -43,10 +54,21 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation=1, groups=1, padding_mode='zeros'):
"""Convolution 2D"""
if isinstance(pad, int):
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
elif isinstance(pad, tuple) and len(pad) == 4:
pad_top, pad_bottom, pad_left, pad_right = pad
else:
raise ValueError(f"The \'pad\' should be an int number or "
f"a tuple of two or four int numbers, but got {pad}")
batch_num, _, x_h, x_w = x.shape
filter_num, _, filter_h, filter_w = weight.shape
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
col_w = np.reshape(weight, (filter_num, -1)).T
out = np.dot(col, col_w)

View File

@ -169,16 +169,32 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
raise ValueError(f"The \'stride\' should be an int number or "
f"a tuple of two or four int numbers, but got {stride}")
if isinstance(pad, int):
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
elif isinstance(pad, tuple) and len(pad) == 2:
pad_top = pad[0]
pad_bottom = pad[0]
pad_left = pad[1]
pad_right = pad[1]
elif isinstance(pad, tuple) and len(pad) == 4:
pad_top, pad_bottom, pad_left, pad_right = pad
else:
raise ValueError(f"The \'pad\' should be an int number or "
f"a tuple of two or four int numbers, but got {pad}")
batch_num, channel, height, width = input_shape
out_h = (height + 2 * pad - filter_h) // stride_h + 1
out_w = (width + 2 * pad - filter_w) // stride_w + 1
out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1
out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1
col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
.transpose(0, 3, 4, 5, 1, 2)
img = np.zeros((batch_num,
channel,
height + 2 * pad + stride_h - 1,
width + 2 * pad + stride_w - 1)) \
height + pad_top + pad_bottom + stride_h - 1,
width + pad_left + pad_right + stride_w - 1)) \
.astype(col.dtype)
for y in range(filter_h):
y_max = y + stride_h * out_h
@ -186,7 +202,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
x_max = x + stride_h * out_w
img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
return img[:, :, pad:height + pad, pad:width + pad]
return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right]
def convolve(x, w, b=None, pad_mode="valid"):
@ -243,10 +259,21 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation_h = dilation[0]
dilation_w = dilation[1]
if isinstance(pad, int):
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
elif isinstance(pad, tuple) and len(pad) == 4:
pad_top, pad_bottom, pad_left, pad_right = pad
else:
raise ValueError(f"The \'pad\' should be an int number or "
f"a tuple of two or four int numbers, but got {pad}")
batch_num, _, x_h, x_w = x.shape
filter_num, _, filter_h, filter_w = weight.shape
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
col_w = np.reshape(weight, (filter_num, -1)).T
out = np.dot(col, col_w)
@ -348,11 +375,22 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
raise ValueError(f"The \'dilation\' should be an int number or "
f"a tuple of two or four int numbers, but got {dilation}")
batch_num, channel, height, width = img.shape
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
if isinstance(pad, int):
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
elif isinstance(pad, tuple) and len(pad) == 4:
pad_top, pad_bottom, pad_left, pad_right = pad
else:
raise ValueError(f"The \'pad\' should be an int number or "
f"a tuple of two or four int numbers, but got {pad}")
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
batch_num, channel, height, width = img.shape
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
for y in range(filter_h):