From 95dbfe0636601f8ff1f6576307f561a3ab91f64b Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Fri, 22 Jan 2021 15:18:37 +0800 Subject: [PATCH] add Dropout2D and rename Dropout3d to Dropout3D --- .../kernel_compiler/aicpu/aicpu_util.h | 6 +- mindspore/ops/_grad/grad_nn_ops.py | 40 +++++++++ mindspore/ops/_op_impl/aicpu/__init__.py | 1 + mindspore/ops/_op_impl/aicpu/dropout2d.py | 42 +++++++++ mindspore/ops/_op_impl/aicpu/dropout3d.py | 32 +++---- mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/nn_ops.py | 88 ++++++++++++++++--- .../ascend/test_aicpu_ops/test_dropout2d.py | 69 +++++++++++++++ .../ascend/test_aicpu_ops/test_dropout3d.py | 53 ++++++----- 9 files changed, 278 insertions(+), 57 deletions(-) create mode 100644 mindspore/ops/_op_impl/aicpu/dropout2d.py create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index fb054e5b896..cd6a8ce9fc3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -52,9 +52,11 @@ constexpr auto kCacheSwapTable = "CacheSwapTable"; constexpr auto kSubAndFilter = "SubAndFilter"; constexpr auto kPadAndShift = "PadAndShift"; constexpr auto kCustRunApi = "RunCpuKernel"; -constexpr auto kDropout3d = "Dropout3d"; +constexpr auto kDropout2D = "Dropout2D"; +constexpr auto kDropout3D = "Dropout3D"; const std::set kCustAiCpuKernelOps{kIdentity}; -const std::set kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3d}; +const std::set kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, + kPadAndShift, kDropout3D, kDropout2D}; struct AicpuParamHead { uint32_t length; // Total length: include cunstom message diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index f3006a8fb3a..ba68df82f6a 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -1263,6 +1263,46 @@ def get_bprop_dropout(self): return bprop +@bprop_getters.register(P.Dropout2D) +def get_bprop_dropout2d(self): + """Grad definition for `Dropout2D` operation.""" + dtype = P.DType() + cast = P.Cast() + mul = P.Mul() + keep_prob = self.keep_prob + + def bprop(x, out, dout): + _, mask = dout + y = cast(mask, mstype.float32) + if keep_prob != 0: + y = y * (1 / keep_prob) + y = mul(x, y) + y = cast(y, dtype(x)) + return (y,) + + return bprop + + +@bprop_getters.register(P.Dropout3D) +def get_bprop_dropout3d(self): + """Grad definition for `Dropout3D` operation.""" + dtype = P.DType() + cast = P.Cast() + mul = P.Mul() + keep_prob = self.keep_prob + + def bprop(x, out, dout): + _, mask = dout + y = cast(mask, mstype.float32) + if keep_prob != 0: + y = y * (1 / keep_prob) + y = mul(x, y) + y = cast(y, dtype(x)) + return (y,) + + return bprop + + @bprop_getters.register(P.CTCLoss) def get_bprop_ctc_loss(self): """Grad definition for `CTCLoss` operation""" diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 958e029428c..d5e18adf3cc 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -27,6 +27,7 @@ from .unique_with_pad import _unique_with_pad_aicpu from .sub_and_filter import _sub_and_filter_aicpu from .pad_and_shift import _pad_and_shift_aicpu from .dropout_genmask import _dropout_genmask_aicpu +from .dropout2d import _dropout2d_aicpu from .dropout3d import _dropout3d_aicpu from .get_next import _get_next_aicpu from .print_tensor import _print_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/dropout2d.py b/mindspore/ops/_op_impl/aicpu/dropout2d.py new file mode 100644 index 00000000000..dd898f4cf0f --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/dropout2d.py @@ -0,0 +1,42 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Dropout2D op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +dropout2d_op_info = AiCPURegOp("Dropout2D") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .output(1, "mask", "required") \ + .attr("keep_prob", "float") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(dropout2d_op_info) +def _dropout2d_aicpu(): + """Dropout2D AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/dropout3d.py b/mindspore/ops/_op_impl/aicpu/dropout3d.py index ed686acea11..2d4ae2ea74e 100644 --- a/mindspore/ops/_op_impl/aicpu/dropout3d.py +++ b/mindspore/ops/_op_impl/aicpu/dropout3d.py @@ -13,30 +13,30 @@ # limitations under the License. # ============================================================================ -"""Dropout3d op""" +"""Dropout3D op""" from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType -dropout3d_op_info = AiCPURegOp("Dropout3d") \ +dropout3d_op_info = AiCPURegOp("Dropout3D") \ .fusion_type("OPAQUE") \ .input(0, "x", "required") \ .output(0, "y", "required") \ + .output(1, "mask", "required") \ .attr("keep_prob", "float") \ - .attr("inplace", "bool") \ - .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \ .get_op_info() @op_info_register(dropout3d_op_info) def _dropout3d_aicpu(): - """Dropout3d AiCPU register""" + """Dropout3D AiCPU register""" return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 527b2249c59..e914a91e8d4 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -64,7 +64,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, - DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten, + DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, GeLU, Gelu, FastGeLU, FastGelu, Elu, @@ -243,6 +243,8 @@ __all__ = [ 'DropoutDoMask', 'DropoutGenMask', 'Dropout', + 'Dropout2D', + 'Dropout3D', 'Neg', 'InplaceAdd', 'InplaceSub', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 22f66c8d07d..784fea9a3d8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6657,22 +6657,77 @@ class Dropout(PrimitiveWithCheck): validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) -class Dropout3d(PrimitiveWithInfer): +class Dropout2D(PrimitiveWithInfer): """ During training, randomly zeroes some of the channels of the input tensor - with probability keep_prob from a Bernoulli distribution. + with probability 1-`keep_prob` from a Bernoulli distribution. Args: keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8, - means dropping out %20 of channels. Default: 0.5. - inplace (bool): When `inplace` is True, this operation will be done in-place. Default: False. + means dropping out 20% of channels. Default: 0.5. + + Inputs: + - **input** (Tensor) - A 4-D tensor with shape :math:`(N, C, H, W)`. + + Outputs: + - **output** (Tensor) - with the same shape and data type as the input tensor. + - **mask** (Tensor[bool]) - with the same shape as the input tensor. + + Raises: + TypeError: If the data type of `keep_prob` is not float. + ValueError: If `keep_prob` is out of the range [0.0, 1.0]; + or if the dim of input is not 4-D. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> dropout = ops.Dropout2D(keep_prob=0.5) + >>> x = Tensor(np.random.randn(2, 1, 2, 3), mindspore.float32) + >>> output, mask = dropout(x) + >>> print(output) + [[[[0. 0. 0.] + [0. 0. 0.]]] + [[[0.88 -2.98 -0.01] + [2.16 -0.34 1.57]]]] + >>> print(mask) + [[[[False False False] + [False False False]]] + [[[True True True] + [True True True]]]] + """ + + @prim_attr_register + def __init__(self, keep_prob=0.5): + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) + + def infer_shape(self, x_shape): + validator.check_int(len(x_shape), 4, Rel.EQ, "dim of input", self.name) + return x_shape, x_shape + + def infer_dtype(self, x_dtype): + valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) + mask_dtype = mstype.tensor_type(mstype.bool_) + return x_dtype, mask_dtype + + +class Dropout3D(PrimitiveWithInfer): + """ + During training, randomly zeroes some of the channels of the input tensor + with probability 1-`keep_prob` from a Bernoulli distribution. + + Args: + keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8, + means dropping out 20% of channels. Default: 0.5. Inputs: - **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`. - When `inplace` is True, `input` should be Parameter. Outputs: - - **output** (Tensor) - with the same shape as the input tensor. + - **output** (Tensor) - with the same shape and data type as the input tensor. + - **mask** (Tensor[bool]) - with the same shape as the input tensor. Raises: TypeError: If the data type of `keep_prob` is not float. @@ -6683,30 +6738,35 @@ class Dropout3d(PrimitiveWithInfer): ``Ascend`` Examples: - >>> dropout = ops.Dropout3d(keep_prob=0.5) + >>> dropout = ops.Dropout3D(keep_prob=0.5) >>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32) - >>> output = dropout(x) + >>> output, mask = dropout(x) >>> print(output) [[[[[0. 0.]] [[0. 0.]]]] [[[[-2.98 -0.01]] [[-0.34 1.57]]]]] + >>> print(mask) + [[[[[False False]] + [[False False]]]] + [[[[True True]] + [[True True]]]]] """ @prim_attr_register - def __init__(self, keep_prob=0.5, inplace=False): - self.inplace = validator.check_value_type("inplace", inplace, [bool], self.name) + def __init__(self, keep_prob=0.5): self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) def infer_shape(self, x_shape): - validator.check_int(len(x_shape), 5, Rel.GE, "dim of input", self.name) - return x_shape + validator.check_int(len(x_shape), 5, Rel.EQ, "dim of input", self.name) + return x_shape, x_shape def infer_dtype(self, x_dtype): - valid_dtypes = mstype.number_type + (mstype.bool_,) + valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) - return x_dtype + mask_dtype = mstype.tensor_type(mstype.bool_) + return x_dtype, mask_dtype class CTCLoss(PrimitiveWithInfer): diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py b/tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py new file mode 100644 index 00000000000..82f6a1bfc20 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +dtype = np.float16 +x0 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype)) +x1 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype)) + + +class Net(nn.Cell): + def __init__(self, keep_prob): + super(Net, self).__init__() + self.drop = P.Dropout2D(keep_prob=keep_prob) + + def construct(self, x): + return self.drop(x) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + self.network.set_train() + + def construct(self, x, y): + return self.grad(self.network)(x, y) + + +def test_net_float32(): + net = Net(0.7) + output, mask = net(x0) + print(x0) + print(output) + + y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3) + output_reshape = output.asnumpy().reshape(3*4, 3*3) + for i in range(3*4): + if not y[i].all(): + assert output_reshape[i].sum() == 0 + return output, mask + + +def test_net_grad(): + net = Grad(Net(0.7)) + y = test_net_float32() + output = net(x1, y) + print("input: ", x1) + print("forward output: ", y) + print("backward output: ", output) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py b/tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py index 8cf5b93fae6..cdc933cc097 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py @@ -13,52 +13,57 @@ # limitations under the License. # ============================================================================ import numpy as np - -import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops.composite import GradOperation context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +dtype = np.float16 +x0 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype)) +x1 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype)) + class Net(nn.Cell): - def __init__(self, keep_prob, inplace): + def __init__(self, keep_prob): super(Net, self).__init__() - self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace) + self.drop = P.Dropout3D(keep_prob=keep_prob) def construct(self, x): return self.drop(x) -class NetInplace(nn.Cell): - def __init__(self, keep_prob, inplace, x): - super(NetInplace, self).__init__() - self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace) - self.x = x +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + self.network.set_train() - def construct(self): - return self.drop(self.x) + def construct(self, x, y): + return self.grad(self.network)(x, y) def test_net_float32(): - x = Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32) - net = Net(0.7, False) - output = net(x) - print(x) + net = Net(0.7) + output, mask = net(x0) + print(x0) print(output) - y = (output.asnumpy() == x.asnumpy()/0.7).reshape(3*4, 3*3*3) + y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3*3) + output_reshape = output.asnumpy().reshape(3*4, 3*3*3) for i in range(3*4): if not y[i].all(): - assert y[i].sum() == 0 + assert output_reshape[i].sum() == 0 + return output, mask -def test_net_float32_inplace(): - x = mindspore.Parameter(Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32)) - net = NetInplace(0.7, True, x) - output = net() - print(Tensor(x)) - print(output) - assert np.array_equal(x.asnumpy(), output.asnumpy()) +def test_net_grad(): + net = Grad(Net(0.7)) + y = test_net_float32() + output = net(x1, y) + print("input: ", x1) + print("forward output: ", y) + print("backward output: ", output)