vm for BesselI0e and BesselI1e

This commit is contained in:
jiangjinsheng 2020-05-27 17:49:25 +08:00
parent 205cfec632
commit c3f681f0cf
8 changed files with 190 additions and 2 deletions

View File

@ -24,6 +24,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
from .grad_base import bprop_getters
from ..primitive import constexpr
from ..composite.multitype_ops import _constexpr_utils as const_utils
shape_op = P.Shape()
reduce_sum = P.ReduceSum()
@ -875,3 +876,39 @@ def get_bprop_atan2(self):
return binop_grad_common(x, y, bc_dx, bc_dy)
return bprop
@bprop_getters.register(P.BesselI0e)
def get_bprop_bessel_i0e(self):
"""Generate bprop for BesselI0e"""
sign = P.Sign()
bessel_i1e = P.BesselI1e()
def bprop(x, out, dout):
dx = dout * (bessel_i1e(x) - sign(x) * out)
return (dx,)
return bprop
@bprop_getters.register(P.BesselI1e)
def get_bprop_bessel_i1e(self):
"""Generate bprop for BesselI1e"""
sign = P.Sign()
bessel_i0e = P.BesselI0e()
less = P.Less()
select = P.Select()
reciprocal = P.Reciprocal()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
zeros = zeros_like(x)
np_eps = const_utils.get_np_eps(dtype(x))
eps = cast(np_eps, dtype(x))
x_is_valid = less(eps, x)
x_safe = select(x_is_valid, x, eps + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, 0.5 + zeros)
return (dx,)
return bprop

View File

@ -200,6 +200,8 @@ from .reduce_prod import _reduce_prod_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe
from .atan2 import _atan2_tbe
from .bessel_i0e import _bessel_i0e_tbe
from .bessel_i1e import _bessel_i1e_tbe
from .batch_to_space_nd import _batch_to_space_nd_tbe
from .space_to_batch_nd import _space_to_batch_nd_tbe
from .bitwise_and import bitwise_and_op_info

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI0e op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i0e_op_info = TBERegOp("BesselI0e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i0e.so") \
.compute_cost(10) \
.kernel_name("bessel_i0e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(bessel_i0e_op_info)
def _bessel_i0e_tbe():
"""BesselI0e TBE register"""
return

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI1e op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i1e_op_info = TBERegOp("BesselI1e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i1e.so") \
.compute_cost(10) \
.kernel_name("bessel_i1e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(bessel_i1e_op_info)
def _bessel_i1e_tbe():
"""BesselI1e TBE register"""
return

View File

@ -631,3 +631,10 @@ def scalar_in_sequence(x, y):
if x in y:
return True
return False
@constexpr
def get_np_eps(input_dtype):
nptype = mstype.dtype_to_nptype(input_dtype)
eps = np.finfo(nptype).eps
return float(eps)

View File

@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
@ -270,7 +270,9 @@ __all__ = [
"SquareSumAll",
"BitwiseAnd",
"BitwiseOr",
"BitwiseXor"
"BitwiseXor",
"BesselI0e",
"BesselI1e",
]
__all__.extend(_quant_ops.__all__)

View File

@ -2265,3 +2265,61 @@ class BitwiseXor(_BitwiseBinaryOp):
>>> bitwise_xor(input_x1, input_x2)
[0, 1, 0, 0, -2, 3, 2]
"""
class BesselI0e(PrimitiveWithInfer):
"""
Computes BesselI0e of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> bessel_i0e = P.BesselI0e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i0e(input_x)
[0.7979961, 0.5144438, 0.75117415, 0.9157829]
"""
@prim_attr_register
def __init__(self):
"""init BesselI0e"""
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
class BesselI1e(PrimitiveWithInfer):
"""
Computes BesselI1e of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> bessel_i1e = P.BesselI1e()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = bessel_i1e(input_x)
[0.09507662, 0.19699717, 0.11505538, 0.04116856]
"""
@prim_attr_register
def __init__(self):
"""init BesselI1e"""
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x

View File

@ -656,6 +656,14 @@ test_case_math_ops = [
'desc_const': [1],
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
'desc_bprop': []}),
('BesselI0e', {
'block': P.BesselI0e(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('BesselI1e', {
'block': P.BesselI1e(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
]
test_case_nn_ops = [