diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index ee71979f28d..04ebe823131 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -24,6 +24,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..functional import broadcast_gradient_args, reduced_shape, tuple_div from .grad_base import bprop_getters from ..primitive import constexpr +from ..composite.multitype_ops import _constexpr_utils as const_utils shape_op = P.Shape() reduce_sum = P.ReduceSum() @@ -875,3 +876,39 @@ def get_bprop_atan2(self): return binop_grad_common(x, y, bc_dx, bc_dy) return bprop + + +@bprop_getters.register(P.BesselI0e) +def get_bprop_bessel_i0e(self): + """Generate bprop for BesselI0e""" + sign = P.Sign() + bessel_i1e = P.BesselI1e() + + def bprop(x, out, dout): + dx = dout * (bessel_i1e(x) - sign(x) * out) + return (dx,) + return bprop + + +@bprop_getters.register(P.BesselI1e) +def get_bprop_bessel_i1e(self): + """Generate bprop for BesselI1e""" + + sign = P.Sign() + bessel_i0e = P.BesselI0e() + less = P.Less() + select = P.Select() + reciprocal = P.Reciprocal() + cast = P.Cast() + dtype = P.DType() + + def bprop(x, out, dout): + zeros = zeros_like(x) + np_eps = const_utils.get_np_eps(dtype(x)) + eps = cast(np_eps, dtype(x)) + x_is_valid = less(eps, x) + x_safe = select(x_is_valid, x, eps + zeros) + tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe)) + dx = select(x_is_valid, tmp, 0.5 + zeros) + return (dx,) + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index cd90cc9f8f2..7ad6bfe4d13 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -200,6 +200,8 @@ from .reduce_prod import _reduce_prod_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe from .atan2 import _atan2_tbe +from .bessel_i0e import _bessel_i0e_tbe +from .bessel_i1e import _bessel_i1e_tbe from .batch_to_space_nd import _batch_to_space_nd_tbe from .space_to_batch_nd import _space_to_batch_nd_tbe from .bitwise_and import bitwise_and_op_info diff --git a/mindspore/ops/_op_impl/tbe/bessel_i0e.py b/mindspore/ops/_op_impl/tbe/bessel_i0e.py new file mode 100644 index 00000000000..ad0030d93ad --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bessel_i0e.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BesselI0e op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bessel_i0e_op_info = TBERegOp("BesselI0e") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("bessel_i0e.so") \ + .compute_cost(10) \ + .kernel_name("bessel_i0e") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(bessel_i0e_op_info) +def _bessel_i0e_tbe(): + """BesselI0e TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/bessel_i1e.py b/mindspore/ops/_op_impl/tbe/bessel_i1e.py new file mode 100644 index 00000000000..39abb5dad4d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bessel_i1e.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BesselI1e op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bessel_i1e_op_info = TBERegOp("BesselI1e") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("bessel_i1e.so") \ + .compute_cost(10) \ + .kernel_name("bessel_i1e") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(bessel_i1e_op_info) +def _bessel_i1e_tbe(): + """BesselI1e TBE register""" + return diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 796933c6f17..e4d42aed033 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -631,3 +631,10 @@ def scalar_in_sequence(x, y): if x in y: return True return False + + +@constexpr +def get_np_eps(input_dtype): + nptype = mstype.dtype_to_nptype(input_dtype) + eps = np.finfo(nptype).eps + return float(eps) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 6805f14b7ea..5462c783144 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul NPUAllocFloatStatus, NPUClearFloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, - Sin, Sqrt, Rsqrt, + Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Square, Sub, TensorAdd, Sign, Round, SquareSumAll) from .random_ops import (RandomChoiceWithMask) from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, @@ -270,7 +270,9 @@ __all__ = [ "SquareSumAll", "BitwiseAnd", "BitwiseOr", - "BitwiseXor" + "BitwiseXor", + "BesselI0e", + "BesselI1e", ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index d7917b6c7d9..47467a2b87c 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2265,3 +2265,61 @@ class BitwiseXor(_BitwiseBinaryOp): >>> bitwise_xor(input_x1, input_x2) [0, 1, 0, 0, -2, 3, 2] """ + + +class BesselI0e(PrimitiveWithInfer): + """ + Computes BesselI0e of input element-wise. + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Outputs: + Tensor, has the same shape as `input_x`. + + Examples: + >>> bessel_i0e = P.BesselI0e() + >>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32) + >>> output = bessel_i0e(input_x) + [0.7979961, 0.5144438, 0.75117415, 0.9157829] + """ + + @prim_attr_register + def __init__(self): + """init BesselI0e""" + + def infer_shape(self, x): + return x + + def infer_dtype(self, x): + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) + return x + + +class BesselI1e(PrimitiveWithInfer): + """ + Computes BesselI1e of input element-wise. + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Outputs: + Tensor, has the same shape as `input_x`. + + Examples: + >>> bessel_i1e = P.BesselI1e() + >>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32) + >>> output = bessel_i1e(input_x) + [0.09507662, 0.19699717, 0.11505538, 0.04116856] + """ + + @prim_attr_register + def __init__(self): + """init BesselI1e""" + + def infer_shape(self, x): + return x + + def infer_dtype(self, x): + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) + return x diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 842c632842c..b41cbdaeb99 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -656,6 +656,14 @@ test_case_math_ops = [ 'desc_const': [1], 'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))], 'desc_bprop': []}), + ('BesselI0e', { + 'block': P.BesselI0e(), + 'desc_inputs': [[2, 3]], + 'desc_bprop': [[2, 3]]}), + ('BesselI1e', { + 'block': P.BesselI1e(), + 'desc_inputs': [[2, 3]], + 'desc_bprop': [[2, 3]]}), ] test_case_nn_ops = [