forked from mindspore-Ecosystem/mindspore
add dropout primtive
This commit is contained in:
parent
3d3b9d5474
commit
661f9dfaf8
|
@ -25,6 +25,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import context
|
||||
from ..cell import Cell
|
||||
from .activation import get_activation
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -84,8 +85,19 @@ class Dropout(Cell):
|
|||
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1)
|
||||
self.dropout_do_mask = P.DropoutDoMask()
|
||||
self.cast = P.Cast()
|
||||
self.is_gpu = context.get_context('device_target') in ["GPU"]
|
||||
|
||||
if self.is_gpu:
|
||||
self.dropout = P.Dropout(keep_prob)
|
||||
|
||||
def construct(self, x):
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
if self.is_gpu:
|
||||
out, _ = self.dropout(x)
|
||||
return out
|
||||
|
||||
shape = self.get_shape(x)
|
||||
dtype = P.DType()(x)
|
||||
keep_prob = self.cast(self.keep_prob, dtype)
|
||||
|
|
|
@ -643,3 +643,17 @@ def get_bprop_binary_cross_entropy(self):
|
|||
return dx, zeros_like(y), zeros_like(weight)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Dropout)
|
||||
def get_bprop_dropout(self):
|
||||
"""Grad definition for `Dropout` operation."""
|
||||
grad = P.DropoutGrad(self.drop_prob)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
_, mask = out
|
||||
dy, _ = dout
|
||||
dx = grad(dy, mask)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -52,7 +52,7 @@ from .random_ops import (RandomChoiceWithMask)
|
|||
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||
BiasAdd, Conv2D,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask,
|
||||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
DropoutGenMask, Flatten, FusedBatchNorm,
|
||||
Gelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss,
|
||||
|
@ -157,6 +157,8 @@ __all__ = [
|
|||
'Shape',
|
||||
'DropoutDoMask',
|
||||
'DropoutGenMask',
|
||||
'DropoutGrad',
|
||||
'Dropout',
|
||||
'Neg',
|
||||
'Slice',
|
||||
'DType',
|
||||
|
|
|
@ -2762,3 +2762,68 @@ class ConfusionMulGrad(PrimitiveWithInfer):
|
|||
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
|
||||
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
|
||||
return input0_dtype, input1_dtype
|
||||
|
||||
|
||||
class Dropout(PrimitiveWithInfer):
|
||||
"""
|
||||
During training, randomly zeroes some of the elements of the input tensor with probability.
|
||||
|
||||
Args:
|
||||
drop_prob (float): probability of an element to be zeroed. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple[int]) - The shape of target mask.
|
||||
|
||||
Outputs:
|
||||
Tensor, the value of generated mask for input shape.
|
||||
|
||||
Examples:
|
||||
>>> dropout = P.Dropout(drop_prob=0.5)
|
||||
>>> in = Tensor((20, 16, 50, 50))
|
||||
>>> out = dropout(in)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, drop_prob=0):
|
||||
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
|
||||
mask_shape = x_shape
|
||||
return x_shape, mask_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
|
||||
return x_dtype, x_dtype
|
||||
|
||||
|
||||
class DropoutGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
The gradient of Dropout. During training, randomly zeroes some of the elements
|
||||
of the input tensor with probability.
|
||||
|
||||
Args:
|
||||
drop_prob (float): probability of an element to be zeroed. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple[int]) - The shape of target mask.
|
||||
|
||||
Outputs:
|
||||
Tensor, the value of generated mask for input shape.
|
||||
|
||||
Examples:
|
||||
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5)
|
||||
>>> in = Tensor((20, 16, 50, 50))
|
||||
>>> out = dropout_grad(in)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, drop_prob=0):
|
||||
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
def infer_shape(self, dy_shape, mask_shape):
|
||||
return dy_shape
|
||||
|
||||
def infer_dtype(self, dy_dtype, mask_dtype):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
|
||||
return dy_dtype
|
||||
|
|
|
@ -17,7 +17,9 @@ import numpy as np
|
|||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_check_dropout_3():
|
||||
Tensor(np.ones([20, 16, 50]).astype(np.int32))
|
||||
|
|
|
@ -19,26 +19,26 @@ from mindspore.common.api import _executor
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_check_dropout_1():
|
||||
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
|
||||
m = nn.Dropout(0.8)
|
||||
with pytest.raises(NotImplementedError):
|
||||
m(x)
|
||||
m(x)
|
||||
|
||||
|
||||
def test_check_dropout_2():
|
||||
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
|
||||
m = nn.Dropout(0.3, seed0=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
m(x)
|
||||
m(x)
|
||||
|
||||
|
||||
def test_check_dropout_3():
|
||||
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
|
||||
m = nn.Dropout(0.3, seed0=1, seed1=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
m(x)
|
||||
m(x)
|
||||
|
||||
|
||||
class Net_Dropout(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue