!44543 add nn and function api for dropout1d

Merge pull request !44543 from ZhidanLiu/master
This commit is contained in:
i-robot 2022-10-27 08:00:13 +00:00 committed by Gitee
commit e720b5e117
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 335 additions and 3 deletions

View File

@ -156,6 +156,7 @@ Dropout层
:template: classtemplate.rst
mindspore.nn.Dropout
mindspore.nn.Dropout1d
mindspore.nn.Dropout2d
mindspore.nn.Dropout3d

View File

@ -28,6 +28,7 @@ mindspore.ops.function
mindspore.ops.crop_and_resize
mindspore.ops.deformable_conv2d
mindspore.ops.dropout
mindspore.ops.dropout1d
mindspore.ops.dropout2d
mindspore.ops.dropout3d
mindspore.ops.flatten

View File

@ -0,0 +1,29 @@
mindspore.nn.Dropout1d
========================
.. py:class:: mindspore.nn.Dropout1d(p=0.5)
在训练期间,以服从伯努利分布的概率 `p` 随机将输入Tensor的某些通道归零。对于shape为 `NCL` 的三维Tensor
其通道特征图指的是后一维 `L` 的一维特征图)。
例如,在批处理输入中 :math:`i\_th` 批, :math:`j\_th` 通道的 `input[i, j]` `1D` Tensor 是一个待处理数据。
每个通道将会独立依据伯努利分布概率 `p` 来确定是否被清零。
论文 `Dropout: A Simple Way to Prevent Neural Networks from Overfitting <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ 中提出了该技术,并证明其能有效地减少过度拟合,防止神经元共适应。更多详细信息,请参见 `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/pdf/1207.0580.pdf>`_
`dropout1d` 可以提高通道特征映射之间的独立性。
参数:
- **p** (float) - 通道的丢弃概率,介于 0 和 1 之间,例如 `p` = 0.8意味着80%的清零概率。默认值0.5。
输入:
- **x** (Tensor) - 一个shape为 :math:`(N, C, L)`:math:`(C, L)``3D``2D` Tensor其中N是批处理大小`C` 是通道数,`L` 是特征长度。其数据类型应为int8、int16、int32、int64、float16、float32或float64。
输出:
Tensor输出具有与输入 `x` 相同的shape和数据类型。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型不是int8、int16、int32、int64、float16、float32或float64。
- **TypeError** - `p` 的数据类型不是float。
- **ValueError** - `p` 值不在 `[0.01.0]` 之间。
- **ValueError** - `x` 的维度不是 `2D``3D`

View File

@ -0,0 +1,28 @@
mindspore.ops.dropout1d
========================
.. py:function:: mindspore.ops.dropout1d(x, p=0.5)
在训练期间,以服从伯努利分布的概率 `p` 随机将输入Tensor的某些通道归零。对于shape为 `NCL` 的三维Tensor
其通道特征图指的是后一维 `L` 的一维特征图)。
例如,在批处理输入中 :math:`i\_th` 批, :math:`j\_th` 通道的 `input[i, j]` `1D` Tensor 是一个待处理数据。
每个通道将会独立依据伯努利分布概率 `p` 来确定是否被清零。
论文 `Dropout: A Simple Way to Prevent Neural Networks from Overfitting <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ 中提出了该技术,并证明其能有效地减少过度拟合,防止神经元共适应。更多详细信息,请参见 `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/pdf/1207.0580.pdf>`_
`dropout1d` 可以提高通道特征映射之间的独立性。
参数:
- **x** (Tensor) - 一个shape为 :math:`(N, C, L)`:math:`(C, L)``3D``2D` Tensor其中N是批处理大小`C` 是通道数,`L` 是特征长度。其数据类型应为int8、int16、int32、int64、float16、float32或float64。
- **p** (float) - 通道的丢弃概率,介于 0 和 1 之间,例如 `p` = 0.8意味着80%的清零概率。默认值0.5。
- **training** (bool) - 若为True则启用dropout功能。默认值True。
返回:
Tensor输出具有与输入 `x` 相同的shape和数据类型。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型不是int8、int16、int32、int64、float16、float32或float64。
- **TypeError** - `p` 的数据类型不是float。
- **ValueError** - `p` 值不在 `[0.01.0]` 之间。
- **ValueError** - `x` 的维度不是 `2D``3D`

