forked from mindspore-Ecosystem/mindspore
vm for BesselI0e and BesselI1e
This commit is contained in:
parent
205cfec632
commit
c3f681f0cf
|
@ -24,6 +24,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
|
from ..composite.multitype_ops import _constexpr_utils as const_utils
|
||||||
|
|
||||||
shape_op = P.Shape()
|
shape_op = P.Shape()
|
||||||
reduce_sum = P.ReduceSum()
|
reduce_sum = P.ReduceSum()
|
||||||
|
@ -875,3 +876,39 @@ def get_bprop_atan2(self):
|
||||||
return binop_grad_common(x, y, bc_dx, bc_dy)
|
return binop_grad_common(x, y, bc_dx, bc_dy)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BesselI0e)
|
||||||
|
def get_bprop_bessel_i0e(self):
|
||||||
|
"""Generate bprop for BesselI0e"""
|
||||||
|
sign = P.Sign()
|
||||||
|
bessel_i1e = P.BesselI1e()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = dout * (bessel_i1e(x) - sign(x) * out)
|
||||||
|
return (dx,)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BesselI1e)
|
||||||
|
def get_bprop_bessel_i1e(self):
|
||||||
|
"""Generate bprop for BesselI1e"""
|
||||||
|
|
||||||
|
sign = P.Sign()
|
||||||
|
bessel_i0e = P.BesselI0e()
|
||||||
|
less = P.Less()
|
||||||
|
select = P.Select()
|
||||||
|
reciprocal = P.Reciprocal()
|
||||||
|
cast = P.Cast()
|
||||||
|
dtype = P.DType()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
zeros = zeros_like(x)
|
||||||
|
np_eps = const_utils.get_np_eps(dtype(x))
|
||||||
|
eps = cast(np_eps, dtype(x))
|
||||||
|
x_is_valid = less(eps, x)
|
||||||
|
x_safe = select(x_is_valid, x, eps + zeros)
|
||||||
|
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
|
||||||
|
dx = select(x_is_valid, tmp, 0.5 + zeros)
|
||||||
|
return (dx,)
|
||||||
|
return bprop
|
||||||
|
|
|
@ -200,6 +200,8 @@ from .reduce_prod import _reduce_prod_tbe
|
||||||
from .flatten_grad import _flatten_grad_tbe
|
from .flatten_grad import _flatten_grad_tbe
|
||||||
from .scatter_add import _scatter_add_tbe
|
from .scatter_add import _scatter_add_tbe
|
||||||
from .atan2 import _atan2_tbe
|
from .atan2 import _atan2_tbe
|
||||||
|
from .bessel_i0e import _bessel_i0e_tbe
|
||||||
|
from .bessel_i1e import _bessel_i1e_tbe
|
||||||
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
||||||
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
||||||
from .bitwise_and import bitwise_and_op_info
|
from .bitwise_and import bitwise_and_op_info
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BesselI0e op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
bessel_i0e_op_info = TBERegOp("BesselI0e") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("bessel_i0e.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("bessel_i0e") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(bessel_i0e_op_info)
|
||||||
|
def _bessel_i0e_tbe():
|
||||||
|
"""BesselI0e TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BesselI1e op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
bessel_i1e_op_info = TBERegOp("BesselI1e") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("bessel_i1e.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("bessel_i1e") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(bessel_i1e_op_info)
|
||||||
|
def _bessel_i1e_tbe():
|
||||||
|
"""BesselI1e TBE register"""
|
||||||
|
return
|
|
@ -631,3 +631,10 @@ def scalar_in_sequence(x, y):
|
||||||
if x in y:
|
if x in y:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_np_eps(input_dtype):
|
||||||
|
nptype = mstype.dtype_to_nptype(input_dtype)
|
||||||
|
eps = np.finfo(nptype).eps
|
||||||
|
return float(eps)
|
||||||
|
|
|
@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
|
||||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||||
Reciprocal, CumSum,
|
Reciprocal, CumSum,
|
||||||
Sin, Sqrt, Rsqrt,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
||||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
|
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
|
||||||
from .random_ops import (RandomChoiceWithMask)
|
from .random_ops import (RandomChoiceWithMask)
|
||||||
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||||
|
@ -270,7 +270,9 @@ __all__ = [
|
||||||
"SquareSumAll",
|
"SquareSumAll",
|
||||||
"BitwiseAnd",
|
"BitwiseAnd",
|
||||||
"BitwiseOr",
|
"BitwiseOr",
|
||||||
"BitwiseXor"
|
"BitwiseXor",
|
||||||
|
"BesselI0e",
|
||||||
|
"BesselI1e",
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(_quant_ops.__all__)
|
__all__.extend(_quant_ops.__all__)
|
||||||
|
|
|
@ -2265,3 +2265,61 @@ class BitwiseXor(_BitwiseBinaryOp):
|
||||||
>>> bitwise_xor(input_x1, input_x2)
|
>>> bitwise_xor(input_x1, input_x2)
|
||||||
[0, 1, 0, 0, -2, 3, 2]
|
[0, 1, 0, 0, -2, 3, 2]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BesselI0e(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Computes BesselI0e of input element-wise.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> bessel_i0e = P.BesselI0e()
|
||||||
|
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
|
||||||
|
>>> output = bessel_i0e(input_x)
|
||||||
|
[0.7979961, 0.5144438, 0.75117415, 0.9157829]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init BesselI0e"""
|
||||||
|
|
||||||
|
def infer_shape(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_dtype(self, x):
|
||||||
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BesselI1e(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Computes BesselI1e of input element-wise.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> bessel_i1e = P.BesselI1e()
|
||||||
|
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
|
||||||
|
>>> output = bessel_i1e(input_x)
|
||||||
|
[0.09507662, 0.19699717, 0.11505538, 0.04116856]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init BesselI1e"""
|
||||||
|
|
||||||
|
def infer_shape(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_dtype(self, x):
|
||||||
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||||
|
return x
|
||||||
|
|
|
@ -656,6 +656,14 @@ test_case_math_ops = [
|
||||||
'desc_const': [1],
|
'desc_const': [1],
|
||||||
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
|
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
|
||||||
'desc_bprop': []}),
|
'desc_bprop': []}),
|
||||||
|
('BesselI0e', {
|
||||||
|
'block': P.BesselI0e(),
|
||||||
|
'desc_inputs': [[2, 3]],
|
||||||
|
'desc_bprop': [[2, 3]]}),
|
||||||
|
('BesselI1e', {
|
||||||
|
'block': P.BesselI1e(),
|
||||||
|
'desc_inputs': [[2, 3]],
|
||||||
|
'desc_bprop': [[2, 3]]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_nn_ops = [
|
test_case_nn_ops = [
|
||||||
|
|
Loading…
Reference in New Issue