From 661f9dfaf80c072eb81000c5d5cdcafbfc8e8404 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Sat, 9 May 2020 15:14:13 +0800 Subject: [PATCH] add dropout primtive --- mindspore/nn/layer/basic.py | 12 ++++ mindspore/ops/_grad/grad_nn_ops.py | 14 ++++ mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/nn_ops.py | 65 +++++++++++++++++++ tests/ut/python/nn/test_dropout.py | 2 + .../python/pynative_mode/nn/test_dropout.py | 12 ++-- 6 files changed, 102 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 9c8de85a68b..6d8d287b825 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -25,6 +25,7 @@ from mindspore.ops.operations import _inner_ops as inner from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore.common.api import ms_function +from mindspore import context from ..cell import Cell from .activation import get_activation from ..._checkparam import Validator as validator @@ -84,8 +85,19 @@ class Dropout(Cell): self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1) self.dropout_do_mask = P.DropoutDoMask() self.cast = P.Cast() + self.is_gpu = context.get_context('device_target') in ["GPU"] + + if self.is_gpu: + self.dropout = P.Dropout(keep_prob) def construct(self, x): + if not self.training: + return x + + if self.is_gpu: + out, _ = self.dropout(x) + return out + shape = self.get_shape(x) dtype = P.DType()(x) keep_prob = self.cast(self.keep_prob, dtype) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 153abc0fb63..6a8454a7de3 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -643,3 +643,17 @@ def get_bprop_binary_cross_entropy(self): return dx, zeros_like(y), zeros_like(weight) return bprop + + +@bprop_getters.register(P.Dropout) +def get_bprop_dropout(self): + """Grad definition for `Dropout` operation.""" + grad = P.DropoutGrad(self.drop_prob) + + def bprop(x, out, dout): + _, mask = out + dy, _ = dout + dx = grad(dy, mask) + return (dx,) + + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index d83f5accd09..d38f84225ad 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -52,7 +52,7 @@ from .random_ops import (RandomChoiceWithMask) from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, - DropoutDoMask, + DropoutDoMask, DropoutGrad, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, Gelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, @@ -157,6 +157,8 @@ __all__ = [ 'Shape', 'DropoutDoMask', 'DropoutGenMask', + 'DropoutGrad', + 'Dropout', 'Neg', 'Slice', 'DType', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2a2dbe08a8f..9f1c4169fbc 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2762,3 +2762,68 @@ class ConfusionMulGrad(PrimitiveWithInfer): validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) return input0_dtype, input1_dtype + + +class Dropout(PrimitiveWithInfer): + """ + During training, randomly zeroes some of the elements of the input tensor with probability. + + Args: + drop_prob (float): probability of an element to be zeroed. Default: 0. + + Inputs: + - **shape** (tuple[int]) - The shape of target mask. + + Outputs: + Tensor, the value of generated mask for input shape. + + Examples: + >>> dropout = P.Dropout(drop_prob=0.5) + >>> in = Tensor((20, 16, 50, 50)) + >>> out = dropout(in) + """ + @prim_attr_register + def __init__(self, drop_prob=0): + self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) + + def infer_shape(self, x_shape): + validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) + mask_shape = x_shape + return x_shape, mask_shape + + def infer_dtype(self, x_dtype): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) + return x_dtype, x_dtype + + +class DropoutGrad(PrimitiveWithInfer): + """ + The gradient of Dropout. During training, randomly zeroes some of the elements + of the input tensor with probability. + + Args: + drop_prob (float): probability of an element to be zeroed. Default: 0. + + Inputs: + - **shape** (tuple[int]) - The shape of target mask. + + Outputs: + Tensor, the value of generated mask for input shape. + + Examples: + >>> dropout_grad = P.DropoutGrad(drop_prob=0.5) + >>> in = Tensor((20, 16, 50, 50)) + >>> out = dropout_grad(in) + """ + @prim_attr_register + def __init__(self, drop_prob=0): + self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) + + def infer_shape(self, dy_shape, mask_shape): + return dy_shape + + def infer_dtype(self, dy_dtype, mask_dtype): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) + return dy_dtype diff --git a/tests/ut/python/nn/test_dropout.py b/tests/ut/python/nn/test_dropout.py index 81a20db6f1b..ec67a4c77bf 100644 --- a/tests/ut/python/nn/test_dropout.py +++ b/tests/ut/python/nn/test_dropout.py @@ -17,7 +17,9 @@ import numpy as np import pytest import mindspore.nn as nn from mindspore import Tensor +from mindspore import context +context.set_context(device_target="Ascend") def test_check_dropout_3(): Tensor(np.ones([20, 16, 50]).astype(np.int32)) diff --git a/tests/ut/python/pynative_mode/nn/test_dropout.py b/tests/ut/python/pynative_mode/nn/test_dropout.py index d4c3d47dbaa..ef026b5fc83 100644 --- a/tests/ut/python/pynative_mode/nn/test_dropout.py +++ b/tests/ut/python/pynative_mode/nn/test_dropout.py @@ -19,26 +19,26 @@ from mindspore.common.api import _executor import mindspore.nn as nn from mindspore import Tensor from mindspore import dtype as mstype +from mindspore import context + +context.set_context(device_target="Ascend") def test_check_dropout_1(): x = Tensor(np.ones([20, 16, 50]), mstype.float32) m = nn.Dropout(0.8) - with pytest.raises(NotImplementedError): - m(x) + m(x) def test_check_dropout_2(): x = Tensor(np.ones([20, 16, 50]), mstype.float32) m = nn.Dropout(0.3, seed0=1) - with pytest.raises(NotImplementedError): - m(x) + m(x) def test_check_dropout_3(): x = Tensor(np.ones([20, 16, 50]), mstype.float32) m = nn.Dropout(0.3, seed0=1, seed1=1) - with pytest.raises(NotImplementedError): - m(x) + m(x) class Net_Dropout(nn.Cell):