forked from mindspore-Ecosystem/mindspore
delete group parameter from nn.DepthwiseConv2d
This commit is contained in:
parent
990c645c85
commit
8337ae710e
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""conv"""
|
"""conv"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -20,7 +21,7 @@ from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
|
@ -810,8 +811,7 @@ class DepthwiseConv2d(Cell):
|
||||||
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||||
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
|
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
|
||||||
:math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
|
:math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
|
||||||
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
|
:math:`(C_{out}, C_{in}, \text{ks_h}, \text{ks_w})` to split the input in the channel dimension.
|
||||||
to split the input in the channel dimension.
|
|
||||||
|
|
||||||
If the 'pad_mode' is set to be "valid", the output height and width will be
|
If the 'pad_mode' is set to be "valid", the output height and width will be
|
||||||
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
||||||
|
@ -854,8 +854,6 @@ class DepthwiseConv2d(Cell):
|
||||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
be greater than or equal to 1 and bounded by the height and width of the
|
be greater than or equal to 1 and bounded by the height and width of the
|
||||||
input. Default: 1.
|
input. Default: 1.
|
||||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
|
||||||
divisible by the number of groups. Default: 1.
|
|
||||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
|
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
|
||||||
|
@ -887,7 +885,6 @@ class DepthwiseConv2d(Cell):
|
||||||
pad_mode='same',
|
pad_mode='same',
|
||||||
padding=0,
|
padding=0,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
group=1,
|
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros'):
|
bias_init='zeros'):
|
||||||
|
@ -897,13 +894,9 @@ class DepthwiseConv2d(Cell):
|
||||||
self.dilation = twice(dilation)
|
self.dilation = twice(dilation)
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = check_int_positive(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = check_int_positive(out_channels)
|
||||||
validator.check_integer('group', group, in_channels, Rel.EQ)
|
|
||||||
validator.check_integer('group', group, out_channels, Rel.EQ)
|
|
||||||
validator.check_integer('group', group, 1, Rel.GE)
|
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
self.group = group
|
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
self.weight_init = weight_init
|
self.weight_init = weight_init
|
||||||
self.bias_init = bias_init
|
self.bias_init = bias_init
|
||||||
|
@ -931,10 +924,10 @@ class DepthwiseConv2d(Cell):
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
||||||
'pad_mode={}, padding={}, dilation={}, group={},' \
|
'pad_mode={}, padding={}, dilation={}' \
|
||||||
'has_bias={}, weight_init={}, bias_init={}'.format(
|
'has_bias={}, weight_init={}, bias_init={}'.format(
|
||||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
self.pad_mode, self.padding, self.dilation, self.group,
|
self.pad_mode, self.padding, self.dilation,
|
||||||
self.has_bias, self.weight_init, self.bias_init)
|
self.has_bias, self.weight_init, self.bias_init)
|
||||||
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
|
Loading…
Reference in New Issue