forked from OSSInnovation/mindspore
!3474 fix bug for con1d with 3d input.
Merge pull request !3474 from liuxiao93/fix-conv1d
This commit is contained in:
commit
73a677be44
|
@ -13,10 +13,13 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""conv"""
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||
|
@ -254,6 +257,11 @@ class Conv2d(_Conv):
|
|||
return s
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_3d(input_shape):
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError(f"Input should be 3d, but got shape {input_shape}")
|
||||
|
||||
class Conv1d(_Conv):
|
||||
r"""
|
||||
1D convolution layer.
|
||||
|
@ -359,6 +367,15 @@ class Conv1d(_Conv):
|
|||
kernel_size = (1, kernel_size)
|
||||
stride = (1, stride)
|
||||
dilation = (1, dilation)
|
||||
get_shape = P.Shape()
|
||||
get_dtype = P.DType()
|
||||
if isinstance(weight_init, Tensor):
|
||||
weight_init_shape = get_shape(weight_init)
|
||||
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
|
||||
weight_init_dtype = get_dtype(weight_init)
|
||||
weight_init_value = weight_init.asnumpy()
|
||||
weight_init_value = np.expand_dims(weight_init_value, 2)
|
||||
weight_init = Tensor(weight_init_value, weight_init_dtype)
|
||||
|
||||
super(Conv1d, self).__init__(
|
||||
in_channels,
|
||||
|
@ -391,13 +408,13 @@ class Conv1d(_Conv):
|
|||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
if len(x_shape) == 3:
|
||||
x = self.expand_dims(x, 2)
|
||||
_check_input_3d(x_shape)
|
||||
x = self.expand_dims(x, 2)
|
||||
output = self.conv2d(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if len(x_shape) == 3:
|
||||
output = self.squeeze(output)
|
||||
|
||||
output = self.squeeze(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
|
@ -669,6 +686,15 @@ class Conv1dTranspose(_Conv):
|
|||
kernel_size = (1, kernel_size)
|
||||
stride = (1, stride)
|
||||
dilation = (1, dilation)
|
||||
get_shape = P.Shape()
|
||||
get_dtype = P.DType()
|
||||
if isinstance(weight_init, Tensor):
|
||||
weight_init_shape = get_shape(weight_init)
|
||||
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
|
||||
weight_init_dtype = get_dtype(weight_init)
|
||||
weight_init_value = weight_init.asnumpy()
|
||||
weight_init_value = np.expand_dims(weight_init_value, 2)
|
||||
weight_init = Tensor(weight_init_value, weight_init_dtype)
|
||||
# out_channels and in_channels swap.
|
||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
|
||||
# then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
|
||||
|
@ -733,8 +759,8 @@ class Conv1dTranspose(_Conv):
|
|||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
if len(x_shape) == 3:
|
||||
x = self.expand_dims(x, 2)
|
||||
_check_input_3d(x_shape)
|
||||
x = self.expand_dims(x, 2)
|
||||
|
||||
n, _, h, w = self.shape(x)
|
||||
|
||||
|
@ -746,8 +772,7 @@ class Conv1dTranspose(_Conv):
|
|||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
|
||||
if len(x_shape) == 3:
|
||||
output = self.squeeze(output)
|
||||
output = self.squeeze(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
|
|
|
@ -1690,7 +1690,9 @@ class L2Loss(PrimitiveWithInfer):
|
|||
Set `input_x` as x and output as loss.
|
||||
|
||||
.. math::
|
||||
loss = sum(x ** 2) / 2
|
||||
loss = sum(x ** 2) / nelement(x)
|
||||
|
||||
:math:`nelement(x)` represents the number of `input_x`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - A input Tensor.
|
||||
|
|
Loading…
Reference in New Issue