forked from mindspore-Ecosystem/mindspore
!11174 add_mish_mulnonan_selu_operations
From: @jiangzg001 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
8170669909
|
@ -416,6 +416,58 @@ def get_bprop_dropout_do_mask(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Mish)
|
||||
def get_bprop_mish(self):
|
||||
"""Grad definition for `Mish` operation."""
|
||||
tanh = P.Tanh()
|
||||
tanh_grad = SG.TanhGrad()
|
||||
softplus = P.Softplus()
|
||||
softplus_grad = G.SoftplusGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx1 = tanh(softplus(x))
|
||||
dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
|
||||
dx = (dx1 * dout + dx2)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SeLU)
|
||||
def get_bprop_selu(self):
|
||||
"""Grad definition for `SeLU` operation."""
|
||||
scale = 1.0507009873554804934193349852946
|
||||
elu_grad = G.EluGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = elu_grad(dout, out) * scale
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MulNoNan)
|
||||
def get_bprop_mul_no_nan(self):
|
||||
"""Grad definition for `MulNoNan` operation."""
|
||||
mul_no_nan = P.MulNoNan()
|
||||
reduce_sum = P.ReduceSum()
|
||||
reshape = P.Reshape()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
x_shape = F.shape(x)
|
||||
y_shape = F.shape(y)
|
||||
dx = mul_no_nan(dout, y)
|
||||
dy = mul_no_nan(x, dout)
|
||||
broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
|
||||
if broadcast_x != ():
|
||||
dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
|
||||
if broadcast_y != ():
|
||||
dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
|
||||
return dx, dy
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ReLU)
|
||||
def get_bprop_relu(self):
|
||||
"""Grad definition for `ReLU` operation."""
|
||||
|
|
|
@ -356,3 +356,6 @@ from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe
|
|||
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe
|
||||
from .nll_loss import _nll_loss_tbe
|
||||
from .nll_loss_grad import _nll_loss_grad_tbe
|
||||
from .mish import _mish_tbe
|
||||
from .mul_no_nan import _mul_no_nan_tbe
|
||||
from .selu import _selu_tbe
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Mish op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
mish_op_info = TBERegOp("Mish") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("mish.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("mish") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mish_op_info)
|
||||
def _mish_tbe():
|
||||
"""Mish TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MulNoNan op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
mul_no_nan_op_info = TBERegOp("MulNoNan") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("mul_no_nan.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("mul_no_nan") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mul_no_nan_op_info)
|
||||
def _mul_no_nan_tbe():
|
||||
"""MulNoNan TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Selu op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
selu_op_info = TBERegOp("Selu") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("selu.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("selu") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.dtype_format(DataType.I8_None, DataType.I8_None) \
|
||||
.dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(selu_op_info)
|
||||
def _selu_tbe():
|
||||
"""Selu TBE register"""
|
||||
return
|
|
@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
|
||||
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
||||
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||
LogicalNot, LogicalOr, MatMul, Maximum,
|
||||
LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan,
|
||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
|
@ -70,8 +70,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
LogSoftmax,
|
||||
MaxPool, DataFormatDimMap,
|
||||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid,
|
||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid, SeLU,
|
||||
SigmoidCrossEntropyWithLogits, NLLLoss,
|
||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
|
@ -194,6 +194,9 @@ __all__ = [
|
|||
'ZerosLike',
|
||||
'Select',
|
||||
'Split',
|
||||
'Mish',
|
||||
'SeLU',
|
||||
'MulNoNan',
|
||||
'ReLU',
|
||||
'ReLU6',
|
||||
'Elu',
|
||||
|
|
|
@ -2035,6 +2035,58 @@ class DivNoNan(_MathBinaryOp):
|
|||
return None
|
||||
|
||||
|
||||
class MulNoNan(_MathBinaryOp):
|
||||
r"""
|
||||
Computes x * y element-wise. if y is zero, No matter what x is, it will return 0.
|
||||
|
||||
Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent.
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors, the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar, the scalar could only be a constant.
|
||||
|
||||
Note:
|
||||
The shapes of X and y should be same or can be broadcasting.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
|
||||
- **input_y** (Union[Tensor]) - The second input is a tensor whose data type is number.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the one after broadcasting,
|
||||
and the data type is the one with higher precision or higher digits among the two inputs.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Raise:
|
||||
TypeError: If x or y is a bool tensor.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[-1.0, 6.0, np.inf], [np.nan, -7.0, 4.0]]), ms.float32)
|
||||
>>> y = Tensor(np.array([[-1.0, 4.0, 0], [0, -3.0, 1.0]]), ms.float32)
|
||||
>>> mul_no_nan = ops.MulNoNan()
|
||||
>>> output = mul_no_nan(x, y)
|
||||
>>> print(output)
|
||||
[[ 1. 24. 0.]
|
||||
[ 0. 21. 4.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize _BinaryOp"""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_value(self, x, y):
|
||||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
out = np.multiply(x, y)
|
||||
out[y == 0] = 0
|
||||
return out
|
||||
return None
|
||||
|
||||
|
||||
class FloorDiv(_MathBinaryOp):
|
||||
"""
|
||||
Divides the first input tensor by the second input tensor element-wise and round down to the closest integer.
|
||||
|
@ -4041,6 +4093,7 @@ class LinSpace(PrimitiveWithInfer):
|
|||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class MatrixInverse(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
|
||||
|
|
|
@ -329,6 +329,99 @@ class ReLU(PrimitiveWithCheck):
|
|||
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
|
||||
|
||||
|
||||
class Mish(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes MISH of input tensors element-wise.
|
||||
|
||||
The function is shown as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{output} = x * \tan(\log(1 + \exp(\text{x})))
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input tensor. Only support float16 and float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Raise:
|
||||
TypeError: If num_features data type not float16 and float32 Tensor.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||
>>> mish = ops.Mish()
|
||||
>>> output = mish(input_x)
|
||||
>>> print(output)
|
||||
[[-3.034014e-01 3.997413e+00 -2.682209e-03]
|
||||
[ 1.943959e+00 -3.357619e-02 8.999999e+00]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Mish"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SeLU(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise.
|
||||
|
||||
The activation function is defined as:
|
||||
|
||||
.. math::
|
||||
E_{i} =
|
||||
scale *
|
||||
\begin{cases}
|
||||
x, &\text{if } x \geq 0; \cr
|
||||
\text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
|
||||
\end{cases}
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Raise:
|
||||
TypeError: If num_features data type not int8, int32, float16 and float32 Tensor.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||
>>> selu = ops.SeLU()
|
||||
>>> output = selu(input_x)
|
||||
>>> print(output)
|
||||
[[-1.1113307 4.202804 -1.7575096]
|
||||
[ 2.101402 -1.7462534 9.456309 ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SeLU"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ReLU6(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise.
|
||||
|
|
|
@ -320,6 +320,42 @@ class CountNonZero(nn.Cell):
|
|||
return nonzero_num
|
||||
|
||||
|
||||
class Mish(nn.Cell):
|
||||
"""Mish net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(Mish, self).__init__()
|
||||
self.mish = P.Mish()
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.mish(input_x)
|
||||
return out
|
||||
|
||||
|
||||
class SeLU(nn.Cell):
|
||||
"""Selu net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(SeLU, self).__init__()
|
||||
self.selu = P.SeLU()
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.selu(input_x)
|
||||
return out
|
||||
|
||||
|
||||
class MulNoNan(nn.Cell):
|
||||
"""MulNoNan net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(MulNoNan, self).__init__()
|
||||
self.mul_no_nan = P.MulNoNan()
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
out = self.mul_no_nan(input_x, input_y)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterUpdate(nn.Cell):
|
||||
"""ScatterUpdate net definition"""
|
||||
|
||||
|
@ -1315,6 +1351,19 @@ test_case_math_ops = [
|
|||
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
|
||||
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}),
|
||||
('Mish', {
|
||||
'block': Mish(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||
('SeLU', {
|
||||
'block': SeLU(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||
('MulNoNan', {
|
||||
'block': MulNoNan(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32),
|
||||
Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||
('Rank', {
|
||||
'block': P.Rank(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
Loading…
Reference in New Issue