forked from mindspore-Ecosystem/mindspore
change_dropout_keep_prob_to_p_master
This commit is contained in:
parent
7d31b2b82d
commit
3001c0344c
|
@ -1,21 +1,21 @@
|
|||
mindspore.nn.Dropout
|
||||
====================
|
||||
|
||||
.. py:class:: mindspore.nn.Dropout(keep_prob=0.5, dtype=mstype.float32)
|
||||
.. py:class:: mindspore.nn.Dropout(keep_prob=0.5, p=None)
|
||||
|
||||
随机丢弃层。
|
||||
|
||||
Dropout是一种正则化手段,该算子根据丢弃概率 :math:`1 - keep\_prob`,在训练过程中随机将一些神经元输出设置为0,通过阻止神经元节点间的相关性来减少过拟合。在推理过程中,此层返回与 `x` 相同的Tensor。
|
||||
Dropout是一种正则化手段,该算子根据丢弃概率 `p` ,在训练过程中随机将一些神经元输出设置为0,通过阻止神经元节点间的相关性来减少过拟合。在推理过程中,此层返回与 `x` 相同的Tensor。
|
||||
|
||||
论文 `Dropout: A Simple Way to Prevent Neural Networks from Overfitting <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ 中提出了该技术,并证明其能有效地减少过度拟合,防止神经元共适应。更多详细信息,请参见 `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/pdf/1207.0580.pdf>`_ 。
|
||||
|
||||
.. note::
|
||||
训练过程中每步对同一通道(或神经元)独立进行丢弃。
|
||||
`dtype` 参数会在未来版本删除。不建议使用这个参数。
|
||||
- 训练过程中每步对同一通道(或神经元)独立进行丢弃。
|
||||
- `keep_prob` 参数会在未来版本删除,请使用 `p` 参数代替它。`p` 表示输入Tensor中元素设置成0的概率。
|
||||
|
||||
参数:
|
||||
- **keep_prob** (float) - 输入神经元保留率,数值范围在0到1之间。例如,rate=0.9,删除10%的神经元。默认值:0.5。
|
||||
- **dtype** (:class:`mindspore.dtype`) - `x` 的数据类型。默认值:mstype.float32。
|
||||
- **keep_prob** (float) - 废弃。输入神经元保留率,数值范围介于(0, 1]之间。例如,`keep_prob` =0.9,删除10%的神经元。默认值:0.5。
|
||||
- **p** (Union(float, int, None)) - 输入神经元丢弃率,数值范围介于[0, 1)之间。例如,`p` =0.9,删除90%的神经元。默认值:None。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - Dropout的输入,任意维度的Tensor。数据类型必须为float16或float32。
|
||||
|
@ -25,7 +25,8 @@ mindspore.nn.Dropout
|
|||
|
||||
异常:
|
||||
- **TypeError** - `keep_prob` 不是浮点数。
|
||||
- **TypeError** - `p` 数据类型不是float或int。
|
||||
- **TypeError** - `x` 的dtype既不是float16也不是float32。
|
||||
- **ValueError** - `keep_prob` 不在范围(0, 1]内。
|
||||
- **ValueError** - `keep_prob` 不在范围(0, 1]之间。
|
||||
- **ValueError** - `p` 不在范围[0, 1)之间。
|
||||
- **ValueError** - `x` 的shape长度小于1。
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class NiN(nn.Cell):
|
|||
nn.Conv2d(in_channels=160, out_channels=96, kernel_size=1, stride=1, has_bias=True),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'),
|
||||
nn.Dropout(1.0)
|
||||
nn.Dropout(p=0.0)
|
||||
)
|
||||
self.block1 = nn.SequentialCell(
|
||||
# block 1
|
||||
|
@ -46,7 +46,7 @@ class NiN(nn.Cell):
|
|||
nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'),
|
||||
nn.Dropout(1.0)
|
||||
nn.Dropout(p=0.0)
|
||||
)
|
||||
self.block2 = nn.SequentialCell(
|
||||
# block 2
|
||||
|
|
|
@ -46,7 +46,7 @@ class AlexNet(nn.Cell):
|
|||
self.fc1 = fc_with_initialize(20*3*3, 1024)
|
||||
self.fc2 = fc_with_initialize(1024, 1024)
|
||||
self.fc3 = fc_with_initialize(1024, num_classes)
|
||||
self.dropout = nn.Dropout(dropout_ratio)
|
||||
self.dropout = nn.Dropout(p=1-dropout_ratio)
|
||||
|
||||
def construct(self, x):
|
||||
"""define network"""
|
||||
|
|
|
@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=Fa
|
|||
|
||||
|
||||
n = Xception(num_classes=1000)
|
||||
n.dropout = nn.Dropout(keep_prob=1.0)
|
||||
n.dropout = nn.Dropout(p=0.0)
|
||||
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0,
|
||||
|
|
|
@ -20,7 +20,8 @@ import math
|
|||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, log as logger
|
||||
from mindspore import context
|
||||
from mindspore.log import logging
|
||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from mindspore.common.seed import _get_graph_seed
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -107,7 +108,7 @@ class Dropout(Cell):
|
|||
r"""
|
||||
Dropout layer for the input.
|
||||
|
||||
Randomly set some elements of the input tensor to zero with probability :math:`1 - keep\_prob` during training
|
||||
Randomly set some elements of the input tensor to zero with probability `p` during training
|
||||
using samples from a Bernoulli distribution.
|
||||
|
||||
The outputs are scaled by a factor of :math:`\frac{1}{keep\_prob}` during training so
|
||||
|
@ -121,13 +122,15 @@ class Dropout(Cell):
|
|||
<https://arxiv.org/pdf/1207.0580.pdf>`_.
|
||||
|
||||
Note:
|
||||
Each channel will be zeroed out independently on every construct call.
|
||||
Parameter `dtype` will be removed in a future version. It is not recommended to define this parameter.
|
||||
- Each channel will be zeroed out independently on every construct call.
|
||||
- Parameter `keep_prob` will be removed in a future version, please use parameter `p` instead.
|
||||
Parameter `p` means the probability of the element of the input tensor to be zeroed.
|
||||
|
||||
Args:
|
||||
keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
|
||||
dropping out 10% of input units. Default: 0.5.
|
||||
dtype (:class:`mindspore.dtype`): Data type of `x`. Default: mindspore.float32.
|
||||
keep_prob (float): Deprecated. The keep rate, greater than 0 and less equal than 1.
|
||||
E.g. rate=0.9, dropping out 10% of input neurons. Default: 0.5.
|
||||
p (Union(float, int, None)): The dropout rate, greater than or equal to 0 and less than 1.
|
||||
E.g. rate=0.9, dropping out 90% of input neurons. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of Dropout with data type of float16 or float32.
|
||||
|
@ -138,8 +141,10 @@ class Dropout(Cell):
|
|||
|
||||
Raises:
|
||||
TypeError: If `keep_prob` is not a float.
|
||||
TypeError: If the dtype of `p` is not float or int.
|
||||
TypeError: If dtype of `x` is not neither float16 nor float32.
|
||||
ValueError: If `keep_prob` is not in range (0, 1].
|
||||
ValueError: If `p` is not in range [0, 1).
|
||||
ValueError: If length of shape of `x` is less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
|
@ -147,45 +152,46 @@ class Dropout(Cell):
|
|||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
||||
>>> net = nn.Dropout(keep_prob=0.8)
|
||||
>>> net = nn.Dropout(p=0.2)
|
||||
>>> net.set_train()
|
||||
Dropout<keep_prob=0.8>
|
||||
>>> output = net(x)
|
||||
>>> print(output.shape)
|
||||
(2, 2, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||
def __init__(self, keep_prob=0.5, p=None):
|
||||
"""Initialize Dropout."""
|
||||
super(Dropout, self).__init__()
|
||||
Validator.check_value_type('keep_prob', keep_prob, [
|
||||
float], self.cls_name)
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
|
||||
f"but got {keep_prob}.")
|
||||
Validator.check_subclass(
|
||||
"dtype", dtype, mstype.number_type, self.cls_name)
|
||||
if dtype != mstype.float32:
|
||||
logger.info(
|
||||
"This parameter `dtype` will be deleted or invisible in the future. Please don't use it.")
|
||||
if p is None:
|
||||
logging.warning("This parameter `keep_prob` will be deprecated, please use `p` instead.")
|
||||
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
|
||||
f"but got {keep_prob}.")
|
||||
seed0, seed1 = _get_graph_seed(0, "dropout")
|
||||
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
||||
else:
|
||||
Validator.check_value_type('p', p, [float, int], self.cls_name)
|
||||
if p < 0 or p >= 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1), "
|
||||
f"but got {p}.")
|
||||
seed0, seed1 = _get_graph_seed(0, "dropout")
|
||||
self.dropout = P.Dropout(1.0 - p, seed0, seed1)
|
||||
self.p = p
|
||||
self.keep_prob = keep_prob
|
||||
seed0, seed1 = _get_graph_seed(0, "dropout")
|
||||
self.seed0 = seed0
|
||||
self.seed1 = seed1
|
||||
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
||||
|
||||
def construct(self, x):
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
if self.keep_prob == 1:
|
||||
if not self.training or self.keep_prob == 1 or self.p == 0:
|
||||
return x
|
||||
|
||||
out, _ = self.dropout(x)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
return 'keep_prob={}'.format(self.keep_prob)
|
||||
if self.p is None:
|
||||
logging.warning("This parameter `keep_prob` will be deprecated, please use `p` instead.")
|
||||
return f'keep_prob={self.keep_prob}'
|
||||
return f'p={self.p}'
|
||||
|
||||
|
||||
class Dropout1d(Cell):
|
||||
|
|
|
@ -408,7 +408,7 @@ class _RNNBase(Cell):
|
|||
self.batch_first = batch_first
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
self.dropout_op = nn.Dropout(p=float(dropout))
|
||||
self.bidirectional = bidirectional
|
||||
self.has_bias = has_bias
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
|
|
@ -273,14 +273,14 @@ class TransformerEncoderLayer(Cell):
|
|||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = _Linear(d_model, dim_feedforward)
|
||||
self.dropout = Dropout(1-dropout)
|
||||
self.dropout = Dropout(p=dropout)
|
||||
self.linear2 = _Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
||||
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
||||
self.dropout1 = Dropout(1-dropout)
|
||||
self.dropout2 = Dropout(1-dropout)
|
||||
self.dropout1 = Dropout(p=dropout)
|
||||
self.dropout2 = Dropout(p=dropout)
|
||||
|
||||
# Legacy string support for activation function.
|
||||
if isinstance(activation, str):
|
||||
|
@ -380,16 +380,16 @@ class TransformerDecoderLayer(Cell):
|
|||
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = _Linear(d_model, dim_feedforward)
|
||||
self.dropout = Dropout(1-dropout)
|
||||
self.dropout = Dropout(p=dropout)
|
||||
self.linear2 = _Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
||||
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
||||
self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
||||
self.dropout1 = Dropout(1-dropout)
|
||||
self.dropout2 = Dropout(1-dropout)
|
||||
self.dropout3 = Dropout(1-dropout)
|
||||
self.dropout1 = Dropout(p=dropout)
|
||||
self.dropout2 = Dropout(p=dropout)
|
||||
self.dropout3 = Dropout(p=dropout)
|
||||
|
||||
# Legacy string support for activation function.
|
||||
if isinstance(activation, str):
|
||||
|
|
|
@ -496,9 +496,9 @@ class FeedForward(Cell):
|
|||
else:
|
||||
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
|
||||
self.projection.bias.parallel_optimizer = False
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout_3d = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout_4d = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.dropout_3d = nn.Dropout(p=dropout_rate)
|
||||
self.dropout_4d = nn.Dropout(p=dropout_rate)
|
||||
self.cast = P.Cast()
|
||||
else:
|
||||
_check_config(parallel_config)
|
||||
|
@ -556,11 +556,11 @@ class FeedForward(Cell):
|
|||
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
|
||||
strategy_bias=((dp, 1), (1,)))
|
||||
self.projection.bias.parallel_optimizer = False
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.dropout.dropout.shard(((dp, 1),))
|
||||
self.dropout_3d = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout_3d = nn.Dropout(p=dropout_rate)
|
||||
self.dropout_3d.dropout.shard(((dp, 1, 1),))
|
||||
self.dropout_4d = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout_4d = nn.Dropout(p=dropout_rate)
|
||||
self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
@ -950,8 +950,8 @@ class MultiHeadAttention(Cell):
|
|||
# Normalize factor for attention, sqrt(dk) as widely used
|
||||
self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
|
||||
self.use_past = use_past
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
||||
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_rate)
|
||||
self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
|
||||
self.softmax = nn.Softmax().to_float(softmax_compute_type)
|
||||
self.softmax_3d = nn.Softmax().to_float(softmax_compute_type)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
|
@ -1051,9 +1051,9 @@ class MultiHeadAttention(Cell):
|
|||
# Normalize factor for attention, sqrt(dk) as widely used
|
||||
self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
|
||||
self.use_past = use_past
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_rate)
|
||||
self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
|
||||
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
||||
self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
|
||||
self.prob_dropout.dropout.shard(
|
||||
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
|
||||
self.softmax = nn.Softmax().to_float(softmax_compute_type)
|
||||
|
|
|
@ -251,7 +251,7 @@ class BertAttentionSoftmax(nn.Cell):
|
|||
self.weight = TruncatedNormal(initializer_range)
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=attention_probs_dropout_prob)
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
self.value_layer = nn.Dense(self.to_tensor_width,
|
||||
|
|
|
@ -70,7 +70,7 @@ class DNN(nn.Cell):
|
|||
dense_layer = nn.Dense(in_channels=self.hidden_units[i], out_channels=self.hidden_units[i + 1],
|
||||
activation=self.activation, weight_init="heUniform")
|
||||
dense_layers.append(dense_layer)
|
||||
drop_layer = nn.Dropout(1.0 - self.dropout_rate)
|
||||
drop_layer = nn.Dropout(p=self.dropout_rate)
|
||||
drop_layers.append(drop_layer)
|
||||
self.dense_layers = nn.CellList(dense_layers)
|
||||
self.drop_layers = nn.CellList(drop_layers)
|
||||
|
|
|
@ -91,7 +91,7 @@ class MultiheadAttention(nn.Cell):
|
|||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=attention_probs_dropout_prob)
|
||||
self.sub = P.Sub()
|
||||
self.add = P.TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
|
@ -192,7 +192,7 @@ class ResidualNorm(nn.Cell):
|
|||
|
||||
def __init__(self, size, dropout_prob=0.1):
|
||||
super(ResidualNorm, self).__init__()
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.add = P.TensorAdd()
|
||||
self.layernorm = nn.LayerNorm([size])
|
||||
self.out_shape = (-1, size)
|
||||
|
@ -213,7 +213,7 @@ class FeedForward(nn.Cell):
|
|||
def __init__(self, attention_size, intermediate_size,
|
||||
hidden_act, hidden_dropout_prob):
|
||||
super(FeedForward, self).__init__()
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_prob)
|
||||
self.linear1 = CustomDense(in_channels=attention_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
|
@ -303,7 +303,7 @@ class EncoderCell(nn.Cell):
|
|||
has_attention_mask=has_attention_mask,
|
||||
compute_type=compute_type)
|
||||
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_prob)
|
||||
self.intermediate = CustomDense(in_channels=size, out_channels=intermediate_size,
|
||||
activation=hidden_act, weight_init="zeros")
|
||||
self.res_norm = ResidualNorm(size, dropout_prob=hidden_dropout_prob)
|
||||
|
@ -345,7 +345,7 @@ class PositionalEncoding(nn.Cell):
|
|||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
xscale = math.sqrt(dim)
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.mul = P.Mul()
|
||||
self.add = P.TensorAdd()
|
||||
self.shape = P.Shape()
|
||||
|
@ -593,7 +593,7 @@ class CTC(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.adim = adim
|
||||
self.odim = odim
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.cast = P.Cast()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.ignore_id = ignore_id
|
||||
|
|
|
@ -708,7 +708,7 @@ class TransformerEncoderLayer(nn.Cell):
|
|||
self.feed_forward = feed_forward
|
||||
self.norm1 = CustomLayerNorm(size, epsilon=1e-5)
|
||||
self.norm2 = CustomLayerNorm(size, epsilon=1e-5)
|
||||
self.dropout = nn.Dropout(keep_prob=1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
|
@ -979,7 +979,7 @@ class DecoderLayer(nn.Cell):
|
|||
self.norm1 = CustomLayerNorm(size, epsilon=1e-12)
|
||||
self.norm2 = CustomLayerNorm(size, epsilon=1e-12)
|
||||
self.norm3 = CustomLayerNorm(size, epsilon=1e-12)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
|
@ -1216,7 +1216,7 @@ class PositionwiseFeedForward(nn.Cell):
|
|||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = Dense(idim, hidden_units).to_float(compute_type)
|
||||
self.activation = activation
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.w_2 = Dense(hidden_units, idim).to_float(compute_type)
|
||||
|
||||
def construct(self, xs):
|
||||
|
@ -1318,7 +1318,7 @@ class PositionalEncoding(nn.Cell):
|
|||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = Tensor([math.sqrt(self.d_model)], dtype=mstype.float32)
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.max_len = max_len
|
||||
|
||||
self.pe = np.zeros((self.max_len, self.d_model))
|
||||
|
@ -1399,7 +1399,7 @@ class MultiHeadedAttention(nn.Cell):
|
|||
self.linear_k = Dense(n_feat, n_feat).to_float(compute_type)
|
||||
self.linear_v = Dense(n_feat, n_feat).to_float(compute_type)
|
||||
self.linear_out = Dense(n_feat, n_feat).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(keep_prob=1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
|
|
|
@ -373,7 +373,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=ms.float32)
|
||||
self.multiply = ops.Mul()
|
||||
self.add = ops.Add()
|
||||
self.dropout = nn.Dropout(1 - dropout_prob, dtype=ms.float32)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.use_dropout = dropout_prob > 0
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size),
|
||||
|
@ -436,7 +436,7 @@ class LayerPostprocess(nn.Cell):
|
|||
dropout_prob=0.1):
|
||||
super(LayerPostprocess, self).__init__()
|
||||
self.add = ops.Add()
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.use_dropout = dropout_prob > 0
|
||||
|
||||
def construct(self, hidden_tensor, input_tensor):
|
||||
|
@ -535,7 +535,7 @@ class MultiheadAttention(nn.Cell):
|
|||
self.matmul = ops.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=attention_probs_dropout_prob)
|
||||
self.use_dropout = attention_probs_dropout_prob > 0
|
||||
|
||||
if self.has_attention_mask:
|
||||
|
@ -704,7 +704,7 @@ class FeedForward(nn.Cell):
|
|||
|
||||
self.reshape = ops.Reshape()
|
||||
self.shape = (-1, in_channels)
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_prob)
|
||||
self.use_dropout = hidden_dropout_prob > 0
|
||||
|
||||
def construct(self, input_tensor):
|
||||
|
|
|
@ -46,7 +46,7 @@ class MeanConv(nn.Cell):
|
|||
self.matmul = P.MatMul()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
self.dropout = nn.Dropout(keep_prob=1 - dropout)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def construct(self, self_feature, neigh_feature):
|
||||
neigh_matrix = self.reduce_mean(neigh_feature, 1)
|
||||
|
@ -72,7 +72,7 @@ class AttenConv(nn.Cell):
|
|||
self.matmul = P.MatMul()
|
||||
self.matmul_3 = P.BatchMatMul()
|
||||
self.matmul_t = P.BatchMatMul(transpose_b=True)
|
||||
self.dropout = nn.Dropout(keep_prob=1 - dropout)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def construct(self, self_feature, neigh_feature):
|
||||
query = self.expanddims(self_feature, 1)
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops.composite import GradOperation
|
|||
class DropoutNet(nn.Cell):
|
||||
def __init__(self, keep_prob):
|
||||
super(DropoutNet, self).__init__()
|
||||
self.drop = nn.Dropout(keep_prob)
|
||||
self.drop = nn.Dropout(p=1.0 - keep_prob)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
|
|
|
@ -160,7 +160,7 @@ class _BaseAggregator(nn.Cell):
|
|||
has_bias=self.has_bias)
|
||||
self.dropout_ratio = dropout_ratio
|
||||
if self.dropout_ratio is not None:
|
||||
self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
|
||||
self.dropout = nn.Dropout(p=1.0 - self.dropout_ratio)
|
||||
self.dropout_flag = self.dropout_ratio is not None
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
|
@ -263,8 +263,8 @@ class AttentionHead(nn.Cell):
|
|||
self.in_channel = Validator.check_positive_int(in_channel)
|
||||
self.out_channel = Validator.check_positive_int(out_channel)
|
||||
self.in_drop_ratio = in_drop_ratio
|
||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
self.in_drop = nn.Dropout(p=in_drop_ratio)
|
||||
self.in_drop_2 = nn.Dropout(p=in_drop_ratio)
|
||||
self.feature_transform = GNNFeatureTransform(
|
||||
in_channels=self.in_channel,
|
||||
out_channels=self.out_channel,
|
||||
|
@ -278,7 +278,7 @@ class AttentionHead(nn.Cell):
|
|||
out_channels=1)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
|
||||
self.coef_drop = nn.Dropout(p=coef_drop_ratio)
|
||||
self.batch_matmul = P.BatchMatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||
|
|
|
@ -23,7 +23,6 @@ import mindspore.context as context
|
|||
|
||||
from gnngraph_dataset import GraphDataset, GatherNet, CSRReduceSumNet
|
||||
|
||||
|
||||
DATASET_PATH = "/home/workspace/mindspore_dataset/cora/cora_mr/cora_v2_with_mask.npz"
|
||||
FEAT_DROPOUT = 0.5
|
||||
EDGE_DROPOUT = 0.5
|
||||
|
@ -47,7 +46,7 @@ class APPNPConv(ms.nn.Cell):
|
|||
super().__init__()
|
||||
self.k_ = k
|
||||
self.alpha_ = alpha
|
||||
self.edge_drop = ms.nn.Dropout(edge_drop)
|
||||
self.edge_drop = ms.nn.Dropout(p=1.0 - edge_drop)
|
||||
self.min_clip = Tensor(1, ms.int32)
|
||||
self.max_clip = Tensor(10000000, ms.int32)
|
||||
self.gather = GatherNet(indptr_backward, indices_backward)
|
||||
|
@ -86,7 +85,7 @@ class APPNPNet(nn.Cell):
|
|||
self.fc0 = nn.Dense(in_feats, hidden_dim, weight_init=XavierUniform())
|
||||
self.fc1 = nn.Dense(hidden_dim, n_classes, weight_init=XavierUniform())
|
||||
self.act = activation()
|
||||
self.feat_drop = nn.Dropout(feat_dropout)
|
||||
self.feat_drop = nn.Dropout(p=1.0 - feat_dropout)
|
||||
self.propagate = APPNPConv(k, alpha, edge_dropout, indptr_backward, indices_backward)
|
||||
|
||||
def construct(self, x, in_deg, out_deg, n_nodes, indptr, indices):
|
||||
|
|
|
@ -66,8 +66,8 @@ class GATConv(ms.nn.Cell):
|
|||
self.attn_d = ms.Parameter(initializer(XavierUniform(gain), [num_attn_head, out_size], ms.float32),
|
||||
name="attn_d")
|
||||
self.bias = ms.Parameter(initializer('zero', [num_attn_head, out_size], ms.float32), name='bias')
|
||||
self.feat_drop = ms.nn.Dropout(input_drop_out_rate)
|
||||
self.attn_drop = ms.nn.Dropout(attn_drop_out_rate)
|
||||
self.feat_drop = ms.nn.Dropout(p=1.0 - input_drop_out_rate)
|
||||
self.attn_drop = ms.nn.Dropout(p=1.0 - attn_drop_out_rate)
|
||||
self.leaky_relu = ms.nn.LeakyReLU(leaky_relu_slope)
|
||||
self.exp = ms.ops.Exp()
|
||||
if add_norm:
|
||||
|
|
|
@ -51,7 +51,7 @@ class GCNConv(ms.nn.Cell):
|
|||
self.activation = activation
|
||||
self.min_clip = Tensor(1, ms.int32)
|
||||
self.max_clip = Tensor(100000000, ms.int32)
|
||||
self.drop_out = ms.nn.Dropout(dropout)
|
||||
self.drop_out = ms.nn.Dropout(p=1.0 - dropout)
|
||||
self.gather = GatherNet(indptr_backward, indices_backward)
|
||||
self.csr_reduce_sum = CSRReduceSumNet(indices_backward)
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ class DenseLayer(nn.Cell):
|
|||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=1.0)
|
||||
self.dropout = Dropout(p=0.0)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
|
|
|
@ -180,7 +180,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
|
@ -230,7 +230,7 @@ class BertOutput(nn.Cell):
|
|||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
|
@ -433,7 +433,7 @@ class BertAttention(nn.Cell):
|
|||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=attention_probs_dropout_prob)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
|
|
|
@ -193,7 +193,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
|
@ -247,7 +247,7 @@ class BertOutput(nn.Cell):
|
|||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
|
@ -469,7 +469,7 @@ class BertAttention(nn.Cell):
|
|||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
self.dropout = nn.Dropout(p=attention_probs_dropout_prob)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
|
|
|
@ -300,7 +300,7 @@ class SingleDeepLabV3(nn.Cell):
|
|||
float(feature_shape[3])]
|
||||
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.dropout = nn.Dropout(keep_prob=0.9)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
self.shape = P.Shape()
|
||||
self.decoder_output_stride = decoder_output_stride
|
||||
if decoder_output_stride is not None:
|
||||
|
|
|
@ -380,7 +380,7 @@ class CellDropDense(nn.Cell):
|
|||
def __init__(self):
|
||||
super(CellDropDense, self).__init__()
|
||||
self.fc = nn.Dense(100, 100)
|
||||
self.drop = nn.Dropout(1.0 - 0.1)
|
||||
self.drop = nn.Dropout(p=0.1)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.fc(input_x)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_net():
|
|||
class Drop(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Drop, self).__init__()
|
||||
self.drop = nn.Dropout(1.0 - 0.5)
|
||||
self.drop = nn.Dropout(p=0.5)
|
||||
|
||||
def construct(self, out):
|
||||
out = self.drop(out)
|
||||
|
|
|
@ -144,7 +144,7 @@ class Conv2dNet(nn.Cell):
|
|||
class DropoutNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DropoutNet, self).__init__()
|
||||
self.drop = nn.Dropout(0.5)
|
||||
self.drop = nn.Dropout(p=0.5)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
|
|
|
@ -31,7 +31,7 @@ class EmbeddingPostprocessor(Cell):
|
|||
super(EmbeddingPostprocessor, self).__init__()
|
||||
self.layernorm = nn.LayerNorm((768,))
|
||||
self.add = P.Add()
|
||||
self.dropout = nn.Dropout(1 - 0.1)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
def construct(self, word_embeddings, token_type_embeddings, position_embeddings):
|
||||
output = word_embeddings
|
||||
|
|
|
@ -31,7 +31,7 @@ class BertAttentionPiece(Cell):
|
|||
def __init__(self):
|
||||
super(BertAttentionPiece, self).__init__()
|
||||
self.add = P.Add()
|
||||
self.dropout = nn.Dropout(1 - 0.1)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
self.softmax = nn.Softmax()
|
||||
self.multiply_data = -10000.0
|
||||
self.sub = P.Sub()
|
||||
|
|
|
@ -570,7 +570,7 @@ test_cases = [
|
|||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
}),
|
||||
('DropoutGrad', {
|
||||
'block': DropoutGrad(VirtualNetWithLoss(nn.Dropout())),
|
||||
'block': DropoutGrad(VirtualNetWithLoss(nn.Dropout(p=0.5))),
|
||||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
}),
|
||||
|
|
|
@ -3101,7 +3101,7 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[64, 12, 128, 128], Tensor(np.ones(1572864).astype(np.uint8))],
|
||||
'desc_bprop': [[64, 12, 128, 128]]}),
|
||||
('Dropout', {
|
||||
'block': nn.Dropout(0.5),
|
||||
'block': nn.Dropout(p=0.5),
|
||||
'desc_inputs': [[64, 12, 128, 128]],
|
||||
'desc_bprop': [[64, 12, 128, 128]]}),
|
||||
('ReduceMean0', {
|
||||
|
|
|
@ -135,7 +135,7 @@ test_case_reid_ops = [
|
|||
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
|
||||
'desc_bprop': [convert([256], np.bool_)]}),
|
||||
('Dropout', {
|
||||
'block': nn.Dropout(),
|
||||
'block': nn.Dropout(p=0.5),
|
||||
'desc_inputs': [[1, 512, 7, 7]],
|
||||
'desc_bprop': [[1, 512, 7, 7]]}),
|
||||
('MatMul', {
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_two_matmul_dropout():
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.dropout = nn.Dropout()
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.matmul2 = P.MatMul()
|
||||
|
||||
def construct(self, x, y, b):
|
||||
|
|
|
@ -193,8 +193,8 @@ class FeedForward(Cell):
|
|||
param_init_type=param_init_type)
|
||||
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
|
||||
self.projection.bias.parallel_optimizer = False
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout_3d = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.dropout_3d = nn.Dropout(p=dropout_rate)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -246,8 +246,8 @@ class MultiHeadAttention(Cell):
|
|||
self.mul = P.Mul()
|
||||
self.add = P.Add()
|
||||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
||||
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
||||
self.dropout = nn.Dropout(p=hidden_dropout_rate)
|
||||
self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
|
||||
self.softmax = nn.Softmax().to_float(softmax_compute_type)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
# Query
|
||||
|
@ -474,7 +474,7 @@ class EmbeddingLayer(nn.Cell):
|
|||
self.word_embedding = VocabEmbedding(vocab_size=40000, embedding_size=2560)
|
||||
self.position_embedding = VocabEmbedding(vocab_size=40000, embedding_size=2560)
|
||||
self.add = P.Add()
|
||||
self.dropout = nn.Dropout(0.9)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
def construct(self, input_ids, input_position, init_reset, batch_valid_length):
|
||||
word_embedding, word_table = self.word_embedding(input_ids)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_batch_parallel_dropout():
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.dropout = nn.Dropout()
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.matmul2 = P.MatMul()
|
||||
|
||||
def construct(self, x, y, b):
|
||||
|
|
|
@ -206,9 +206,9 @@ class Mlp(nn.Cell):
|
|||
self.fc2.matmul.shard(((dp, mp), (1, mp)))
|
||||
self.fc2.bias_add.shard(((dp, 1), (1,)))
|
||||
|
||||
self.drop = nn.Dropout(1.0-drop)
|
||||
self.drop = nn.Dropout(p=drop)
|
||||
self.drop.dropout.shard(((dp, 1),))
|
||||
self.drop2 = nn.Dropout(1.0-drop)
|
||||
self.drop2 = nn.Dropout(p=drop)
|
||||
self.drop2.dropout.shard(((dp, mp),))
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -263,14 +263,14 @@ class Attention(nn.Cell):
|
|||
self.softmax.softmax.shard(((dp, mp, 1, 1),))
|
||||
|
||||
self.batmatmul_trans_b = P.BatchMatMul().shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
|
||||
self.attn_drop = nn.Dropout(1. - attn_drop)
|
||||
self.attn_drop = nn.Dropout(p=attn_drop)
|
||||
self.attn_drop.dropout.shard(((dp, mp, 1, 1),))
|
||||
|
||||
self.proj = nn.Dense(hidden_dim, dim, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
|
||||
self.proj.matmul.shard(((dp, mp), (1, mp)))
|
||||
self.proj.bias_add.shard(((dp, 1), (1,)))
|
||||
|
||||
self.proj_drop = nn.Dropout(1. - proj_drop)
|
||||
self.proj_drop = nn.Dropout(p=proj_drop)
|
||||
self.proj_drop.dropout.shard(((dp, 1),))
|
||||
|
||||
self.transpose = P.Transpose().shard(((dp, 1, mp, 1),))
|
||||
|
|
|
@ -25,14 +25,14 @@ context.set_context(device_target="Ascend")
|
|||
|
||||
def test_check_dropout():
|
||||
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
|
||||
m = nn.Dropout(0.8)
|
||||
m = nn.Dropout(p=0.2)
|
||||
m(x)
|
||||
|
||||
|
||||
class Net_Dropout(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net_Dropout, self).__init__()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
|
||||
def construct(self, x):
|
||||
return self.dropout(x)
|
||||
|
|
Loading…
Reference in New Issue