add Dropout3d ops for aicpu

This commit is contained in:
yanzhenxiang2020 2021-01-14 16:22:31 +08:00
parent f679fcf075
commit f8147aa57d
6 changed files with 162 additions and 2 deletions

View File

@ -51,8 +51,9 @@ constexpr auto kCacheSwapTable = "CacheSwapTable";
constexpr auto kSubAndFilter = "SubAndFilter";
constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCustRunApi = "RunCpuKernel";
constexpr auto kDropout3d = "Dropout3d";
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3d};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

View File

@ -27,6 +27,7 @@ from .unique_with_pad import _unique_with_pad_aicpu
from .sub_and_filter import _sub_and_filter_aicpu
from .pad_and_shift import _pad_and_shift_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .dropout3d import _dropout3d_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dropout3d op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout3d_op_info = AiCPURegOp("Dropout3d") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.attr("keep_prob", "float") \
.attr("inplace", "bool") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(dropout3d_op_info)
def _dropout3d_aicpu():
"""Dropout3d AiCPU register"""
return

View File

@ -63,7 +63,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, Dropout, DropoutGenMask, Flatten,
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
Gelu, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,

View File

@ -6242,6 +6242,58 @@ class Dropout(PrimitiveWithInfer):
return x_dtype, x_dtype
class Dropout3d(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the channels of the input tensor
with probability keep_prob from a Bernoulli distribution.
Args:
keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
means dropping out %20 of channels. Default: 0.5.
inplace (bool): When `inplace` is True, this operation will be done in-place. Default: False.
Inputs:
- **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`.
When `inplace` is True, `input` should be Parameter.
Outputs:
- **output** (Tensor) - with the same shape as the input tensor.
Raises:
TypeError: If the data type of `keep_prob` is not float.
ValueError: If `keep_prob` is out of the range [0.0, 1.0];
or if the dim of input is not 5-D.
Supported Platforms:
``Ascend``
Examples:
>>> dropout = ops.Dropout3d(keep_prob=0.5)
>>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32)
>>> output = dropout(x)
>>> print(output)
[[[[[0. 0.]]
[[0. 0.]]]]
[[[[-2.98 -0.01]]
[[-0.34 1.57]]]]]
"""
@prim_attr_register
def __init__(self, keep_prob=0.5, inplace=False):
self.inplace = validator.check_value_type("inplace", inplace, [bool], self.name)
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 5, Rel.GE, "dim of input", self.name)
return x_shape
def infer_dtype(self, x_dtype):
valid_dtypes = mstype.number_type + (mstype.bool_,)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
class CTCLoss(PrimitiveWithInfer):
"""
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, keep_prob, inplace):
super(Net, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
def construct(self, x):
return self.drop(x)
class NetInplace(nn.Cell):
def __init__(self, keep_prob, inplace, x):
super(NetInplace, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
self.x = x
def construct(self):
return self.drop(self.x)
def test_net_float32():
x = Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32)
net = Net(0.7, False)
output = net(x)
print(x)
print(output)
y = (output.asnumpy() == x.asnumpy()/0.7).reshape(3*4, 3*3*3)
for i in range(3*4):
if not y[i].all():
assert y[i].sum() == 0
def test_net_float32_inplace():
x = mindspore.Parameter(Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32))
net = NetInplace(0.7, True, x)
output = net()
print(Tensor(x))
print(output)
assert np.array_equal(x.asnumpy(), output.asnumpy())