!12724 add Dropout2D and rename Dropout3d to Dropout3D

From: @yanzhenxiang2020
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-15 14:29:55 +08:00 committed by Gitee
commit 5a6bb251b0
9 changed files with 278 additions and 57 deletions

View File

@ -52,9 +52,11 @@ constexpr auto kCacheSwapTable = "CacheSwapTable";
constexpr auto kSubAndFilter = "SubAndFilter";
constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCustRunApi = "RunCpuKernel";
constexpr auto kDropout3d = "Dropout3d";
constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D";
const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3d};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
kPadAndShift, kDropout3D, kDropout2D};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

View File

@ -1188,6 +1188,46 @@ def get_bprop_dropout(self):
return bprop
@bprop_getters.register(P.Dropout2D)
def get_bprop_dropout2d(self):
"""Grad definition for `Dropout2D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()
keep_prob = self.keep_prob
def bprop(x, out, dout):
_, mask = dout
y = cast(mask, mstype.float32)
if keep_prob != 0:
y = y * (1 / keep_prob)
y = mul(x, y)
y = cast(y, dtype(x))
return (y,)
return bprop
@bprop_getters.register(P.Dropout3D)
def get_bprop_dropout3d(self):
"""Grad definition for `Dropout3D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()
keep_prob = self.keep_prob
def bprop(x, out, dout):
_, mask = dout
y = cast(mask, mstype.float32)
if keep_prob != 0:
y = y * (1 / keep_prob)
y = mul(x, y)
y = cast(y, dtype(x))
return (y,)
return bprop
@bprop_getters.register(P.CTCLoss)
def get_bprop_ctc_loss(self):
"""Grad definition for `CTCLoss` operation"""

View File

@ -27,6 +27,7 @@ from .unique_with_pad import _unique_with_pad_aicpu
from .sub_and_filter import _sub_and_filter_aicpu
from .pad_and_shift import _pad_and_shift_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .dropout2d import _dropout2d_aicpu
from .dropout3d import _dropout3d_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dropout2D op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout2d_op_info = AiCPURegOp("Dropout2D") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.output(1, "mask", "required") \
.attr("keep_prob", "float") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(dropout2d_op_info)
def _dropout2d_aicpu():
"""Dropout2D AiCPU register"""
return

View File

@ -13,30 +13,30 @@
# limitations under the License.
# ============================================================================
"""Dropout3d op"""
"""Dropout3D op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout3d_op_info = AiCPURegOp("Dropout3d") \
dropout3d_op_info = AiCPURegOp("Dropout3D") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.output(1, "mask", "required") \
.attr("keep_prob", "float") \
.attr("inplace", "bool") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(dropout3d_op_info)
def _dropout3d_aicpu():
"""Dropout3d AiCPU register"""
"""Dropout3D AiCPU register"""
return

View File

@ -64,7 +64,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
GeLU, Gelu, FastGeLU, FastGelu, Elu,
@ -252,6 +252,8 @@ __all__ = [
'DropoutDoMask',
'DropoutGenMask',
'Dropout',
'Dropout2D',
'Dropout3D',
'Neg',
'InplaceAdd',
'InplaceSub',

View File

@ -7055,22 +7055,77 @@ class Dropout(PrimitiveWithCheck):
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
class Dropout3d(PrimitiveWithInfer):
class Dropout2D(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the channels of the input tensor
with probability keep_prob from a Bernoulli distribution.
with probability 1-`keep_prob` from a Bernoulli distribution.
Args:
keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
means dropping out %20 of channels. Default: 0.5.
inplace (bool): When `inplace` is True, this operation will be done in-place. Default: False.
means dropping out 20% of channels. Default: 0.5.
Inputs:
- **input** (Tensor) - A 4-D tensor with shape :math:`(N, C, H, W)`.
Outputs:
- **output** (Tensor) - with the same shape and data type as the input tensor.
- **mask** (Tensor[bool]) - with the same shape as the input tensor.
Raises:
TypeError: If the data type of `keep_prob` is not float.
ValueError: If `keep_prob` is out of the range [0.0, 1.0];
or if the dim of input is not 4-D.
Supported Platforms:
``Ascend``
Examples:
>>> dropout = ops.Dropout2D(keep_prob=0.5)
>>> x = Tensor(np.random.randn(2, 1, 2, 3), mindspore.float32)
>>> output, mask = dropout(x)
>>> print(output)
[[[[0. 0. 0.]
[0. 0. 0.]]]
[[[0.88 -2.98 -0.01]
[2.16 -0.34 1.57]]]]
>>> print(mask)
[[[[False False False]
[False False False]]]
[[[True True True]
[True True True]]]]
"""
@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 4, Rel.EQ, "dim of input", self.name)
return x_shape, x_shape
def infer_dtype(self, x_dtype):
valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
mask_dtype = mstype.tensor_type(mstype.bool_)
return x_dtype, mask_dtype
class Dropout3D(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the channels of the input tensor
with probability 1-`keep_prob` from a Bernoulli distribution.
Args:
keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
means dropping out 20% of channels. Default: 0.5.
Inputs:
- **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`.
When `inplace` is True, `input` should be Parameter.
Outputs:
- **output** (Tensor) - with the same shape as the input tensor.
- **output** (Tensor) - with the same shape and data type as the input tensor.
- **mask** (Tensor[bool]) - with the same shape as the input tensor.
Raises:
TypeError: If the data type of `keep_prob` is not float.
@ -7081,30 +7136,35 @@ class Dropout3d(PrimitiveWithInfer):
``Ascend``
Examples:
>>> dropout = ops.Dropout3d(keep_prob=0.5)
>>> dropout = ops.Dropout3D(keep_prob=0.5)
>>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32)
>>> output = dropout(x)
>>> output, mask = dropout(x)
>>> print(output)
[[[[[0. 0.]]
[[0. 0.]]]]
[[[[-2.98 -0.01]]
[[-0.34 1.57]]]]]
>>> print(mask)
[[[[[False False]]
[[False False]]]]
[[[[True True]]
[[True True]]]]]
"""
@prim_attr_register
def __init__(self, keep_prob=0.5, inplace=False):
self.inplace = validator.check_value_type("inplace", inplace, [bool], self.name)
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 5, Rel.GE, "dim of input", self.name)
return x_shape
validator.check_int(len(x_shape), 5, Rel.EQ, "dim of input", self.name)
return x_shape, x_shape
def infer_dtype(self, x_dtype):
valid_dtypes = mstype.number_type + (mstype.bool_,)
valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
mask_dtype = mstype.tensor_type(mstype.bool_)
return x_dtype, mask_dtype
class CTCLoss(PrimitiveWithInfer):

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dtype = np.float16
x0 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype))
x1 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype))
class Net(nn.Cell):
def __init__(self, keep_prob):
super(Net, self).__init__()
self.drop = P.Dropout2D(keep_prob=keep_prob)
def construct(self, x):
return self.drop(x)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
self.network.set_train()
def construct(self, x, y):
return self.grad(self.network)(x, y)
def test_net_float32():
net = Net(0.7)
output, mask = net(x0)
print(x0)
print(output)
y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3)
output_reshape = output.asnumpy().reshape(3*4, 3*3)
for i in range(3*4):
if not y[i].all():
assert output_reshape[i].sum() == 0
return output, mask
def test_net_grad():
net = Grad(Net(0.7))
y = test_net_float32()
output = net(x1, y)
print("input: ", x1)
print("forward output: ", y)
print("backward output: ", output)

View File

@ -13,52 +13,57 @@
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dtype = np.float16
x0 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype))
x1 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype))
class Net(nn.Cell):
def __init__(self, keep_prob, inplace):
def __init__(self, keep_prob):
super(Net, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
self.drop = P.Dropout3D(keep_prob=keep_prob)
def construct(self, x):
return self.drop(x)
class NetInplace(nn.Cell):
def __init__(self, keep_prob, inplace, x):
super(NetInplace, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
self.x = x
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
self.network.set_train()
def construct(self):
return self.drop(self.x)
def construct(self, x, y):
return self.grad(self.network)(x, y)
def test_net_float32():
x = Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32)
net = Net(0.7, False)
output = net(x)
print(x)
net = Net(0.7)
output, mask = net(x0)
print(x0)
print(output)
y = (output.asnumpy() == x.asnumpy()/0.7).reshape(3*4, 3*3*3)
y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3*3)
output_reshape = output.asnumpy().reshape(3*4, 3*3*3)
for i in range(3*4):
if not y[i].all():
assert y[i].sum() == 0
assert output_reshape[i].sum() == 0
return output, mask
def test_net_float32_inplace():
x = mindspore.Parameter(Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32))
net = NetInplace(0.7, True, x)
output = net()
print(Tensor(x))
print(output)
assert np.array_equal(x.asnumpy(), output.asnumpy())
def test_net_grad():
net = Grad(Net(0.7))
y = test_net_float32()
output = net(x1, y)
print("input: ", x1)
print("forward output: ", y)
print("backward output: ", output)