!41797 Remove custom_ops of Bessel functions on Ascend platform

Merge pull request !41797 from hedongdong/TBE_Bessel
This commit is contained in:
i-robot 2022-09-13 07:12:45 +00:00 committed by Gitee
commit adbaa9e86f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 12 additions and 2173 deletions

View File

@ -19,16 +19,6 @@ from .batchnorm_fold2 import _batchnorm_fold2_tbe
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
from .bessel_i0 import _bessel_i0_tbe
from .bessel_i1 import _bessel_i1_tbe
from .bessel_j0 import _bessel_j0_tbe
from .bessel_j1 import _bessel_j1_tbe
from .bessel_k0 import _bessel_k0_tbe
from .bessel_k1 import _bessel_k1_tbe
from .bessel_k0e import _bessel_k0e_tbe
from .bessel_k1e import _bessel_k1e_tbe
from .bessel_y0 import _bessel_y0_tbe
from .bessel_y1 import _bessel_y1_tbe
from .correction_mul import _correction_mul_tbe from .correction_mul import _correction_mul_tbe
from .correction_mul_grad import _correction_mul_grad_tbe from .correction_mul_grad import _correction_mul_grad_tbe
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe

View File

@ -1,117 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI0 op"""
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i0_op_info = TBERegOp("BesselI0") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i0.so") \
.compute_cost(10) \
.kernel_name("bessel_i0") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_i0_op_info)
def _bessel_i0_tbe():
"""BesselI0 TBE register"""
return
A = [-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
B = [3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1]
def chebevl(x, num, coef, shape, dtype):
"""chebevl"""
broad_coef = dsl.broadcast(coef[0], shape, dtype)
broad_zero = dsl.broadcast(0, shape, dtype)
none_signal = None
for i in range(1, num):
none_signal = broad_zero
broad_zero = broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
@fusion_manager.register("bessel_i0")
def bessel_i0_compute(input_x, output, kernel_name="bessel_i0"):
"""bessel_i0_compute"""
dtype = input_x.dtype
shape = input_x.shape
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight_in = dsl.vmuls(y, 0.5)
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
y_le_eight = chebevl(y_le_eight_in, 18, A, shape, dtype)
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
y_gt_eight = chebevl(y_gt_eight_in, 7, B, shape, dtype)
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vmul(res, dsl.vexp(y))
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_i0(x, output, kernel_name="bessel_i0"):
"""bessel_i0"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_i0_compute(data_x, output, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,128 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselI1 op"""
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_i1_op_info = TBERegOp("BesselI1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_i1.so") \
.compute_cost(10) \
.kernel_name("bessel_i1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_i1_op_info)
def _bessel_i1_tbe():
"""BesselI1 TBE register"""
return
A = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1]
B = [
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1]
def chebevl(x, num, coef, shape, dtype):
"""chebevl"""
broad_coef = dsl.broadcast(coef[0], shape, dtype)
broad_zero = dsl.broadcast(0, shape, dtype)
none_signal = None
for i in range(1, num):
none_signal = broad_zero
broad_zero = broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
@fusion_manager.register("bessel_i1")
def bessel_i1_compute(input_x, output_y, kernel_name="bessel_i1"):
"""bessel_i1_compute"""
dtype = input_x.dtype
shape = input_x.shape
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight = dsl.vmul(y, chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, A, shape, dtype))
y_gt_eight = chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, B, shape, dtype)
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_i1(x, y, kernel_name="bessel_i1"):
"""bessel_i1"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_i1_compute(data_x, y, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,226 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselJ0 op"""
import te.lang.cce as tbe
import te.platform as tbe_platform
from te import tvm
from te.platform.fusion_manager import fusion_manager
from tbe import dsl
from tbe.common.utils import shape_util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_j0_op_info = TBERegOp("BesselJ0") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_j0.so") \
.compute_cost(10) \
.kernel_name("bessel_j0") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_j0_op_info)
def _bessel_j0_tbe():
"""BesselJ0 TBE register"""
return
FLOAT_16 = "float16"
FLOAT_32 = "float32"
PP = [7.96936729297347051624E-4, 8.28352392107440799803E-2,
1.23953371646414299388E0, 5.44725003058768775090E0,
8.74716500199817011941E0, 5.30324038235394892183E0,
9.99999999999999997821E-1]
PQ = [9.24408810558863637013E-4, 8.56288474354474431428E-2,
1.25352743901058953537E0, 5.47097740330417105182E0,
8.76190883237069594232E0, 5.30605288235394617618E0,
1.00000000000000000218E0]
QP = [-1.13663838898469149931E-2, -1.28252718670509318512E0,
-1.95539544257735972385E1, -9.32060152123768231369E1,
-1.77681167980488050595E2, -1.47077505154951170175E2,
-5.14105326766599330220E1, -6.05014350600728481186E0]
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
8.56430025976980587198E2, 3.88240183605401609683E3,
7.24046774195652478189E3, 5.93072701187316984827E3,
2.06209331660327847417E3, 2.42005740240291393179E2]
RP = [-4.79443220978201773821E9, 1.95617491946556577543E12,
-2.49248344360967716204E14, 9.70862251047306323952E15]
RQ = [1.00000000000000000000E0, 4.99563147152651017219E2,
1.73785401676374683123E5, 4.84409658339962045305E7,
1.11855537045356834862E10, 2.11277520115489217587E12,
3.10518229857422583814E14, 3.18121955943204943306E16,
1.71086294081043136091E18]
DR1 = 5.78318596294678452118E0
DR2 = 3.04712623436620863991E1
SQ2OPI = 7.9788456080286535587989E-1
NEG_PIO4 = -0.7853981633974483096
PI = 3.14159265358979
FIRST_ORDER = 5
LAST_ORDER = 13
FIRST_FACTOR = -1.0 / 6.0
def besselj0_cos(x):
"""cos"""
dtype = x.dtype
shape = shape_util.shape_to_list(x.shape)
# cast to type float32 when type is float16
has_improve_precision = False
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
x = dsl.cast_to(x, FLOAT_32)
dtype = FLOAT_32
has_improve_precision = True
# round the input
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
round_fp32 = dsl.cast_to(round_fp16, dtype)
input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
# the initial value one
const_res = tvm.const(1.0, dtype=dtype)
res = dsl.broadcast(const_res, shape)
# compute the rank 2
input_x_power = dsl.vmul(input_x_round, input_x_round)
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
res = dsl.vadd(res, iter_value)
# compute the rank 4~14
iter_list = (4, 6, 8, 10, 12, 14)
for i in iter_list:
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
res = dsl.vadd(res, iter_value)
# cast the dtype to float16
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def _besselj0_sin(x):
"""_besselj0_sin"""
input_x_power = dsl.vmul(x, x)
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
res = dsl.vadd(x, iter_value)
signal = FIRST_ORDER
while signal < LAST_ORDER:
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
-1.0 / (signal*(signal - 1)))
res = dsl.vadd(res, iter_value)
signal = signal + 2
return res
def besselj0_sin(x):
"""sin"""
dtype = x.dtype
shape = shape_util.shape_to_list(x.shape)
has_improve_precision = False
cast_dtype = FLOAT_16
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
has_improve_precision = True
cast_dtype = FLOAT_32
# cast to type float32 when type is float16
if dtype == FLOAT_16 and has_improve_precision:
x = tbe.cast_to(x, FLOAT_32)
pai_multiple = tbe.vmuls(x, 1 / PI)
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
# to adjust x to [-pai/2,pai/2]
x = tbe.vsub(x, tbe.vmuls(round_float, PI))
res = _besselj0_sin(x)
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
# if odd, ceil*2-round is 1,if even, the value is 0
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
odd_tensor = tbe.vsub(tensor_one, sub_value)
even_tensor = tbe.vsub(odd_tensor, tensor_one)
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
res = tbe.vmul(res, odd_even_tensor)
# cast the dtype to float16
if dtype == FLOAT_16 and has_improve_precision:
res = tbe.cast_to(res, FLOAT_16)
return res
def besselj0_polevl(x, n, coef, shape):
"""polevl"""
dtype = 'float32'
x = dsl.cast_to(x, dtype)
if n == 0:
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
return dsl.cast_to(coef_0, dtype)
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
res = dsl.vadd(dsl.vmul(besselj0_polevl(x, n-1, coef, shape), x), coef_n)
return dsl.cast_to(res, 'float32')
@fusion_manager.register("bessel_j0")
def bessel_j0_compute(x, kernel_name="bessel_j0"):
"""bessel_j0_compute"""
dtype = x.dtype
shape = shape_util.shape_to_list(x.shape)
# cast to type float32 when type is float16
has_improve_precision = False
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
x = dsl.cast_to(x, FLOAT_32)
dtype = FLOAT_32
has_improve_precision = True
y = dsl.vabs(x)
z = dsl.vmul(y, y)
y_le_five = dsl.vcmpsel(y, 1.0e-5, 'lt', dsl.vadds(dsl.vmuls(z, -0.25), 1),
dsl.vmul(dsl.vmul(dsl.vadds(z, -DR1), dsl.vadds(z, -DR2)),
dsl.vdiv(besselj0_polevl(z, 3, RP, shape), besselj0_polevl(z, 8, RQ, shape))))
s = dsl.vmuls(dsl.vrec(z), 25)
p = dsl.vdiv(besselj0_polevl(s, 6, PP, shape), besselj0_polevl(s, 6, PQ, shape))
q = dsl.vdiv(besselj0_polevl(s, 7, QP, shape), besselj0_polevl(s, 6, PQ, shape))
yn = dsl.vadds(y, NEG_PIO4)
w = dsl.vmuls(dsl.vrec(y), -5.0)
p = dsl.vadd(dsl.vmul(p, besselj0_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj0_sin(yn))))
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_j0(x, y, kernel_name="bessel_j0"):
"""bessel_j0"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_j0_compute(data_x, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,246 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselJ1 op"""
import te.lang.cce as tbe
import te.platform as tbe_platform
from te import tvm
from te.platform.fusion_manager import fusion_manager
from tbe import dsl
from tbe.common.utils import shape_util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_j1_op_info = TBERegOp("BesselJ1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_j1.so") \
.compute_cost(10) \
.kernel_name("bessel_j1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_j1_op_info)
def _bessel_j1_tbe():
"""BesselJ1 TBE register"""
return
FLOAT_16 = "float16"
FLOAT_32 = "float32"
PI = 3.14159265358979
FIRST_ORDER = 5
LAST_ORDER = 13
FIRST_FACTOR = -1.0 / 6.0
def besselj1_cos(x):
"""cos"""
dtype = x.dtype
shape = shape_util.shape_to_list(x.shape)
# cast to type float32 when type is float16
has_improve_precision = False
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
x = dsl.cast_to(x, FLOAT_32)
dtype = FLOAT_32
has_improve_precision = True
# round the input
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
round_fp32 = dsl.cast_to(round_fp16, dtype)
besselj1_input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
# the initial value one
const_res = tvm.const(1.0, dtype=dtype)
res = dsl.broadcast(const_res, shape)
# compute the rank 2
input_x_power = dsl.vmul(besselj1_input_x_round, besselj1_input_x_round)
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
res = dsl.vadd(res, iter_value)
# compute the rank 4~14
iter_list = (4, 6, 8, 10, 12, 14)
for i in iter_list:
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
res = dsl.vadd(res, iter_value)
# cast the dtype to float16
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def _besselj1_sin(x):
"""_sin"""
input_x_power = dsl.vmul(x, x)
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
besselj1_res = dsl.vadd(x, iter_value)
signal = FIRST_ORDER
while signal < LAST_ORDER:
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
-1.0 / (signal*(signal - 1)))
besselj1_res = dsl.vadd(besselj1_res, iter_value)
signal = signal + 2
return besselj1_res
def besselj1_sin(besselj1_x):
"""sin"""
dtype = besselj1_x.dtype
shape = shape_util.shape_to_list(besselj1_x.shape)
has_improve_precision = False
cast_dtype = FLOAT_16
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
has_improve_precision = True
cast_dtype = FLOAT_32
# cast to type float32 when type is float16
if dtype == FLOAT_16 and has_improve_precision:
besselj1_x = tbe.cast_to(besselj1_x, FLOAT_32)
pai_multiple = tbe.vmuls(besselj1_x, 1 / PI)
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
# to adjust x to [-pai/2,pai/2]
besselj1_x = tbe.vsub(besselj1_x, tbe.vmuls(round_float, PI))
besselj1_res = _besselj1_sin(besselj1_x)
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
# if odd, ceil*2-round is 1,if even, the value is 0
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
odd_tensor = tbe.vsub(tensor_one, sub_value)
even_tensor = tbe.vsub(odd_tensor, tensor_one)
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
besselj1_res = tbe.vmul(besselj1_res, odd_even_tensor)
# cast the dtype to float16
if dtype == FLOAT_16 and has_improve_precision:
besselj1_res = tbe.cast_to(besselj1_res, FLOAT_16)
return besselj1_res
PP = [7.62125616208173112003E-4, 7.31397056940917570436E-2,
1.12719608129684925192E0, 5.11207951146807644818E0,
8.42404590141772420927E0, 5.21451598682361504063E0,
1.00000000000000000254E0]
PQ = [5.71323128072548699714E-4, 6.88455908754495404082E-2,
1.10514232634061696926E0, 5.07386386128601488557E0,
8.39985554327604159757E0, 5.20982848682361821619E0,
9.99999999999999997461E-1]
QP = [5.10862594750176621635E-2, 4.98213872951233449420E0,
7.58238284132545283818E1, 3.66779609360150777800E2,
7.10856304998926107277E2, 5.97489612400613639965E2,
2.11688757100572135698E2, 2.52070205858023719784E1]
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
8.56430025976980587198E2, 3.88240183605401609683E3,
7.24046774195652478189E3, 5.93072701187316984827E3,
2.06209331660327847417E3, 2.42005740240291393179E2]
RP = [-8.99971225705559398224E8, 4.52228297998194034323E11,
-7.27494245221818276015E13, 3.68295732863852883286E15]
RQ = [1.00000000000000000000E0, 6.20836478118054335476E2,
2.56987256757748830383E5, 8.35146791431949253037E7,
2.21511595479792499675E10, 4.74914122079991414898E12,
7.84369607876235854894E14, 8.95222336184627338078E16,
5.32278620332680085395E18]
Z1 = 1.46819706421238932572E1
Z2 = 4.92184563216946036703E1
NEG_THPIO4 = -2.35619449019234492885
SQ2OPI = 7.9788456080286535587989E-1
def polevl(x, n, coef, shape):
"""polevl"""
dtype = 'float32'
x = dsl.cast_to(x, dtype)
if n == 0:
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
return dsl.cast_to(coef_0, dtype)
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
res = dsl.vadd(dsl.vmul(polevl(x, n-1, coef, shape), x), coef_n)
return dsl.cast_to(res, 'float32')
@fusion_manager.register("bessel_j1")
def bessel_j1_compute(x, kernel_name="bessel_j1"):
"""bessel_j1_compute"""
dtype = x.dtype
shape = shape_util.shape_to_list(x.shape)
# cast to type float32 when type is float16
has_improve_precision = False
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
x = dsl.cast_to(x, FLOAT_32)
dtype = FLOAT_32
has_improve_precision = True
y = dsl.vabs(x)
z = dsl.vmul(y, y)
y_le_five = dsl.vdiv(polevl(z, 3, RP, shape), polevl(z, 8, RQ, shape))
y_le_five = dsl.vmul(y_le_five, dsl.vmul(x, dsl.vmul(dsl.vadds(z, -Z1), dsl.vadds(z, -Z2))))
s = dsl.vmuls(dsl.vrec(z), 25)
p = dsl.vdiv(polevl(s, 6, PP, shape), polevl(s, 6, PQ, shape))
q = dsl.vdiv(polevl(s, 7, QP, shape), polevl(s, 7, QQ, shape))
yn = dsl.vadds(y, NEG_THPIO4)
w = dsl.vmuls(dsl.vrec(y), -5.0)
p = dsl.vadd(dsl.vmul(p, besselj1_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj1_sin(yn))))
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
y_gt_five = dsl.vcmpsel(x, 0.0, 'lt', dsl.vmuls(y_gt_five, -1.0), y_gt_five)
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_j1(x, y, kernel_name="bessel_j1"):
"""
To do: Implement the operator by referring to the
TBE Operator Development Guide.
"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_j1_compute(data_x, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,186 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesseK0 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
bessel_k0_op_info = TBERegOp("BesselK0") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_k0.so") \
.compute_cost(10) \
.kernel_name("bessel_k0") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_k0_op_info)
def _bessel_k0_tbe():
"""BesselK0 TBE register"""
return
A = [1.37446543561352307156E-16,
4.25981614279661018399E-14,
1.03496952576338420167E-11,
1.90451637722020886025E-9,
2.53479107902614945675E-7,
2.28621210311945178607E-5,
1.26461541144692592338E-3,
3.59799365153615016266E-2,
3.44289899924628486886E-1,
-5.35327393233902768720E-1]
B = [
5.30043377268626276149E-18, -1.64758043015242134646E-17,
5.21039150503902756861E-17, -1.67823109680541210385E-16,
5.51205597852431940784E-16, -1.84859337734377901440E-15,
6.34007647740507060557E-15, -2.22751332699166985548E-14,
8.03289077536357521100E-14, -2.98009692317273043925E-13,
1.14034058820847496303E-12, -4.51459788337394416547E-12,
1.85594911495471785253E-11, -7.95748924447710747776E-11,
3.57739728140030116597E-10, -1.69753450938905987466E-9,
8.57403401741422608519E-9, -4.66048989768794782956E-8,
2.76681363944501510342E-7, -1.83175552271911948767E-6,
1.39498137188764993662E-5, -1.28495495816278026384E-4,
1.56988388573005337491E-3, -3.14481013119645005427E-2,
2.44030308206595545468E0]
AA = [-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1
]
BB = [
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1]
MAXNUM = 4294967295.0
TWO = 2.0
def chebevl(x, n, coef, shape, dtype):
"""chebevl"""
k0_broad_coef = dsl.broadcast(coef[0], shape, dtype)
k0_broad_zero = dsl.broadcast(0, shape, dtype)
k0_none_signal = None
for i in range(1, n):
k0_none_signal = k0_broad_zero
k0_broad_zero = k0_broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
k0_broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, k0_broad_zero), coef_i), k0_none_signal)
return dsl.vmuls(dsl.vsub(k0_broad_coef, k0_none_signal), 0.5)
def bessel_i0_compute(input_x):
"""bessel_i0_compute"""
dtype = input_x.dtype
shape = input_x.shape
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight_in = dsl.vmuls(y, 0.5)
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
y_le_eight = chebevl(y_le_eight_in, 30, AA, shape, dtype)
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
y_gt_eight = chebevl(y_gt_eight_in, 25, BB, shape, dtype)
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vmul(res, dsl.vexp(y))
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
@fusion_manager.register("bessel_k0")
def bessel_k0_compute(input_x, output_y, kernel_name="bessel_k0"):
"""bessel_k0_compute"""
shape = input_x.shape
dtype = input_x.dtype
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
x_le_two = chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 10, A, shape, dtype)
x_le_two = dsl.vadd(dsl.vmul(bessel_i0_compute(input_x),
dsl.vmuls(dsl.vlog(dsl.vmuls(input_x, 0.5)), -1.0)), x_le_two)
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)),
chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x), 8.0), -2.0), 25, B, shape, dtype)),
(dsl.vrsqrt(input_x)))
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_k0(x, output, kernel_name="bessel_k0"):
"""bessel_k0"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_k0_compute(data_x, output, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,187 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesseK0e op"""
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_k0e_op_info = TBERegOp("BesselK0e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_k0e.so") \
.compute_cost(10) \
.kernel_name("bessel_k0e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_k0e_op_info)
def _bessel_k0e_tbe():
"""BesselK0e TBE register"""
return
A = [1.37446543561352307156E-16,
4.25981614279661018399E-14,
1.03496952576338420167E-11,
1.90451637722020886025E-9,
2.53479107902614945675E-7,
2.28621210311945178607E-5,
1.26461541144692592338E-3,
3.59799365153615016266E-2,
3.44289899924628486886E-1,
-5.35327393233902768720E-1]
B = [
5.30043377268626276149E-18, -1.64758043015242134646E-17,
5.21039150503902756861E-17, -1.67823109680541210385E-16,
5.51205597852431940784E-16, -1.84859337734377901440E-15,
6.34007647740507060557E-15, -2.22751332699166985548E-14,
8.03289077536357521100E-14, -2.98009692317273043925E-13,
1.14034058820847496303E-12, -4.51459788337394416547E-12,
1.85594911495471785253E-11, -7.95748924447710747776E-11,
3.57739728140030116597E-10, -1.69753450938905987466E-9,
8.57403401741422608519E-9, -4.66048989768794782956E-8,
2.76681363944501510342E-7, -1.83175552271911948767E-6,
1.39498137188764993662E-5, -1.28495495816278026384E-4,
1.56988388573005337491E-3, -3.14481013119645005427E-2,
2.44030308206595545468E0]
AA = [-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1
]
BB = [
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1]
MAXNUM = 4294967295.0
TWO = 2.0
def chebevl(x, n, coef, shape, dtype):
"""chebevl"""
broad_coef = dsl.broadcast(coef[0], shape, dtype)
broad_zero = dsl.broadcast(0, shape, dtype)
none_signal = None
for i in range(1, n):
none_signal = broad_zero
broad_zero = broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
def bessel_i0_compute(input_x):
"""bessel_i0_compute"""
dtype = input_x.dtype
shape = input_x.shape
k0e_has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
k0e_has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight_in = dsl.vmuls(y, 0.5)
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
y_le_eight = chebevl(y_le_eight_in, 30, AA, shape, dtype)
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
y_gt_eight = chebevl(y_gt_eight_in, 25, BB, shape, dtype)
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vmul(res, dsl.vexp(y))
if k0e_has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
@fusion_manager.register("bessel_k0e")
def bessel_k0e_compute(input_x, output_y, kernel_name="bessel_k0e"):
"""bessel_k0e_compute"""
shape = input_x.shape
dtype = input_x.dtype
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
x_le_two = chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 10, A, shape, dtype)
x_le_two = dsl.vadd(dsl.vmul(bessel_i0_compute(input_x), dsl.vmuls(dsl.vlog(dsl.vmuls(input_x, 0.5)), -1.0)),
x_le_two)
x_le_two = dsl.vmul(dsl.vexp(input_x), x_le_two)
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)), chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x),
8.0), -2.0), 25, B,
shape, dtype)), (dsl.vrsqrt(input_x)))
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_k0e(x, output, kernel_name="bessel_k0e"):
"""bessel_k0e"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_k0e_compute(data_x, output, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,178 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesseK1 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
bessel_k1_op_info = TBERegOp("BesselK1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_k1.so") \
.compute_cost(10) \
.kernel_name("bessel_k1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_k1_op_info)
def _bessel_k1_tbe():
"""BesselK1 TBE register"""
return
AA = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1]
BB = [
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1]
A = [-7.02386347938628759343E-18, -2.42744985051936593393E-15,
-6.66690169419932900609E-13, -1.41148839263352776110E-10,
-2.21338763073472585583E-8, -2.43340614156596823496E-6,
-1.73028895751305206302E-4, -6.97572385963986435018E-3,
-1.22611180822657148235E-1, -3.53155960776544875667E-1,
1.52530022733894777053E0]
B = [-5.75674448366501715755E-18, 1.79405087314755922667E-17,
-5.68946255844285935196E-17, 1.83809354436663880070E-16,
-6.05704724837331885336E-16, 2.03870316562433424052E-15,
-7.01983709041831346144E-15, 2.47715442448130437068E-14,
-8.97670518232499435011E-14, 3.34841966607842919884E-13,
-1.28917396095102890680E-12, 5.13963967348173025100E-12,
-2.12996783842756842877E-11, 9.21831518760500529508E-11,
-4.19035475934189648750E-10, 2.01504975519703286596E-9,
-1.03457624656780970260E-8, 5.74108412545004946722E-8,
-3.50196060308781257119E-7, 2.40648494783721712015E-6,
-1.93619797416608296024E-5, 1.95215518471351631108E-4,
-2.85781685962277938680E-3, 1.03923736576817238437E-1,
2.72062619048444266945E0]
MAX_NUM = 4294967295.0
NUM_TWO = 2.0
def bessel_k1_chebevl(x, n, coef, shape, dtype):
"""chebevl"""
k1_broad_coef = dsl.broadcast(coef[0], shape, dtype)
k1_broad_zero = dsl.broadcast(0, shape, dtype)
k1_none_signal = None
for i in range(1, n):
k1_none_signal = k1_broad_zero
k1_broad_zero = k1_broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
k1_broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, k1_broad_zero), coef_i), k1_none_signal)
return dsl.vmuls(dsl.vsub(k1_broad_coef, k1_none_signal), 0.5)
def bessel_i1_compute(input_x):
"""bessel_i1_compute"""
dtype = input_x.dtype
shape = input_x.shape
k1_has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
k1_has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight = dsl.vmul(y, bessel_k1_chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, AA, shape, dtype))
y_gt_eight = dsl.vmul(bessel_k1_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, BB,
shape, dtype), dsl.vrsqrt(y))
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
if k1_has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
@fusion_manager.register("bessel_k1")
def bessel_k1_compute(input_x, output_y, kernel_name="bessel_k1"):
"""bessel_k1_compute"""
shape = input_x.shape
dtype = input_x.dtype
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
x_le_two = dsl.vdiv(bessel_k1_chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 11, A, shape, dtype), input_x)
x_le_two = dsl.vadd(dsl.vmul(bessel_i1_compute(input_x), dsl.vlog(dsl.vmuls(input_x, 0.5))), x_le_two)
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAX_NUM, x_le_two)
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)),
bessel_k1_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x), 8.0), -2.0), 25,
B, shape, dtype)), (dsl.vrsqrt(input_x)))
res = dsl.vcmpsel(input_x, NUM_TWO, 'le', x_le_two, x_gt_two)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_k1(x, output, kernel_name="bessel_k1"):
"""bessel_k1"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_k1_compute(data_x, output, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,178 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesseK1e op"""
from tbe import dsl
from te import tvm
from te.platform.fusion_manager import fusion_manager
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_k1e_op_info = TBERegOp("BesselK1e") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_k1e.so") \
.compute_cost(10) \
.kernel_name("bessel_k1e") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_k1e_op_info)
def _bessel_k1e_tbe():
"""BesselK1e TBE register"""
return
AA = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1]
BB = [
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1]
A = [-7.02386347938628759343E-18, -2.42744985051936593393E-15,
-6.66690169419932900609E-13, -1.41148839263352776110E-10,
-2.21338763073472585583E-8, -2.43340614156596823496E-6,
-1.73028895751305206302E-4, -6.97572385963986435018E-3,
-1.22611180822657148235E-1, -3.53155960776544875667E-1,
1.52530022733894777053E0]
B = [-5.75674448366501715755E-18, 1.79405087314755922667E-17,
-5.68946255844285935196E-17, 1.83809354436663880070E-16,
-6.05704724837331885336E-16, 2.03870316562433424052E-15,
-7.01983709041831346144E-15, 2.47715442448130437068E-14,
-8.97670518232499435011E-14, 3.34841966607842919884E-13,
-1.28917396095102890680E-12, 5.13963967348173025100E-12,
-2.12996783842756842877E-11, 9.21831518760500529508E-11,
-4.19035475934189648750E-10, 2.01504975519703286596E-9,
-1.03457624656780970260E-8, 5.74108412545004946722E-8,
-3.50196060308781257119E-7, 2.40648494783721712015E-6,
-1.93619797416608296024E-5, 1.95215518471351631108E-4,
-2.85781685962277938680E-3, 1.03923736576817238437E-1,
2.72062619048444266945E0]
MAXNUM = 4294967295.0
TWO = 2.0
def bessel_k1e_chebevl(x, n, coef, shape, dtype):
"""chebevl"""
broad_coef = dsl.broadcast(coef[0], shape, dtype)
broad_zero = dsl.broadcast(0, shape, dtype)
none_signal = None
for i in range(1, n):
none_signal = broad_zero
broad_zero = broad_coef
coef_i = dsl.broadcast(coef[i], shape, dtype)
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
def bessel_i1_compute(input_x):
"""bessel_i1_compute"""
dtype = input_x.dtype
shape = input_x.shape
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
y = dsl.vabs(input_x)
y_le_eight = dsl.vmul(y, bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, AA, shape, dtype))
y_gt_eight = dsl.vmul(bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, BB, shape, dtype),
dsl.vrsqrt(y))
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
@fusion_manager.register("bessel_k1e")
def bessel_k1e_compute(input_x, output_y, kernel_name="bessel_k1e"):
"""bessel_k1e_compute"""
shape = input_x.shape
dtype = input_x.dtype
has_improve_precision = False
if dtype != "float32":
input_x = dsl.cast_to(input_x, "float32")
dtype = "float32"
has_improve_precision = True
x_le_two = dsl.vdiv(bessel_k1e_chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 11, A, shape, dtype), input_x)
x_le_two = dsl.vadd(dsl.vmul(bessel_i1_compute(input_x), dsl.vlog(dsl.vmuls(input_x, 0.5))), x_le_two)
x_le_two = dsl.vmul(x_le_two, dsl.vexp(input_x))
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)), bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(
dsl.vrec(input_x), 8.0), -2.0), 25, B, shape, dtype)), (dsl.vrsqrt(input_x)))
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
if has_improve_precision:
res = dsl.cast_to(res, "float16")
return res
def bessel_k1e(x, output, kernel_name="bessel_k1e"):
"""bessel_k1e"""
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
res = bessel_k1e_compute(data_x, output, kernel_name)
# auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# operator build
config = {"name": kernel_name,
"tensor_list": [data_x, res]}
dsl.build(schedule, config)

