vm for BesselI0e and BesselI1e
This commit is contained in:
parent
205cfec632
commit
c3f681f0cf
|
@ -24,6 +24,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
|||
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||
from .grad_base import bprop_getters
|
||||
from ..primitive import constexpr
|
||||
from ..composite.multitype_ops import _constexpr_utils as const_utils
|
||||
|
||||
shape_op = P.Shape()
|
||||
reduce_sum = P.ReduceSum()
|
||||
|
@ -875,3 +876,39 @@ def get_bprop_atan2(self):
|
|||
return binop_grad_common(x, y, bc_dx, bc_dy)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BesselI0e)
|
||||
def get_bprop_bessel_i0e(self):
|
||||
"""Generate bprop for BesselI0e"""
|
||||
sign = P.Sign()
|
||||
bessel_i1e = P.BesselI1e()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = dout * (bessel_i1e(x) - sign(x) * out)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BesselI1e)
|
||||
def get_bprop_bessel_i1e(self):
|
||||
"""Generate bprop for BesselI1e"""
|
||||
|
||||
sign = P.Sign()
|
||||
bessel_i0e = P.BesselI0e()
|
||||
less = P.Less()
|
||||
select = P.Select()
|
||||
reciprocal = P.Reciprocal()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
zeros = zeros_like(x)
|
||||
np_eps = const_utils.get_np_eps(dtype(x))
|
||||
eps = cast(np_eps, dtype(x))
|
||||
x_is_valid = less(eps, x)
|
||||
x_safe = select(x_is_valid, x, eps + zeros)
|
||||
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
|
||||
dx = select(x_is_valid, tmp, 0.5 + zeros)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
|
|
@ -200,6 +200,8 @@ from .reduce_prod import _reduce_prod_tbe
|
|||
from .flatten_grad import _flatten_grad_tbe
|
||||
from .scatter_add import _scatter_add_tbe
|
||||
from .atan2 import _atan2_tbe
|
||||
from .bessel_i0e import _bessel_i0e_tbe
|
||||
from .bessel_i1e import _bessel_i1e_tbe
|
||||
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
||||
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
||||
from .bitwise_and import bitwise_and_op_info
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BesselI0e op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bessel_i0e_op_info = TBERegOp("BesselI0e") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bessel_i0e.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bessel_i0e") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bessel_i0e_op_info)
|
||||
def _bessel_i0e_tbe():
|
||||
"""BesselI0e TBE register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BesselI1e op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bessel_i1e_op_info = TBERegOp("BesselI1e") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bessel_i1e.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bessel_i1e") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bessel_i1e_op_info)
|
||||
def _bessel_i1e_tbe():
|
||||
"""BesselI1e TBE register"""
|
||||
return
|
|
@ -631,3 +631,10 @@ def scalar_in_sequence(x, y):
|
|||
if x in y:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_np_eps(input_dtype):
|
||||
nptype = mstype.dtype_to_nptype(input_dtype)
|
||||
eps = np.finfo(nptype).eps
|
||||
return float(eps)
|
||||
|
|
|
@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
|
|||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum,
|
||||
Sin, Sqrt, Rsqrt,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
|
||||
from .random_ops import (RandomChoiceWithMask)
|
||||
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||
|
@ -270,7 +270,9 @@ __all__ = [
|
|||
"SquareSumAll",
|
||||
"BitwiseAnd",
|
||||
"BitwiseOr",
|
||||
"BitwiseXor"
|
||||
"BitwiseXor",
|
||||
"BesselI0e",
|
||||
"BesselI1e",
|
||||
]
|
||||
|
||||
__all__.extend(_quant_ops.__all__)
|
||||
|
|
|
@ -2265,3 +2265,61 @@ class BitwiseXor(_BitwiseBinaryOp):
|
|||
>>> bitwise_xor(input_x1, input_x2)
|
||||
[0, 1, 0, 0, -2, 3, 2]
|
||||
"""
|
||||
|
||||
|
||||
class BesselI0e(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes BesselI0e of input element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> bessel_i0e = P.BesselI0e()
|
||||
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
|
||||
>>> output = bessel_i0e(input_x)
|
||||
[0.7979961, 0.5144438, 0.75117415, 0.9157829]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init BesselI0e"""
|
||||
|
||||
def infer_shape(self, x):
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
class BesselI1e(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes BesselI1e of input element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> bessel_i1e = P.BesselI1e()
|
||||
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
|
||||
>>> output = bessel_i1e(input_x)
|
||||
[0.09507662, 0.19699717, 0.11505538, 0.04116856]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init BesselI1e"""
|
||||
|
||||
def infer_shape(self, x):
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
|
|
@ -656,6 +656,14 @@ test_case_math_ops = [
|
|||
'desc_const': [1],
|
||||
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
|
||||
'desc_bprop': []}),
|
||||
('BesselI0e', {
|
||||
'block': P.BesselI0e(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('BesselI1e', {
|
||||
'block': P.BesselI1e(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
]
|
||||
|
||||
test_case_nn_ops = [
|
||||
|
|
Loading…
Reference in New Issue