forked from mindspore-Ecosystem/mindspore
!11174 add_mish_mulnonan_selu_operations
From: @jiangzg001 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
8170669909
|
@ -416,6 +416,58 @@ def get_bprop_dropout_do_mask(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.Mish)
|
||||||
|
def get_bprop_mish(self):
|
||||||
|
"""Grad definition for `Mish` operation."""
|
||||||
|
tanh = P.Tanh()
|
||||||
|
tanh_grad = SG.TanhGrad()
|
||||||
|
softplus = P.Softplus()
|
||||||
|
softplus_grad = G.SoftplusGrad()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx1 = tanh(softplus(x))
|
||||||
|
dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
|
||||||
|
dx = (dx1 * dout + dx2)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.SeLU)
|
||||||
|
def get_bprop_selu(self):
|
||||||
|
"""Grad definition for `SeLU` operation."""
|
||||||
|
scale = 1.0507009873554804934193349852946
|
||||||
|
elu_grad = G.EluGrad()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = elu_grad(dout, out) * scale
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.MulNoNan)
|
||||||
|
def get_bprop_mul_no_nan(self):
|
||||||
|
"""Grad definition for `MulNoNan` operation."""
|
||||||
|
mul_no_nan = P.MulNoNan()
|
||||||
|
reduce_sum = P.ReduceSum()
|
||||||
|
reshape = P.Reshape()
|
||||||
|
|
||||||
|
def bprop(x, y, out, dout):
|
||||||
|
x_shape = F.shape(x)
|
||||||
|
y_shape = F.shape(y)
|
||||||
|
dx = mul_no_nan(dout, y)
|
||||||
|
dy = mul_no_nan(x, dout)
|
||||||
|
broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
|
||||||
|
if broadcast_x != ():
|
||||||
|
dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
|
||||||
|
if broadcast_y != ():
|
||||||
|
dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
|
||||||
|
return dx, dy
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.ReLU)
|
@bprop_getters.register(P.ReLU)
|
||||||
def get_bprop_relu(self):
|
def get_bprop_relu(self):
|
||||||
"""Grad definition for `ReLU` operation."""
|
"""Grad definition for `ReLU` operation."""
|
||||||
|
|
|
@ -356,3 +356,6 @@ from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe
|
||||||
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe
|
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe
|
||||||
from .nll_loss import _nll_loss_tbe
|
from .nll_loss import _nll_loss_tbe
|
||||||
from .nll_loss_grad import _nll_loss_grad_tbe
|
from .nll_loss_grad import _nll_loss_grad_tbe
|
||||||
|
from .mish import _mish_tbe
|
||||||
|
from .mul_no_nan import _mul_no_nan_tbe
|
||||||
|
from .selu import _selu_tbe
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Mish op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
mish_op_info = TBERegOp("Mish") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("mish.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("mish") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mish_op_info)
|
||||||
|
def _mish_tbe():
|
||||||
|
"""Mish TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MulNoNan op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
mul_no_nan_op_info = TBERegOp("MulNoNan") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("mul_no_nan.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("mul_no_nan") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x1", False, "required", "all") \
|
||||||
|
.input(1, "x2", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||||
|
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mul_no_nan_op_info)
|
||||||
|
def _mul_no_nan_tbe():
|
||||||
|
"""MulNoNan TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Selu op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
selu_op_info = TBERegOp("Selu") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("selu.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("selu") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", True, "required", "all") \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.dtype_format(DataType.I8_None, DataType.I8_None) \
|
||||||
|
.dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(selu_op_info)
|
||||||
|
def _selu_tbe():
|
||||||
|
"""Selu TBE register"""
|
||||||
|
return
|
|
@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
|
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
|
||||||
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
||||||
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||||
LogicalNot, LogicalOr, MatMul, Maximum,
|
LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan,
|
||||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||||
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
|
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
|
||||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||||
|
@ -70,8 +70,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
MaxPool, DataFormatDimMap,
|
MaxPool, DataFormatDimMap,
|
||||||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||||
ResizeBilinear, Sigmoid,
|
ResizeBilinear, Sigmoid, SeLU,
|
||||||
SigmoidCrossEntropyWithLogits, NLLLoss,
|
SigmoidCrossEntropyWithLogits, NLLLoss,
|
||||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
|
@ -194,6 +194,9 @@ __all__ = [
|
||||||
'ZerosLike',
|
'ZerosLike',
|
||||||
'Select',
|
'Select',
|
||||||
'Split',
|
'Split',
|
||||||
|
'Mish',
|
||||||
|
'SeLU',
|
||||||
|
'MulNoNan',
|
||||||
'ReLU',
|
'ReLU',
|
||||||
'ReLU6',
|
'ReLU6',
|
||||||
'Elu',
|
'Elu',
|
||||||
|
|
|
@ -2035,6 +2035,58 @@ class DivNoNan(_MathBinaryOp):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MulNoNan(_MathBinaryOp):
|
||||||
|
r"""
|
||||||
|
Computes x * y element-wise. if y is zero, No matter what x is, it will return 0.
|
||||||
|
|
||||||
|
Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent.
|
||||||
|
The inputs must be two tensors or one tensor and one scalar.
|
||||||
|
When the inputs are two tensors, the shapes of them could be broadcast.
|
||||||
|
When the inputs are one tensor and one scalar, the scalar could only be a constant.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The shapes of X and y should be same or can be broadcasting.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
|
||||||
|
- **input_y** (Union[Tensor]) - The second input is a tensor whose data type is number.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape is the same as the one after broadcasting,
|
||||||
|
and the data type is the one with higher precision or higher digits among the two inputs.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
TypeError: If x or y is a bool tensor.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[-1.0, 6.0, np.inf], [np.nan, -7.0, 4.0]]), ms.float32)
|
||||||
|
>>> y = Tensor(np.array([[-1.0, 4.0, 0], [0, -3.0, 1.0]]), ms.float32)
|
||||||
|
>>> mul_no_nan = ops.MulNoNan()
|
||||||
|
>>> output = mul_no_nan(x, y)
|
||||||
|
>>> print(output)
|
||||||
|
[[ 1. 24. 0.]
|
||||||
|
[ 0. 21. 4.]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize _BinaryOp"""
|
||||||
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_value(self, x, y):
|
||||||
|
if x is not None and y is not None:
|
||||||
|
x = x.asnumpy()
|
||||||
|
y = y.asnumpy()
|
||||||
|
with np.errstate(divide='ignore', invalid='ignore'):
|
||||||
|
out = np.multiply(x, y)
|
||||||
|
out[y == 0] = 0
|
||||||
|
return out
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FloorDiv(_MathBinaryOp):
|
class FloorDiv(_MathBinaryOp):
|
||||||
"""
|
"""
|
||||||
Divides the first input tensor by the second input tensor element-wise and round down to the closest integer.
|
Divides the first input tensor by the second input tensor element-wise and round down to the closest integer.
|
||||||
|
@ -4041,6 +4093,7 @@ class LinSpace(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MatrixInverse(PrimitiveWithInfer):
|
class MatrixInverse(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
|
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
|
||||||
|
|
|
@ -329,6 +329,99 @@ class ReLU(PrimitiveWithCheck):
|
||||||
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
|
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class Mish(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Computes MISH of input tensors element-wise.
|
||||||
|
|
||||||
|
The function is shown as follows:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{output} = x * \tan(\log(1 + \exp(\text{x})))
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - The input tensor. Only support float16 and float32.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
TypeError: If num_features data type not float16 and float32 Tensor.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||||
|
>>> mish = ops.Mish()
|
||||||
|
>>> output = mish(input_x)
|
||||||
|
>>> print(output)
|
||||||
|
[[-3.034014e-01 3.997413e+00 -2.682209e-03]
|
||||||
|
[ 1.943959e+00 -3.357619e-02 8.999999e+00]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize Mish"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class SeLU(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise.
|
||||||
|
|
||||||
|
The activation function is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
E_{i} =
|
||||||
|
scale *
|
||||||
|
\begin{cases}
|
||||||
|
x, &\text{if } x \geq 0; \cr
|
||||||
|
\text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the same type and shape as the `input_x`.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
TypeError: If num_features data type not int8, int32, float16 and float32 Tensor.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||||
|
>>> selu = ops.SeLU()
|
||||||
|
>>> output = selu(input_x)
|
||||||
|
>>> print(output)
|
||||||
|
[[-1.1113307 4.202804 -1.7575096]
|
||||||
|
[ 2.101402 -1.7462534 9.456309 ]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize SeLU"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||||
|
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class ReLU6(PrimitiveWithInfer):
|
class ReLU6(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise.
|
Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise.
|
||||||
|
|
|
@ -320,6 +320,42 @@ class CountNonZero(nn.Cell):
|
||||||
return nonzero_num
|
return nonzero_num
|
||||||
|
|
||||||
|
|
||||||
|
class Mish(nn.Cell):
|
||||||
|
"""Mish net definition"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Mish, self).__init__()
|
||||||
|
self.mish = P.Mish()
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = self.mish(input_x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SeLU(nn.Cell):
|
||||||
|
"""Selu net definition"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SeLU, self).__init__()
|
||||||
|
self.selu = P.SeLU()
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = self.selu(input_x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MulNoNan(nn.Cell):
|
||||||
|
"""MulNoNan net definition"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MulNoNan, self).__init__()
|
||||||
|
self.mul_no_nan = P.MulNoNan()
|
||||||
|
|
||||||
|
def construct(self, input_x, input_y):
|
||||||
|
out = self.mul_no_nan(input_x, input_y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ScatterUpdate(nn.Cell):
|
class ScatterUpdate(nn.Cell):
|
||||||
"""ScatterUpdate net definition"""
|
"""ScatterUpdate net definition"""
|
||||||
|
|
||||||
|
@ -1315,6 +1351,19 @@ test_case_math_ops = [
|
||||||
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
|
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
|
||||||
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
|
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
|
||||||
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}),
|
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}),
|
||||||
|
('Mish', {
|
||||||
|
'block': Mish(),
|
||||||
|
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||||
|
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||||
|
('SeLU', {
|
||||||
|
'block': SeLU(),
|
||||||
|
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||||
|
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||||
|
('MulNoNan', {
|
||||||
|
'block': MulNoNan(),
|
||||||
|
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32),
|
||||||
|
Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
|
||||||
|
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
|
||||||
('Rank', {
|
('Rank', {
|
||||||
'block': P.Rank(),
|
'block': P.Rank(),
|
||||||
'desc_inputs': [[2, 3]],
|
'desc_inputs': [[2, 3]],
|
||||||
|
|
Loading…
Reference in New Issue