View File

@ -1,424 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselY0 op"""
import te.lang.cce as tbe
from te import tvm
from te.utils import para_check
from te.utils import shape_util
from tbe.common.platform import api_check_support
from tbe.common.register import register_op_compute
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_y0_op_info = TBERegOp("BesselY0") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_y0.so") \
.compute_cost(10) \
.kernel_name("bessel_y0") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_y0_op_info)
def _bessel_y0_tbe():
"""BesselY0 TBE register"""
return
ITR_A1 = (57568490574.0, -13362590354.0, 651619640.7, -11214424.18, 77392.33017, -184.9052456)
ITR_A2 = (57568490411.0, 1029532985.0, 9494680.718, 59272.64853, 267.8532712, 1.0)
ITR_A33 = (1.0, -0.1098628627e-2, 0.2734510407e-4, -0.2073370639e-5, 0.2093887211e-6)
ITR_A44 = (-0.1562499995e-1, 0.1430488765e-3, -0.6911147651e-5, 0.7621095161e-6, 0.934935152e-7)
ITR_A5 = (-2957821389.0, 7062834065.0, -512359803.6, 10879881.29, -86327.92757, 228.4622733)
ITR_A6 = (40076544269.0, 745249964.8, 7189466.438, 47447.26470, 226.1030244, 1.0)
ITR_A = (-0.785398164, 0.636619772)
LEN_A1256 = 6
LEN_A34 = 5
EIGHT = 8.0
ZERO = 0
NEG_ONE = -1.0
ONE = 1.0
PI2 = 6.2831853071796
HALF_PI = 1.5707963267948966192313216916398
BOUNDARY_1 = 0.70710678118654752440084436210485
# Taylor coefficient
COEF = (1.0,
0.16666666666666666666666666666667,
0.075,
0.04464285714285714285714285714286,
0.03038194444444444444444444444444,
0.02237215909090909090909090909091,
0.01735276442307692307692307692308,
0.01396484375)
# TAYLOR COUNT
TAYLOR_COUNT = 7
# negative min float16 value
NEG_MIN_FP16 = -2 ** (-24)
# min float16 * 2
TWO_MIN_FP16 = 2 ** (-23)
def _taylor_compute(data_x, x_square=None):
"""_taylor_compute"""
if x_square is None:
x_square = tbe.vmul(data_x, data_x)
res = tbe.vmuls(x_square, tvm.const(COEF[TAYLOR_COUNT], "float32"))
for temp in reversed(range(TAYLOR_COUNT)):
res = tbe.vadds(res, tvm.const(COEF[temp], "float32"))
if temp == 0:
res = tbe.vmul(res, data_x)
else:
res = tbe.vmul(x_square, res)
return res
def acos_compute(x):
"""
do element-wise acos compute using asin op
acos(x) = HALF_PI - asin(x)
asin(x) = | arcsin(sqrt(1-x^2)) - HALF_PI, x belongs to (-1, -2^(-0.5))
| the 15th order taylor expansion, x belongs to (-2^(-0.5),
| 2^(-0.5))
| HALF_PI - arcsin(sqrt(1-x^2)), x belongs to (2^(-0.5), 1)
Parameters:
----------
x: the placeholder of data input
y : the dict of output
Returns : A Tensor. Has the same type as x.
-------
"""
shape = x.shape
dtype = x.dtype
# Change dtype to float32
if (dtype in ('float16', 'double')) and \
api_check_support("te.lang.cce.vadd", "float32"):
x = tbe.cast_to(x, "float32")
# to fix bug for input data is 1.0
x = tbe.vadds(x, NEG_MIN_FP16)
# Sign mask
sign = tbe.vcmpsel(x, 0, 'lt', NEG_ONE, ONE)
# All positive
x = tbe.vmul(x, sign)
# x belongs to (0, 2^(-0.5))
if api_check_support("te.lang.cce.vmins", x.dtype):
choice_1 = tbe.vmins(x, tvm.const(BOUNDARY_1, x.dtype))
else:
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, x.dtype), shape)
choice_1 = tbe.vmin(x, boundary_mask1)
if api_check_support("te.lang.cce.vsubs", choice_1.dtype):
choice_1 = tbe.vsubs(choice_1, tvm.const(BOUNDARY_1, choice_1.dtype))
else:
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, choice_1.dtype), shape)
choice_1 = tbe.vsub(choice_1, boundary_mask1)
choice_1 = tbe.vmuls(tbe.floor(choice_1), NEG_ONE)
res_1 = _taylor_compute(x)
res_1 = tbe.vmul(res_1, choice_1)
# x belongs to (2^(-0.5), 1)
choice_2 = tbe.vmuls(choice_1, tvm.const(NEG_ONE, x.dtype))
choice_2 = tbe.vadds(choice_2, tvm.const(ONE, x.dtype))
# to fix bug for input data is 1.0
x = tbe.vadds(x, TWO_MIN_FP16)
res_2 = tbe.vmul(x, x)
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
res_2 = tbe.vadds(res_2, tvm.const(ONE, x.dtype))
res_2_sqrt = tbe.vsqrt(res_2, 1)
res_2 = _taylor_compute(res_2_sqrt, res_2)
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
res_2 = tbe.vadds(res_2, tvm.const(HALF_PI, x.dtype))
res_2 = tbe.vmul(res_2, choice_2)
# Restore sign of asin
res_1 = tbe.vadd(res_1, res_2)
res_1 = tbe.vmul(res_1, sign)
res_1 = tbe.vmuls(res_1, tvm.const(NEG_ONE, x.dtype))
res_1 = tbe.vadds(res_1, tvm.const(HALF_PI, x.dtype))
return res_1
def asin_compute(x):
"""
do element-wise asin compute
asin(x) = | arcsin(sqrt(1-x^2)) - HALF_PI, x belongs to (-1, -2^(-0.5))
| the 15th order taylor expansion, x belongs to (-2^(-0.5), 2^(-0.5))
| HALF_PI - arcsin(sqrt(1-x^2)), x belongs to (2^(-0.5), 1)
Parameters:
----------
x: the placeholder of data input
y : the dict of output
Returns : A Tensor. Has the same type as data_input.
-------
"""
shape = x.shape
dtype = x.dtype
# Change dtype to float32
if (dtype in ('float16', 'double')) and api_check_support("te.lang.cce.vadd", "float32"):
x = tbe.cast_to(x, "float32")
# Sign mask
bessely0_sign = tbe.vcmpsel(x, 0, 'lt', NEG_ONE, ONE)
# All positive
x = tbe.vmul(x, bessely0_sign)
# x belongs to (0, 2^(-0.5))
if api_check_support("te.lang.cce.vmins", x.dtype):
y0_choice_1 = tbe.vmins(x, tvm.const(BOUNDARY_1, x.dtype))
else:
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, x.dtype), shape)
y0_choice_1 = tbe.vmin(x, boundary_mask1)
if api_check_support("te.lang.cce.vsubs", y0_choice_1.dtype):
y0_choice_1 = tbe.vsubs(y0_choice_1, tvm.const(BOUNDARY_1, y0_choice_1.dtype))
else:
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, y0_choice_1.dtype), shape)
y0_choice_1 = tbe.vsub(y0_choice_1, boundary_mask1)
y0_choice_1 = tbe.vmuls(tbe.floor(y0_choice_1), NEG_ONE)
res_1 = _taylor_compute(x)
res_1 = tbe.vmul(res_1, y0_choice_1)
# x belongs to (2^(-0.5), 1)
choice_2 = tbe.vmuls(y0_choice_1, tvm.const(NEG_ONE, x.dtype))
choice_2 = tbe.vadds(choice_2, tvm.const(ONE, x.dtype))
res_2 = tbe.vmul(x, x)
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
res_2 = tbe.vadds(res_2, tvm.const(ONE, x.dtype))
res_2_sqrt = tbe.vsqrt(res_2)
res_2 = _taylor_compute(res_2_sqrt, res_2)
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
res_2 = tbe.vadds(res_2, tvm.const(HALF_PI, x.dtype))
res_2 = tbe.vmul(res_2, choice_2)
# Restore sign
res_1 = tbe.vadd(res_1, res_2)
res_1 = tbe.vmul(res_1, bessely0_sign)
return res_1
def _besselj0(x):
"""
Algrithm:
y = x * x;
ans1 = 57568490574.0 + y * (-13362590354.0 + y * (651619640.7 + y * (-11214424.18 + y *
(77392.33017 + y * (-184.9052456)))))
ans2 = 57568490411.0 + y * (1029532985.0 + y * (9494680.718 + y * (59272.64853 + y * (267.8532712 + y * 1.0))))
ans = ans1 / ans2 (x < 8.0)
z = 8.0 / x
y = z * z
xx = ax - 0.785398164;
ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
ans2 = -0.1562499995e-1 + y * (0.1430488765e-3 + y * (-0.6911147651e-5 +
y * (0.7621095161e-6 - y * 0.934935152e-7)));
ans = sqrt(0.636619772 / ax) * (cos(xx) * ans1 - z * sin(xx) * ans2), (x >= 8.0)
Parameters
----------
x: the placeholder of data input
y : the dict of output
Returns
-------
A tensor. Has the same type as x.
"""
jax = tbe.vabs(x)
tensor_eight = tbe.broadcast(tvm.const(EIGHT, jax.dtype), jax.shape)
first_res = tbe.vmax(jax, tensor_eight)
jz = tbe.vdiv(tensor_eight, first_res)
jy = tbe.vmul(jz, jz)
jxx = tbe.vadds(first_res, ITR_A[0])
jans1 = tbe.vmuls(jy, tvm.const(ITR_A33[LEN_A34 - 1]))
jans1 = tbe.vadds(jans1, ITR_A33[LEN_A34 - 2])
for index in reversed(range(LEN_A34 - 2)):
jans1 = tbe.vmul(jans1, jy)
jans1 = tbe.vadds(jans1, ITR_A33[index])
jans2 = tbe.vmuls(jy, tvm.const(ITR_A44[LEN_A34 - 1]))
jans2 = tbe.vadds(jans2, ITR_A44[LEN_A34 - 2])
for index in reversed(range(LEN_A34 - 2)):
jans2 = tbe.vmul(jans2, jy)
jans2 = tbe.vadds(jans2, ITR_A44[index])
jansres1 = tbe.vmul(tbe.vsqrt(tbe.vmuls(tbe.vrec(first_res), ITR_A[1])),
tbe.vsub(tbe.vmul(acos_compute(jxx), jans1), tbe.vmul(jz, tbe.vmul(asin_compute(jxx), jans2))))
first_res = tbe.vmin(jax, tensor_eight)
jy = tbe.vmul(first_res, first_res)
jans1 = tbe.vmuls(jy, tvm.const(ITR_A1[LEN_A1256 - 1]))
jans1 = tbe.vadds(jans1, ITR_A1[LEN_A1256 - 2])
for index in reversed(range(LEN_A1256 - 2)):
jans1 = tbe.vmul(jans1, jy)
jans1 = tbe.vadds(jans1, ITR_A1[index])
jans2 = tbe.vmuls(jy, tvm.const(ITR_A2[LEN_A1256 - 1]))
jans2 = tbe.vadds(jans2, ITR_A2[LEN_A1256 - 2])
for index in reversed(range(LEN_A1256 - 2)):
jans2 = tbe.vmul(jans2, jy)
jans2 = tbe.vadds(jans2, ITR_A2[index])
jansres2 = tbe.vdiv(jans1, jans2)
res = tbe.vcmpsel(jax, tensor_eight, 'lt', jansres2, jansres1)
return res
@register_op_compute("bessel_y0")
def bessel_y0_compute(x, y, kernel_name="bessel_y0"):
"""
Algrithm:
y = x * x;
ans1 = -2957821389.0 + y * (7062834065.0 + y * (-512359803.6 + y * (10879881.29 +
y * (-86327.92757 + y * 228.4622733))))
ans2 = 40076544269.0 + y * (745249964.8 + y * (7189466.438 + y * (47447.26470 + y * (226.1030244 + y * 1.0))))
ans = (ans1 / ans2) + 0.636619772 * bessj0(x) * math.log(x), (x < 8.0)
z = 8.0 / x
y = z * z
xx = x - 0.785398164
ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
ans2 = -0.1562499995e-1 + y * (0.1430488765e-3 + y * (-0.6911147651e-5 +
y * (0.7621095161e-6 + y * (-0.934945152e-7))))
ans = math.sqrt(0.636619772 / x) * (math.asin(xx) * ans1 + z * math.acos(xx) * ans2), (x >= 8.0)
Parameters
----------
x: the placeholder of data input
y : the dict of output
kernel_name : cce kernel name, default value is "bessel_y0"
Returns
-------
A tensor. Has the same type as x.
"""
dtype_input = x.dtype
# chose the type of data in begin
if (dtype_input in ('float16', 'double', 'float32')) \
and api_check_support("te.lang.cce.vadd", "float32"):
x = tbe.cast_to(x, "float32")
else:
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
x = tbe.vabs(x)
tensor_eight = tbe.broadcast(tvm.const(EIGHT, x.dtype), x.shape)
first_res = tbe.vmin(x, tensor_eight)
yy = tbe.vmul(first_res, first_res)
ans1 = tbe.vmuls(yy, tvm.const(ITR_A5[LEN_A1256 - 1]))
ans1 = tbe.vadds(ans1, ITR_A5[LEN_A1256 - 2])
for index in reversed(range(LEN_A1256 - 2)):
ans1 = tbe.vmul(ans1, yy)
ans1 = tbe.vadds(ans1, ITR_A5[index])
ans2 = tbe.vmuls(yy, tvm.const(ITR_A6[LEN_A1256 - 1]))
ans2 = tbe.vadds(ans2, ITR_A6[LEN_A1256 - 2])
for index in reversed(range(LEN_A1256 - 2)):
ans2 = tbe.vmul(ans2, yy)
ans2 = tbe.vadds(ans2, ITR_A6[index])
res1 = tbe.vadd(tbe.vdiv(ans1, ans2), tbe.vmuls(tbe.vmul(_besselj0(first_res), tbe.vlog(first_res)), ITR_A[1]))
first_res = tbe.vmax(x, tensor_eight)
z = tbe.vdiv(tensor_eight, first_res)
y = tbe.vmul(z, z)
xx = tbe.vadds(first_res, ITR_A[0])
ans1 = tbe.vmuls(y, tvm.const(ITR_A33[LEN_A34 - 1]))
ans1 = tbe.vadds(ans1, ITR_A33[LEN_A34 - 2])
for index in reversed(range(LEN_A34 - 2)):
ans1 = tbe.vmul(ans1, y)
ans1 = tbe.vadds(ans1, ITR_A33[index])
ans2 = tbe.vmuls(y, tvm.const(ITR_A44[LEN_A34 - 1]))
ans2 = tbe.vadds(ans2, ITR_A44[LEN_A34 - 2])
for index in reversed(range(LEN_A34 - 2)):
ans2 = tbe.vmul(ans2, y)
ans2 = tbe.vadds(ans2, ITR_A44[index])
res2 = tbe.vmul(tbe.vsqrt(tbe.vmuls(tbe.vrec(first_res), ITR_A[1])),
tbe.vadd(tbe.vmul(asin_compute(xx), ans1),
tbe.vmul(z, tbe.vmul(acos_compute(xx), ans2))))
res = tbe.vcmpsel(x, tensor_eight, 'lt', res1, res2)
# Restore dtype
if dtype_input == "float16":
res = tbe.cast_to(res, "float16")
if dtype_input == "double":
res = tbe.cast_to(res, "double")
return res
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
def bessel_y0(x, y, kernel_name="bessel_y0"):
"""
Computes the Bessel y0 function of x element-wise.
Parameters
----------
x: only support float16, float32, double
y : output
kernel_name : cce kernel name, default value is "bessel_y0"
Returns
-------
None
"""
shape_input = x.get("shape")
dtype_input = x.get("dtype")
para_check.check_shape(shape_input, param_name="x")
shape_input, _ = shape_util.refine_shape_axes(shape_input, [])
check_list = ("float16", "float32", "double")
para_check.check_dtype(dtype_input, check_list, param_name="x")
input_dtype = dtype_input.lower()
data = tvm.placeholder(shape_input, dtype=input_dtype, name="data_input")
res = bessel_y0_compute(data, y, kernel_name)
with tvm.target.cce():
sch = tbe.auto_schedule(res)
config = {"name": kernel_name,
"print_ir": False,
"tensor_list": (data, res),
"bool_storage_as_1bit": False}
tbe.cce_build_code(sch, config)

View File

@ -1,281 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BesselY1 op"""
import tbe.dsl as dsl
from tbe.common.platform import api_check_support
from tbe.common.register import register_op_compute
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
from te import tvm
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bessel_y1_op_info = TBERegOp("BesselY1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bessel_y1.so") \
.compute_cost(10) \
.kernel_name("bessel_y1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(bessel_y1_op_info)
def _bessel_y1_tbe():
"""BesselY1 TBE register"""
return
CONST_LIMIT = 8.0
ITR_BEFORE1_J = [72362614232.0, -7895059235.0, 242396853.1, -2972611.439, 15704.48260, -30.16036606]
ITR_BEFORE2_J = [144725228442.0, 2300535178.0, 18583304.74, 99447.43394, 376.9991397, 1.0]
ITR_AFTER1_J = [1.0, 0.183105e-2, -0.3516396496e-4, 0.2457520174e-5, -0.240337019e-6]
ITR_AFTER2_J = [0.04687499995, -0.2002690873e-3, 0.8449199096e-5, -0.88228987e-6, 0.105787412e-6]
ITR_BEFORE1 = [-0.4900604943e13, 0.1275274390e13, -0.5153438139e11, 0.7349264551e9,
-0.4237922726e7, 0.8511937935e4]
ITR_BEFORE2 = [0.2499580570e14, 0.4244419664e12, 0.3733650367e10, 0.2245904002e8,
0.1020426050e6, 0.3549632885e3, 1.0]
THREE_QUARTERS_PI = 2.356194490192345
ITR_AFTER1 = [1.0, 0.183105e-2, -0.3516396496e-4, 0.2457520174e-5, -0.240337019e-6]
ITR_AFTER2 = [0.04687499995, -0.2002690873e-3, 0.8449199096e-5, -0.88228987e-6, 0.105787412e-6]
TOW_OVER_PI = 0.636619772367581
PI = 3.141592653589793
def angle_trans_cal(x):
"""angle_trans_cal"""
consult = dsl.vdiv(x, dsl.broadcast(tvm.const(PI * 2), x.shape))
floor_consult = dsl.cast_to(dsl.floor(consult), 'float32')
fixed_x = dsl.vsub(x, dsl.vmuls(floor_consult, PI * 2))
coe = -0.707106781186548 # -sqrt(2)/2
quarter_x = dsl.vmuls(x, 0.25)
sin_quar_x, cos_quar_x = cordic(quarter_x)
cos_quar_x2 = dsl.vmul(cos_quar_x, cos_quar_x)
sin_quar_x2 = dsl.vmul(sin_quar_x, sin_quar_x)
cos_quar_x4 = dsl.vmul(cos_quar_x2, cos_quar_x2)
temp_res1 = dsl.vadds(dsl.vadd(dsl.vmuls(cos_quar_x2, -8.0), dsl.vmuls(cos_quar_x4, 8.0)), 1)
temp_res2 = dsl.vmuls(sin_quar_x2, 2.0)
temp_res2 = dsl.vadds(temp_res2, -1.0)
temp_res2 = dsl.vmul(temp_res2, sin_quar_x)
temp_res2 = dsl.vmul(temp_res2, cos_quar_x)
temp_res2 = dsl.vmuls(temp_res2, -4.0)
sin_res = dsl.vmuls(dsl.vadd(temp_res1, temp_res2), coe)
cos_res = dsl.vmuls(dsl.vsub(temp_res2, temp_res1), coe)
return sin_res, cos_res
def cordic(angle):
"""cordic"""
shape = angle.shape
dtype = angle.dtype
ceof = [1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64, 1 / 128, 1 / 256,
1 / 512, 1 / 1024, 1 / 2048, 1 / 4096, 1 / 8192, 1 / 16384, 1 / 32768,
1 / 65536, 1 / 131072, 1 / 262144, 1 / 524288, 1 / 1048576]
dangle = [45, 26.565051177078, 14.0362434679265, 7.1250163489018,
3.57633437499735, 1.78991060824607, 0.8951737102111,
0.4476141708606, 0.2238105003685, 0.1119056770662,
0.0559528918938, 0.027976452617, 0.01398822714227,
0.006994113675353, 0.003497056950704, 0.001748528426980,
0.000874264213694, 0.000437132106872, 0.000218566053439,
0.000109283026720, 0.000054641513360]
k = 0.60725293500888
x = dsl.broadcast(1.0, shape, dtype)
y = dsl.broadcast(0.0, shape, dtype)
z = angle
for i in range(10):
ones = dsl.broadcast(1.0, shape, dtype)
nones = dsl.broadcast(-1.0, shape, dtype)
cmp = dsl.vcmp(z, dsl.broadcast(0.0, shape, dtype))
d = dsl.vsel(cmp, ones, nones)
xn = x
d_ceof = dsl.vmuls(d, ceof[i])
x = dsl.vsub(xn, dsl.vmul(y, d_ceof))
y = dsl.vadd(y, dsl.vmul(xn, d_ceof))
z = dsl.vsub(z, dsl.vmuls(d, dangle[i]))
return dsl.vmuls(y, k), dsl.vmuls(x, k)
# pylint: disable=locally-disabled,too-many-arguments,unused-argument,invalid-name,too-many-locals,
def prod(data, iter_arr):
"""prod"""
input_shape = data.shape
input_dtype = data.dtype
res = dsl.broadcast(tvm.const(iter_arr[-1], input_dtype), input_shape)
for addition in reversed(iter_arr[:-1]):
res = dsl.vmul(res, data)
res = dsl.vadd(res, dsl.broadcast(tvm.const(addition, input_dtype), input_shape))
return res
def bessel_j1(x):
"""bessel_j1"""
shape_input = x.shape
dtype_input = x.dtype
# 1. chose the type of data in begin
if (dtype_input in ('float16', 'float32')) and api_check_support("te.lang.cce.vadd", "float32"):
x = dsl.cast_to(x, "float32")
else:
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
abs_data = dsl.vabs(x)
# 2. compute bessel_j1 for data in (-8, 8)
broad_const_limit = dsl.broadcast(tvm.const(CONST_LIMIT, x.dtype), shape_input)
before_abs_data = dsl.vmin(abs_data, broad_const_limit)
square_data = dsl.vmul(before_abs_data, before_abs_data) # x * x
before_res1 = prod(square_data, ITR_BEFORE1_J)
before_res1 = dsl.vmul(before_res1, before_abs_data)
before_res2 = prod(square_data, ITR_BEFORE2_J)
before_final_res = dsl.vdiv(before_res1, before_res2)
# 3. compute bessel_j1 for data in (-inf, -8) or (8, inf)
div_data = dsl.vdiv(dsl.broadcast(tvm.const(8.0), shape_input), abs_data)
square_div_data = dsl.vmul(div_data, div_data)
minus_pi_data = dsl.vsub(abs_data, dsl.broadcast(tvm.const(THREE_QUARTERS_PI), shape_input))
after_res1 = prod(square_div_data, ITR_AFTER1_J)
after_res2 = prod(square_div_data, ITR_AFTER2_J)
# 3.1 sqrt(0.636619772/ax)
tmp_res1 = dsl.vsqrt(dsl.vdiv(dsl.broadcast(tvm.const(TOW_OVER_PI), shape_input), abs_data),
impl_mode='high_precision')
# 3.2 cos(xx)*ans1
sinv, cosv = angle_trans_cal(abs_data)
tmp_res2 = dsl.vmul(cosv, after_res1)
# 3.3 z*math.sin(xx)*ans2
tmp_res3 = dsl.vmul(dsl.vmul(div_data, sinv), after_res2)
after_final_res = dsl.vmul(tmp_res1, dsl.vsub(tmp_res2, tmp_res3))
zero = dsl.broadcast(0.0, shape_input, 'float32')
neg_cond = dsl.vcmp(after_final_res, zero, operation='lt', mode='bool')
neg_after_res = dsl.vmuls(after_final_res, -1.0)
after_final_res = dsl.vsel(neg_cond, neg_after_res, after_final_res)
# 5. select res
# 5.1 compare with limit
select_condition = dsl.vcmp(abs_data, broad_const_limit, operation='lt', mode='bool')
# 5.2 select
res = dsl.vsel(select_condition, before_final_res, after_final_res)
# 6. chose the type of data in end
if dtype_input == "float16":
res = dsl.cast_to(res, "float16")
return res
@register_op_compute("bessel_y1")
def bessel_y1_compute(x, y, kernel_name="bessel_y1"):
"""bessel_y1_compute"""
shape_input = x.shape
dtype_input = x.dtype
# chose the type of data in begin
if (dtype_input in ('float16', 'float32')) and api_check_support("te.lang.cce.vadd", "float32"):
x = dsl.cast_to(x, "float32")
else:
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
x = dsl.vabs(x)
# compute bessel_y1 for data in (0, 8)
broad_const_limit = dsl.broadcast(tvm.const(CONST_LIMIT, x.dtype), shape_input)
square_data = dsl.vmul(x, x) # y = x * x
before_res1 = prod(square_data, ITR_BEFORE1)
before_res1 = dsl.vmul(before_res1, x)
before_res2 = prod(square_data, ITR_BEFORE2)
before_final_res = dsl.vmul(bessel_j1(x), dsl.vlog(x, impl_mode="high_precision"))
before_final_res = dsl.vsub(before_final_res, dsl.vrec(x, impl_mode="high_precision"))
before_final_res = dsl.vmuls(before_final_res, TOW_OVER_PI)
before_final_res = dsl.vadd(before_final_res, dsl.vdiv(before_res1, before_res2))
# compute bessel_y1 for data in (8, inf)
div_data = dsl.vdiv(broad_const_limit, x) # z = 8.0 / x
square_div_data = dsl.vmul(div_data, div_data) # y = z * z
minus_pi_data = dsl.vsub(x, dsl.broadcast(tvm.const(THREE_QUARTERS_PI), shape_input)) # xx
after_res1 = prod(square_div_data, ITR_AFTER1)
after_res2 = prod(square_div_data, ITR_AFTER2)
tmp_res1 = dsl.vsqrt(dsl.vdiv(dsl.broadcast(tvm.const(TOW_OVER_PI), shape_input), x), impl_mode="high_precision")
sinv, cosv = angle_trans_cal(x)
tmp_res2 = dsl.vmul(sinv, after_res1)
tmp_res3 = dsl.vmul(dsl.vmul(div_data, cosv), after_res2)
after_final_res = dsl.vmul(tmp_res1, dsl.vadd(tmp_res2, tmp_res3))
# select res
# compare with limit
select_condition = dsl.vcmp(x, broad_const_limit, operation='lt', mode='bool')
# select
res = dsl.vsel(select_condition, before_final_res, after_final_res)
# chose the type of data in end
if dtype_input == "float16":
res = dsl.cast_to(res, "float16")
return res
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
def bessel_y1(x, y, kernel_name="bessel_y1"):
"""
Computes the Bessel j1 function of x element-wise.
Parameters
----------
x: the dict of input, only support float16, float32
y : the dict of output
kernel_name : cce kernel name, default value is "bessel_y1"
Returns
-------
None
"""
shape_input = x.get("shape")
dtype_input = x.get("dtype")
para_check.check_shape(shape_input, param_name="x")
shape_input, _ = shape_util.refine_shape_axes(shape_input, [])
check_list = ("float16", "float32")
para_check.check_dtype(dtype_input, check_list, param_name="x")
input_dtype = dtype_input.lower()
data = tvm.placeholder(shape_input, dtype=input_dtype, name="data_input")
res = bessel_y1_compute(data, y, kernel_name)
with tvm.target.cce():
sch = dsl.auto_schedule(res)
config = {"name": kernel_name,
"print_ir": False,
"tensor_list": (data, res)}
dsl.build(sch, config)

View File

@ -4767,7 +4767,7 @@ class BesselI0(Primitive):
TypeError: If `x` is not a Tensor of float16, float32 or float64. TypeError: If `x` is not a Tensor of float16, float32 or float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_i0 = ops.BesselI0() >>> bessel_i0 = ops.BesselI0()
@ -4798,7 +4798,7 @@ class BesselI1(Primitive):
TypeError: If `x` is not a Tensor of float16, float32 or float64. TypeError: If `x` is not a Tensor of float16, float32 or float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_i1 = ops.BesselI1() >>> bessel_i1 = ops.BesselI1()
@ -4837,7 +4837,7 @@ class BesselI0e(Primitive):
TypeError: If dtype of `x` is not float16, float32 or float64. TypeError: If dtype of `x` is not float16, float32 or float64.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_i0e = ops.BesselI0e() >>> bessel_i0e = ops.BesselI0e()
@ -4877,7 +4877,7 @@ class BesselI1e(Primitive):
TypeError: If dtype of `x` is not float16, float32 or float64. TypeError: If dtype of `x` is not float16, float32 or float64.
Supported Platforms: Supported Platforms:
``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_i1e = ops.BesselI1e() >>> bessel_i1e = ops.BesselI1e()
@ -4909,7 +4909,7 @@ class BesselK0(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_k0 = ops.BesselK0() >>> bessel_k0 = ops.BesselK0()
@ -4940,7 +4940,7 @@ class BesselK1(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_k1 = ops.BesselK1() >>> bessel_k1 = ops.BesselK1()
@ -4971,7 +4971,7 @@ class BesselK0e(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_k0e = ops.BesselK0e() >>> bessel_k0e = ops.BesselK0e()
@ -5002,7 +5002,7 @@ class BesselK1e(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_k1e = ops.BesselK1e() >>> bessel_k1e = ops.BesselK1e()
@ -5033,7 +5033,7 @@ class BesselJ0(Primitive):
TypeError: If `x` is not a Tensor of float16, float32 or float64. TypeError: If `x` is not a Tensor of float16, float32 or float64.
Supported Platforms: Supported Platforms:
``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_j0 = ops.BesselJ0() >>> bessel_j0 = ops.BesselJ0()
@ -5065,7 +5065,7 @@ class BesselJ1(Primitive):
TypeError: If `x` is not a Tensor of float16, float32 or float64. TypeError: If `x` is not a Tensor of float16, float32 or float64.
Supported Platforms: Supported Platforms:
``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_j1 = ops.BesselJ1() >>> bessel_j1 = ops.BesselJ1()
@ -5097,7 +5097,7 @@ class BesselY0(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_y0 = ops.BesselY0() >>> bessel_y0 = ops.BesselY0()
@ -5129,7 +5129,7 @@ class BesselY1(Primitive):
TypeError: If `x` is not a Tensor of float16, float32, float64. TypeError: If `x` is not a Tensor of float16, float32, float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``CPU`` ``GPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> bessel_y1 = ops.BesselY1() >>> bessel_y1 = ops.BesselY1()