[feat][assistant][I4CRK7] add new Ascend operator BesselJ0 and BesselJ1
This commit is contained in:
parent
1c42f6dbed
commit
63f08d4b36
|
@ -341,6 +341,33 @@ def get_bprop_bessel_i1(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(math.BesselJ0)
|
||||
def get_bprop_bessel_j0(self):
|
||||
"""Generate bprop for BesselJ0"""
|
||||
bessel_j1 = math.BesselJ1()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = -dout * bessel_j1(x)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(math.BesselJ1)
|
||||
def get_bprop_bessel_j1(self):
|
||||
"""Generate bprop for BesselJ1"""
|
||||
equal = P.Equal()
|
||||
div = P.Div()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
bessel_j0 = math.BesselJ0()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dout_dx = mnp.where(equal(x, 0.), cast(0.5, dtype(x)), bessel_j0(x) - div(out, x))
|
||||
dx = dout * dout_dx
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(math.BesselK0)
|
||||
def get_bprop_bessel_k0(self):
|
||||
"""Generate bprop for BesselK0"""
|
||||
|
|
|
@ -21,6 +21,8 @@ from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
|
|||
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
|
||||
from .bessel_i0 import _bessel_i0_tbe
|
||||
from .bessel_i1 import _bessel_i1_tbe
|
||||
from .bessel_j0 import _bessel_j0_tbe
|
||||
from .bessel_j1 import _bessel_j1_tbe
|
||||
from .bessel_k0 import _bessel_k0_tbe
|
||||
from .bessel_k1 import _bessel_k1_tbe
|
||||
from .bessel_k0e import _bessel_k0e_tbe
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""BesselJ0 op"""
|
||||
|
||||
import te.lang.cce as tbe
|
||||
import te.platform as tbe_platform
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from tbe import dsl
|
||||
from tbe.common.utils import shape_util
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bessel_j0_op_info = TBERegOp("BesselJ0") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bessel_j0.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bessel_j0") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bessel_j0_op_info)
|
||||
def _bessel_j0_tbe():
|
||||
"""BesselJ0 TBE register"""
|
||||
return
|
||||
|
||||
FLOAT_16 = "float16"
|
||||
FLOAT_32 = "float32"
|
||||
|
||||
PP = [7.96936729297347051624E-4, 8.28352392107440799803E-2,
|
||||
1.23953371646414299388E0, 5.44725003058768775090E0,
|
||||
8.74716500199817011941E0, 5.30324038235394892183E0,
|
||||
9.99999999999999997821E-1]
|
||||
PQ = [9.24408810558863637013E-4, 8.56288474354474431428E-2,
|
||||
1.25352743901058953537E0, 5.47097740330417105182E0,
|
||||
8.76190883237069594232E0, 5.30605288235394617618E0,
|
||||
1.00000000000000000218E0]
|
||||
QP = [-1.13663838898469149931E-2, -1.28252718670509318512E0,
|
||||
-1.95539544257735972385E1, -9.32060152123768231369E1,
|
||||
-1.77681167980488050595E2, -1.47077505154951170175E2,
|
||||
-5.14105326766599330220E1, -6.05014350600728481186E0]
|
||||
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
|
||||
8.56430025976980587198E2, 3.88240183605401609683E3,
|
||||
7.24046774195652478189E3, 5.93072701187316984827E3,
|
||||
2.06209331660327847417E3, 2.42005740240291393179E2]
|
||||
RP = [-4.79443220978201773821E9, 1.95617491946556577543E12,
|
||||
-2.49248344360967716204E14, 9.70862251047306323952E15]
|
||||
RQ = [1.00000000000000000000E0, 4.99563147152651017219E2,
|
||||
1.73785401676374683123E5, 4.84409658339962045305E7,
|
||||
1.11855537045356834862E10, 2.11277520115489217587E12,
|
||||
3.10518229857422583814E14, 3.18121955943204943306E16,
|
||||
1.71086294081043136091E18]
|
||||
|
||||
DR1 = 5.78318596294678452118E0
|
||||
DR2 = 3.04712623436620863991E1
|
||||
SQ2OPI = 7.9788456080286535587989E-1
|
||||
NEG_PIO4 = -0.7853981633974483096
|
||||
|
||||
PI = 3.14159265358979
|
||||
FIRST_ORDER = 5
|
||||
LAST_ORDER = 13
|
||||
FIRST_FACTOR = -1.0 / 6.0
|
||||
|
||||
|
||||
def besselj0_cos(x):
|
||||
"""cos"""
|
||||
dtype = x.dtype
|
||||
shape = shape_util.shape_to_list(x.shape)
|
||||
# cast to type float32 when type is float16
|
||||
has_improve_precision = False
|
||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
x = dsl.cast_to(x, FLOAT_32)
|
||||
dtype = FLOAT_32
|
||||
has_improve_precision = True
|
||||
# round the input
|
||||
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
|
||||
round_fp32 = dsl.cast_to(round_fp16, dtype)
|
||||
input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
|
||||
# the initial value one
|
||||
const_res = tvm.const(1.0, dtype=dtype)
|
||||
res = dsl.broadcast(const_res, shape)
|
||||
# compute the rank 2
|
||||
input_x_power = dsl.vmul(input_x_round, input_x_round)
|
||||
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
|
||||
res = dsl.vadd(res, iter_value)
|
||||
# compute the rank 4~14
|
||||
iter_list = (4, 6, 8, 10, 12, 14)
|
||||
for i in iter_list:
|
||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
|
||||
res = dsl.vadd(res, iter_value)
|
||||
# cast the dtype to float16
|
||||
if has_improve_precision:
|
||||
res = dsl.cast_to(res, "float16")
|
||||
return res
|
||||
|
||||
|
||||
def _besselj0_sin(x):
|
||||
"""_besselj0_sin"""
|
||||
input_x_power = dsl.vmul(x, x)
|
||||
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
|
||||
res = dsl.vadd(x, iter_value)
|
||||
signal = FIRST_ORDER
|
||||
while signal < LAST_ORDER:
|
||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
|
||||
-1.0 / (signal*(signal - 1)))
|
||||
res = dsl.vadd(res, iter_value)
|
||||
signal = signal + 2
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def besselj0_sin(x):
|
||||
"""sin"""
|
||||
dtype = x.dtype
|
||||
shape = shape_util.shape_to_list(x.shape)
|
||||
has_improve_precision = False
|
||||
cast_dtype = FLOAT_16
|
||||
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
has_improve_precision = True
|
||||
cast_dtype = FLOAT_32
|
||||
|
||||
# cast to type float32 when type is float16
|
||||
if dtype == FLOAT_16 and has_improve_precision:
|
||||
x = tbe.cast_to(x, FLOAT_32)
|
||||
|
||||
pai_multiple = tbe.vmuls(x, 1 / PI)
|
||||
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
|
||||
# to adjust x to [-pai/2,pai/2]
|
||||
x = tbe.vsub(x, tbe.vmuls(round_float, PI))
|
||||
res = _besselj0_sin(x)
|
||||
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
|
||||
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
|
||||
# if odd, ceil*2-round is 1,if even, the value is 0
|
||||
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
|
||||
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
|
||||
odd_tensor = tbe.vsub(tensor_one, sub_value)
|
||||
even_tensor = tbe.vsub(odd_tensor, tensor_one)
|
||||
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
|
||||
res = tbe.vmul(res, odd_even_tensor)
|
||||
|
||||
# cast the dtype to float16
|
||||
if dtype == FLOAT_16 and has_improve_precision:
|
||||
res = tbe.cast_to(res, FLOAT_16)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def besselj0_polevl(x, n, coef, shape):
|
||||
"""polevl"""
|
||||
dtype = 'float32'
|
||||
x = dsl.cast_to(x, dtype)
|
||||
|
||||
if n == 0:
|
||||
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
|
||||
return dsl.cast_to(coef_0, dtype)
|
||||
|
||||
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
|
||||
res = dsl.vadd(dsl.vmul(besselj0_polevl(x, n-1, coef, shape), x), coef_n)
|
||||
return dsl.cast_to(res, 'float32')
|
||||
|
||||
|
||||
@fusion_manager.register("bessel_j0")
|
||||
def bessel_j0_compute(x, kernel_name="bessel_j0"):
|
||||
"""bessel_j0_compute"""
|
||||
dtype = x.dtype
|
||||
shape = shape_util.shape_to_list(x.shape)
|
||||
|
||||
# cast to type float32 when type is float16
|
||||
has_improve_precision = False
|
||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
x = dsl.cast_to(x, FLOAT_32)
|
||||
dtype = FLOAT_32
|
||||
has_improve_precision = True
|
||||
|
||||
y = dsl.vabs(x)
|
||||
z = dsl.vmul(y, y)
|
||||
y_le_five = dsl.vcmpsel(y, 1.0e-5, 'lt', dsl.vadds(dsl.vmuls(z, -0.25), 1),
|
||||
dsl.vmul(dsl.vmul(dsl.vadds(z, -DR1), dsl.vadds(z, -DR2)),
|
||||
dsl.vdiv(besselj0_polevl(z, 3, RP, shape), besselj0_polevl(z, 8, RQ, shape))))
|
||||
s = dsl.vmuls(dsl.vrec(z), 25)
|
||||
p = dsl.vdiv(besselj0_polevl(s, 6, PP, shape), besselj0_polevl(s, 6, PQ, shape))
|
||||
q = dsl.vdiv(besselj0_polevl(s, 7, QP, shape), besselj0_polevl(s, 6, PQ, shape))
|
||||
yn = dsl.vadds(y, NEG_PIO4)
|
||||
w = dsl.vmuls(dsl.vrec(y), -5.0)
|
||||
p = dsl.vadd(dsl.vmul(p, besselj0_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj0_sin(yn))))
|
||||
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
|
||||
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
|
||||
|
||||
if has_improve_precision:
|
||||
res = dsl.cast_to(res, "float16")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def bessel_j0(x, y, kernel_name="bessel_j0"):
|
||||
"""bessel_j0"""
|
||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
||||
res = bessel_j0_compute(data_x, kernel_name)
|
||||
|
||||
# auto schedule
|
||||
with tvm.target.cce():
|
||||
schedule = dsl.auto_schedule(res)
|
||||
|
||||
# operator build
|
||||
config = {"name": kernel_name,
|
||||
"tensor_list": [data_x, res]}
|
||||
dsl.build(schedule, config)
|
|
@ -0,0 +1,246 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""BesselJ1 op"""
|
||||
|
||||
|
||||
import te.lang.cce as tbe
|
||||
import te.platform as tbe_platform
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from tbe import dsl
|
||||
from tbe.common.utils import shape_util
|
||||
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bessel_j1_op_info = TBERegOp("BesselJ1") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bessel_j1.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bessel_j1") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bessel_j1_op_info)
|
||||
def _bessel_j1_tbe():
|
||||
"""BesselJ1 TBE register"""
|
||||
return
|
||||
|
||||
|
||||
FLOAT_16 = "float16"
|
||||
FLOAT_32 = "float32"
|
||||
|
||||
PI = 3.14159265358979
|
||||
FIRST_ORDER = 5
|
||||
LAST_ORDER = 13
|
||||
FIRST_FACTOR = -1.0 / 6.0
|
||||
|
||||
|
||||
def besselj1_cos(x):
|
||||
"""cos"""
|
||||
dtype = x.dtype
|
||||
shape = shape_util.shape_to_list(x.shape)
|
||||
|
||||
# cast to type float32 when type is float16
|
||||
has_improve_precision = False
|
||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
x = dsl.cast_to(x, FLOAT_32)
|
||||
dtype = FLOAT_32
|
||||
has_improve_precision = True
|
||||
# round the input
|
||||
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
|
||||
round_fp32 = dsl.cast_to(round_fp16, dtype)
|
||||
besselj1_input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
|
||||
# the initial value one
|
||||
const_res = tvm.const(1.0, dtype=dtype)
|
||||
res = dsl.broadcast(const_res, shape)
|
||||
# compute the rank 2
|
||||
input_x_power = dsl.vmul(besselj1_input_x_round, besselj1_input_x_round)
|
||||
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
|
||||
res = dsl.vadd(res, iter_value)
|
||||
# compute the rank 4~14
|
||||
iter_list = (4, 6, 8, 10, 12, 14)
|
||||
for i in iter_list:
|
||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
|
||||
res = dsl.vadd(res, iter_value)
|
||||
# cast the dtype to float16
|
||||
if has_improve_precision:
|
||||
res = dsl.cast_to(res, "float16")
|
||||
return res
|
||||
|
||||
|
||||
def _besselj1_sin(x):
|
||||
"""_sin"""
|
||||
input_x_power = dsl.vmul(x, x)
|
||||
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
|
||||
besselj1_res = dsl.vadd(x, iter_value)
|
||||
|
||||
signal = FIRST_ORDER
|
||||
while signal < LAST_ORDER:
|
||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
|
||||
-1.0 / (signal*(signal - 1)))
|
||||
besselj1_res = dsl.vadd(besselj1_res, iter_value)
|
||||
signal = signal + 2
|
||||
|
||||
return besselj1_res
|
||||
|
||||
|
||||
def besselj1_sin(besselj1_x):
|
||||
"""sin"""
|
||||
dtype = besselj1_x.dtype
|
||||
shape = shape_util.shape_to_list(besselj1_x.shape)
|
||||
|
||||
has_improve_precision = False
|
||||
cast_dtype = FLOAT_16
|
||||
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
has_improve_precision = True
|
||||
cast_dtype = FLOAT_32
|
||||
|
||||
# cast to type float32 when type is float16
|
||||
if dtype == FLOAT_16 and has_improve_precision:
|
||||
besselj1_x = tbe.cast_to(besselj1_x, FLOAT_32)
|
||||
|
||||
pai_multiple = tbe.vmuls(besselj1_x, 1 / PI)
|
||||
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
|
||||
# to adjust x to [-pai/2,pai/2]
|
||||
besselj1_x = tbe.vsub(besselj1_x, tbe.vmuls(round_float, PI))
|
||||
|
||||
besselj1_res = _besselj1_sin(besselj1_x)
|
||||
|
||||
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
|
||||
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
|
||||
# if odd, ceil*2-round is 1,if even, the value is 0
|
||||
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
|
||||
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
|
||||
odd_tensor = tbe.vsub(tensor_one, sub_value)
|
||||
even_tensor = tbe.vsub(odd_tensor, tensor_one)
|
||||
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
|
||||
besselj1_res = tbe.vmul(besselj1_res, odd_even_tensor)
|
||||
|
||||
# cast the dtype to float16
|
||||
if dtype == FLOAT_16 and has_improve_precision:
|
||||
besselj1_res = tbe.cast_to(besselj1_res, FLOAT_16)
|
||||
|
||||
return besselj1_res
|
||||
|
||||
|
||||
PP = [7.62125616208173112003E-4, 7.31397056940917570436E-2,
|
||||
1.12719608129684925192E0, 5.11207951146807644818E0,
|
||||
8.42404590141772420927E0, 5.21451598682361504063E0,
|
||||
1.00000000000000000254E0]
|
||||
|
||||
PQ = [5.71323128072548699714E-4, 6.88455908754495404082E-2,
|
||||
1.10514232634061696926E0, 5.07386386128601488557E0,
|
||||
8.39985554327604159757E0, 5.20982848682361821619E0,
|
||||
9.99999999999999997461E-1]
|
||||
|
||||
QP = [5.10862594750176621635E-2, 4.98213872951233449420E0,
|
||||
7.58238284132545283818E1, 3.66779609360150777800E2,
|
||||
7.10856304998926107277E2, 5.97489612400613639965E2,
|
||||
2.11688757100572135698E2, 2.52070205858023719784E1]
|
||||
|
||||
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
|
||||
8.56430025976980587198E2, 3.88240183605401609683E3,
|
||||
7.24046774195652478189E3, 5.93072701187316984827E3,
|
||||
2.06209331660327847417E3, 2.42005740240291393179E2]
|
||||
|
||||
RP = [-8.99971225705559398224E8, 4.52228297998194034323E11,
|
||||
-7.27494245221818276015E13, 3.68295732863852883286E15]
|
||||
|
||||
RQ = [1.00000000000000000000E0, 6.20836478118054335476E2,
|
||||
2.56987256757748830383E5, 8.35146791431949253037E7,
|
||||
2.21511595479792499675E10, 4.74914122079991414898E12,
|
||||
7.84369607876235854894E14, 8.95222336184627338078E16,
|
||||
5.32278620332680085395E18]
|
||||
|
||||
Z1 = 1.46819706421238932572E1
|
||||
Z2 = 4.92184563216946036703E1
|
||||
NEG_THPIO4 = -2.35619449019234492885
|
||||
SQ2OPI = 7.9788456080286535587989E-1
|
||||
|
||||
|
||||
def polevl(x, n, coef, shape):
|
||||
"""polevl"""
|
||||
dtype = 'float32'
|
||||
x = dsl.cast_to(x, dtype)
|
||||
|
||||
if n == 0:
|
||||
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
|
||||
return dsl.cast_to(coef_0, dtype)
|
||||
|
||||
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
|
||||
res = dsl.vadd(dsl.vmul(polevl(x, n-1, coef, shape), x), coef_n)
|
||||
return dsl.cast_to(res, 'float32')
|
||||
|
||||
|
||||
@fusion_manager.register("bessel_j1")
|
||||
def bessel_j1_compute(x, kernel_name="bessel_j1"):
|
||||
"""bessel_j1_compute"""
|
||||
dtype = x.dtype
|
||||
shape = shape_util.shape_to_list(x.shape)
|
||||
|
||||
# cast to type float32 when type is float16
|
||||
has_improve_precision = False
|
||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
||||
x = dsl.cast_to(x, FLOAT_32)
|
||||
dtype = FLOAT_32
|
||||
has_improve_precision = True
|
||||
|
||||
y = dsl.vabs(x)
|
||||
z = dsl.vmul(y, y)
|
||||
y_le_five = dsl.vdiv(polevl(z, 3, RP, shape), polevl(z, 8, RQ, shape))
|
||||
y_le_five = dsl.vmul(y_le_five, dsl.vmul(x, dsl.vmul(dsl.vadds(z, -Z1), dsl.vadds(z, -Z2))))
|
||||
s = dsl.vmuls(dsl.vrec(z), 25)
|
||||
p = dsl.vdiv(polevl(s, 6, PP, shape), polevl(s, 6, PQ, shape))
|
||||
q = dsl.vdiv(polevl(s, 7, QP, shape), polevl(s, 7, QQ, shape))
|
||||
yn = dsl.vadds(y, NEG_THPIO4)
|
||||
w = dsl.vmuls(dsl.vrec(y), -5.0)
|
||||
p = dsl.vadd(dsl.vmul(p, besselj1_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj1_sin(yn))))
|
||||
|
||||
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
|
||||
y_gt_five = dsl.vcmpsel(x, 0.0, 'lt', dsl.vmuls(y_gt_five, -1.0), y_gt_five)
|
||||
|
||||
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
|
||||
|
||||
if has_improve_precision:
|
||||
res = dsl.cast_to(res, "float16")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def bessel_j1(x, y, kernel_name="bessel_j1"):
|
||||
"""
|
||||
To do: Implement the operator by referring to the
|
||||
TBE Operator Development Guide.
|
||||
"""
|
||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
||||
|
||||
res = bessel_j1_compute(data_x, kernel_name)
|
||||
|
||||
# auto schedule
|
||||
with tvm.target.cce():
|
||||
schedule = dsl.auto_schedule(res)
|
||||
|
||||
# operator build
|
||||
config = {"name": kernel_name,
|
||||
"tensor_list": [data_x, res]}
|
||||
dsl.build(schedule, config)
|
|
@ -4761,7 +4761,7 @@ class BesselJ0(Primitive):
|
|||
Tensor, has the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor of float16, float32.
|
||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
@ -4793,7 +4793,7 @@ class BesselJ1(Primitive):
|
|||
Tensor, has the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor of float16, float32.
|
||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -30,7 +30,8 @@ from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations.math_ops import BesselK0, BesselK1, BesselK0e, BesselK1e, Bucketize
|
||||
from mindspore.ops.operations.math_ops import BesselJ0, BesselJ1, BesselK0, BesselK1, BesselK0e, \
|
||||
BesselK1e, Bucketize
|
||||
from mindspore.ops.operations.math_ops import ReduceStd
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import Tril
|
||||
|
@ -1832,6 +1833,14 @@ test_case_math_ops = [
|
|||
'block': P.BesselI1e(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('BesselJ0', {
|
||||
'block': BesselJ0(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('BesselJ1', {
|
||||
'block': BesselJ1(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('BesselK0', {
|
||||
'block': BesselK0(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
Loading…
Reference in New Issue