View File

@ -156,6 +156,7 @@ Dropout Layer
:template: classtemplate.rst
mindspore.nn.Dropout
mindspore.nn.Dropout1d
mindspore.nn.Dropout2d
mindspore.nn.Dropout3d

View File

@ -28,6 +28,7 @@ Neural Network
mindspore.ops.crop_and_resize
mindspore.ops.deformable_conv2d
mindspore.ops.dropout
mindspore.ops.dropout1d
mindspore.ops.dropout2d
mindspore.ops.dropout3d
mindspore.ops.flatten

View File

@ -36,8 +36,8 @@ from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Dropout2d',
'Dropout3d', 'Roll', 'Identity', 'Unflatten']
'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Dropout1d',
'Dropout2d', 'Dropout3d', 'Roll', 'Identity', 'Unflatten']
class L1Regularizer(Cell):
@ -182,6 +182,84 @@ class Dropout(Cell):
return 'keep_prob={}'.format(self.keep_prob)
class Dropout1d(Cell):
r"""
During training, randomly zeroes entire channels of the input tensor with probability `p`
from a Bernoulli distribution (For a 3-dimensional tensor with a shape of :math:`NCL`,
the channel feature map refers to a 1-dimensional feature map with the shape of :math:`L`).
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
`1D` tensor input[i,j].
Each channel will be zeroed out independently on every forward call with probability `p` using samples
from a Bernoulli distribution.
The parper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ mentioned this technologyAnd it is proved that
it can effectively reduce over fitting and prevent neuronal coadaptation.
For more details, refer to `Improving neural networks by preventing co-adaptation of feature detectors
<https://arxiv.org/pdf/1207.0580.pdf>`_ .
`Dropout1d` can improve the independence between channel feature maps.
Args:
p (float): The dropping probability of a channel, between 0 and 1, e.g. `p` = 0.8,
which means an 80% chance of clearing. Default: 0.5.
Inputs:
- **x** (Tensor) - A tensor with shape :math:`(N, C, L)` or :math:`(C, L)`, where `N` is the batch size,
`C` is the number of channels, `L` is the feature length. The data type must be int8, int16, int32,
int64, float16, float32 or float64.
Returns:
Tensor, output, with the same shape and data type as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not int8, int16, int32, int64, float16, float32 or float64.
TypeError: If the data type of `p` is not float.
ValueError: If `p` is out of the range `[0.0, 1.0]`.
ValueError: If `x` shape is not `2D` or `3D`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.random.randn(4, 3), mindspore.float32)
>>> output = dropout1d(input_x, 0.5)
>>> print(output.shape)
(4, 3)
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> dropout = nn.Dropout1d(p=0.5)
>>> x = Tensor(np.ones([4, 3]), mindspore.float32)
>>> output = dropout(x)
>>> print(output.shape)
(4, 3)
"""
def __init__(self, p=0.5):
"""Initialize Dropout1d."""
super(Dropout1d, self).__init__()
Validator.check_value_type('p', p, [float], self.cls_name)
if p < 0 or p > 1:
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1], "
f"but got {p}.")
self.prob = p
def construct(self, x):
if not self.training:
return x
if self.prob == 0:
return x
out = F.dropout1d(x, self.prob)
return out
class Dropout2d(Cell):
r"""
During training, randomly zeroes some channels of the input tensor with probability `p`
@ -1678,5 +1756,5 @@ class Unflatten(Cell):
new_shape += input_shape[: self.axis]
new_shape += self.unflattened_size
if self.axis != -1:
new_shape += input_shape[self.axis+1 :]
new_shape += input_shape[self.axis + 1:]
return self.reshape(input_x, new_shape)

View File

@ -301,6 +301,7 @@ from .nn_func import (
bias_add,
binary_cross_entropy,
binary_cross_entropy_with_logits,
dropout1d,
dropout2d,
dropout3d,
deformable_conv2d,

View File

@ -740,6 +740,69 @@ def celu(x, alpha=1.0):
return celu_op(x)
def dropout1d(x, p=0.5, training=True):
r"""
During training, randomly zeroes some channels of the input tensor with probability `p`
from a Bernoulli distribution(For a 3-dimensional tensor with a shape of :math:`NCL`,
the channel feature map refers to a 1-dimensional feature map with the shape of :math:`L`).
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
`1D` tensor input[i,j].
Each channel will be zeroed out independently on every forward call which based on Bernoulli distribution
probability `p`.
The parper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ mentioned this technologyAnd it is proved that
it can effectively reduce over fitting and prevent neuronal coadaptation.
For more details, refer to `Improving neural networks by preventing co-adaptation of feature detectors
<https://arxiv.org/pdf/1207.0580.pdf>`_ .
`dropout1d` can improve the independence between channel feature maps.
Args:
x (Tensor): A tensor with shape :math:`(N, C, L)` or :math:`(C, L)`, where `N` is the batch size, `C` is the
number of channels, `L` is the feature length. The data type must be int8, int16, int32, int64, float16,
float32 or float64.
p (float): The dropping probability of a channel, between 0 and 1, e.g. `p` = 0.8,
which means an 80% chance of clearing. Default: 0.5.
training (bool): Apply dropout if is True. Default: True.
Returns:
Tensor, output, with the same shape and data type as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not int8, int16, int32, int64, float16, float32 or float64.
TypeError: If the data type of `p` is not float.
ValueError: If `p` is out of the range `[0.0, 1.0]`.
ValueError: If `x` shape is not `2D` or `3D`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.random.randn(4, 3), mindspore.float32)
>>> output = dropout1d(input_x, 0.5)
>>> print(output.shape)
(4, 3)
"""
if not training:
p = 0
dropout_2d_op = NN_OPS.Dropout2D(1.0 - p)
if len(x.shape) == 2:
x = x.expand_dims(0)
x = x.expand_dims(-1)
out, _ = dropout_2d_op(x)
out = out.squeeze(-1)
out = out.squeeze(0)
else:
x = x.expand_dims(-1)
out, _ = dropout_2d_op(x)
out = out.squeeze(-1)
return out
def dropout2d(x, p=0.5):
r"""
During training, randomly zeroes some channels of the input tensor with probability `p`
@ -3707,6 +3770,7 @@ __all__ = [
'kl_div',
'celu',
'deformable_conv2d',
'dropout1d',
'dropout2d',
'dropout3d',
'fast_gelu',

View File

@ -0,0 +1,86 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore
from mindspore import nn
import mindspore.ops as ops
import mindspore.context as context
class Net(nn.Cell):
"""Net used to test nn.Dropout1d"""
def __init__(self, p):
super(Net, self).__init__()
self.dropout1d = nn.Dropout1d(p)
def construct(self, x):
return self.dropout1d(x)
class FNet(nn.Cell):
"""Net used to test ops.dropout1d"""
def __init__(self, p):
super(FNet, self).__init__()
self.p = p
def construct(self, x):
out = ops.dropout1d(x, self.p)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_dropout1d(mode):
"""
Feature: dropout1d
Description: Verify the result of Dropout1d
Expectation: success
"""
context.set_context(mode=mode)
x = np.random.randn(4, 3)
dropout = Net(p=1.0)
x = Tensor(x, mindspore.float32)
dropout.set_train()
output = dropout(x)
expect = np.zeros((4, 3))
np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_f_dropout1d(mode):
"""
Feature: function api dropout1d
Description: Verify the result of dropout1d
Expectation: success
"""
context.set_context(mode=mode)
x = np.random.randn(4, 3)
x = Tensor(x, mindspore.float32)
net = FNet(p=1.0)
output = net(x)
expect = np.zeros((4, 3))
np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,42 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test pooling api
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _cell_graph_executor
class Dropout1dNet(nn.Cell):
def __init__(self, p):
super(Dropout1dNet, self).__init__()
self.dropout1 = nn.Dropout1d(p)
def construct(self, x):
return self.dropout1(x)
def test_dropout1_normal():
"""
Feature: dropout1d
Description: Verify the result of Dropout1d
Expectation: success
"""
x = Tensor(np.random.randn(4, 3).astype(np.float32))
net = Dropout1dNet(p=0.5)
_cell_graph_executor.compile(net, x)