forked from mindspore-Ecosystem/mindspore
!41797 Remove custom_ops of Bessel functions on Ascend platform
Merge pull request !41797 from hedongdong/TBE_Bessel
This commit is contained in:
commit
adbaa9e86f
|
@ -19,16 +19,6 @@ from .batchnorm_fold2 import _batchnorm_fold2_tbe
|
||||||
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
|
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
|
||||||
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
|
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
|
||||||
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
|
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
|
||||||
from .bessel_i0 import _bessel_i0_tbe
|
|
||||||
from .bessel_i1 import _bessel_i1_tbe
|
|
||||||
from .bessel_j0 import _bessel_j0_tbe
|
|
||||||
from .bessel_j1 import _bessel_j1_tbe
|
|
||||||
from .bessel_k0 import _bessel_k0_tbe
|
|
||||||
from .bessel_k1 import _bessel_k1_tbe
|
|
||||||
from .bessel_k0e import _bessel_k0e_tbe
|
|
||||||
from .bessel_k1e import _bessel_k1e_tbe
|
|
||||||
from .bessel_y0 import _bessel_y0_tbe
|
|
||||||
from .bessel_y1 import _bessel_y1_tbe
|
|
||||||
from .correction_mul import _correction_mul_tbe
|
from .correction_mul import _correction_mul_tbe
|
||||||
from .correction_mul_grad import _correction_mul_grad_tbe
|
from .correction_mul_grad import _correction_mul_grad_tbe
|
||||||
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
|
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
|
||||||
|
|
|
@ -1,117 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselI0 op"""
|
|
||||||
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_i0_op_info = TBERegOp("BesselI0") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_i0.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_i0") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_i0_op_info)
|
|
||||||
def _bessel_i0_tbe():
|
|
||||||
"""BesselI0 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
A = [-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
|
||||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
|
||||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
|
||||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
|
||||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
|
||||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
|
||||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
|
||||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
|
||||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1]
|
|
||||||
|
|
||||||
B = [3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
|
||||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
|
||||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
|
||||||
8.04490411014108831608E-1]
|
|
||||||
|
|
||||||
|
|
||||||
def chebevl(x, num, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
none_signal = None
|
|
||||||
for i in range(1, num):
|
|
||||||
none_signal = broad_zero
|
|
||||||
broad_zero = broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_i0")
|
|
||||||
def bessel_i0_compute(input_x, output, kernel_name="bessel_i0"):
|
|
||||||
"""bessel_i0_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight_in = dsl.vmuls(y, 0.5)
|
|
||||||
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
|
|
||||||
y_le_eight = chebevl(y_le_eight_in, 18, A, shape, dtype)
|
|
||||||
|
|
||||||
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
|
|
||||||
y_gt_eight = chebevl(y_gt_eight_in, 7, B, shape, dtype)
|
|
||||||
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(y))
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i0(x, output, kernel_name="bessel_i0"):
|
|
||||||
"""bessel_i0"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_i0_compute(data_x, output, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,128 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselI1 op"""
|
|
||||||
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_i1_op_info = TBERegOp("BesselI1") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_i1.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_i1") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_i1_op_info)
|
|
||||||
def _bessel_i1_tbe():
|
|
||||||
"""BesselI1 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
A = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
|
|
||||||
1.55363195773620046921E-16, -1.10559694773538630805E-15,
|
|
||||||
7.60068429473540693410E-15, -5.04218550472791168711E-14,
|
|
||||||
3.22379336594557470981E-13, -1.98397439776494371520E-12,
|
|
||||||
1.17361862988909016308E-11, -6.66348972350202774223E-11,
|
|
||||||
3.62559028155211703701E-10, -1.88724975172282928790E-9,
|
|
||||||
9.38153738649577178388E-9, -4.44505912879632808065E-8,
|
|
||||||
2.00329475355213526229E-7, -8.56872026469545474066E-7,
|
|
||||||
3.47025130813767847674E-6, -1.32731636560394358279E-5,
|
|
||||||
4.78156510755005422638E-5, -1.61760815825896745588E-4,
|
|
||||||
5.12285956168575772895E-4, -1.51357245063125314899E-3,
|
|
||||||
4.15642294431288815669E-3, -1.05640848946261981558E-2,
|
|
||||||
2.47264490306265168283E-2, -5.29459812080949914269E-2,
|
|
||||||
1.02643658689847095384E-1, -1.76416518357834055153E-1,
|
|
||||||
2.52587186443633654823E-1]
|
|
||||||
B = [
|
|
||||||
7.51729631084210481353E-18, 4.41434832307170791151E-18,
|
|
||||||
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
|
|
||||||
2.96262899764595013876E-16, 3.30820231092092828324E-16,
|
|
||||||
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
|
|
||||||
1.04202769841288027642E-14, 4.27244001671195135429E-14,
|
|
||||||
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
|
|
||||||
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
|
|
||||||
1.41258074366137813316E-11, 3.25260358301548823856E-11,
|
|
||||||
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
|
|
||||||
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
|
|
||||||
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
|
|
||||||
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
|
|
||||||
7.78576235018280120474E-1]
|
|
||||||
|
|
||||||
|
|
||||||
def chebevl(x, num, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
none_signal = None
|
|
||||||
for i in range(1, num):
|
|
||||||
none_signal = broad_zero
|
|
||||||
broad_zero = broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_i1")
|
|
||||||
def bessel_i1_compute(input_x, output_y, kernel_name="bessel_i1"):
|
|
||||||
"""bessel_i1_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight = dsl.vmul(y, chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, A, shape, dtype))
|
|
||||||
y_gt_eight = chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, B, shape, dtype)
|
|
||||||
|
|
||||||
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i1(x, y, kernel_name="bessel_i1"):
|
|
||||||
"""bessel_i1"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_i1_compute(data_x, y, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,226 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselJ0 op"""
|
|
||||||
|
|
||||||
import te.lang.cce as tbe
|
|
||||||
import te.platform as tbe_platform
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
from tbe import dsl
|
|
||||||
from tbe.common.utils import shape_util
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_j0_op_info = TBERegOp("BesselJ0") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_j0.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_j0") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_j0_op_info)
|
|
||||||
def _bessel_j0_tbe():
|
|
||||||
"""BesselJ0 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
FLOAT_16 = "float16"
|
|
||||||
FLOAT_32 = "float32"
|
|
||||||
|
|
||||||
PP = [7.96936729297347051624E-4, 8.28352392107440799803E-2,
|
|
||||||
1.23953371646414299388E0, 5.44725003058768775090E0,
|
|
||||||
8.74716500199817011941E0, 5.30324038235394892183E0,
|
|
||||||
9.99999999999999997821E-1]
|
|
||||||
PQ = [9.24408810558863637013E-4, 8.56288474354474431428E-2,
|
|
||||||
1.25352743901058953537E0, 5.47097740330417105182E0,
|
|
||||||
8.76190883237069594232E0, 5.30605288235394617618E0,
|
|
||||||
1.00000000000000000218E0]
|
|
||||||
QP = [-1.13663838898469149931E-2, -1.28252718670509318512E0,
|
|
||||||
-1.95539544257735972385E1, -9.32060152123768231369E1,
|
|
||||||
-1.77681167980488050595E2, -1.47077505154951170175E2,
|
|
||||||
-5.14105326766599330220E1, -6.05014350600728481186E0]
|
|
||||||
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
|
|
||||||
8.56430025976980587198E2, 3.88240183605401609683E3,
|
|
||||||
7.24046774195652478189E3, 5.93072701187316984827E3,
|
|
||||||
2.06209331660327847417E3, 2.42005740240291393179E2]
|
|
||||||
RP = [-4.79443220978201773821E9, 1.95617491946556577543E12,
|
|
||||||
-2.49248344360967716204E14, 9.70862251047306323952E15]
|
|
||||||
RQ = [1.00000000000000000000E0, 4.99563147152651017219E2,
|
|
||||||
1.73785401676374683123E5, 4.84409658339962045305E7,
|
|
||||||
1.11855537045356834862E10, 2.11277520115489217587E12,
|
|
||||||
3.10518229857422583814E14, 3.18121955943204943306E16,
|
|
||||||
1.71086294081043136091E18]
|
|
||||||
|
|
||||||
DR1 = 5.78318596294678452118E0
|
|
||||||
DR2 = 3.04712623436620863991E1
|
|
||||||
SQ2OPI = 7.9788456080286535587989E-1
|
|
||||||
NEG_PIO4 = -0.7853981633974483096
|
|
||||||
|
|
||||||
PI = 3.14159265358979
|
|
||||||
FIRST_ORDER = 5
|
|
||||||
LAST_ORDER = 13
|
|
||||||
FIRST_FACTOR = -1.0 / 6.0
|
|
||||||
|
|
||||||
|
|
||||||
def besselj0_cos(x):
|
|
||||||
"""cos"""
|
|
||||||
dtype = x.dtype
|
|
||||||
shape = shape_util.shape_to_list(x.shape)
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
x = dsl.cast_to(x, FLOAT_32)
|
|
||||||
dtype = FLOAT_32
|
|
||||||
has_improve_precision = True
|
|
||||||
# round the input
|
|
||||||
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
|
|
||||||
round_fp32 = dsl.cast_to(round_fp16, dtype)
|
|
||||||
input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
|
|
||||||
# the initial value one
|
|
||||||
const_res = tvm.const(1.0, dtype=dtype)
|
|
||||||
res = dsl.broadcast(const_res, shape)
|
|
||||||
# compute the rank 2
|
|
||||||
input_x_power = dsl.vmul(input_x_round, input_x_round)
|
|
||||||
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
|
|
||||||
res = dsl.vadd(res, iter_value)
|
|
||||||
# compute the rank 4~14
|
|
||||||
iter_list = (4, 6, 8, 10, 12, 14)
|
|
||||||
for i in iter_list:
|
|
||||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
|
|
||||||
res = dsl.vadd(res, iter_value)
|
|
||||||
# cast the dtype to float16
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _besselj0_sin(x):
|
|
||||||
"""_besselj0_sin"""
|
|
||||||
input_x_power = dsl.vmul(x, x)
|
|
||||||
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
|
|
||||||
res = dsl.vadd(x, iter_value)
|
|
||||||
signal = FIRST_ORDER
|
|
||||||
while signal < LAST_ORDER:
|
|
||||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
|
|
||||||
-1.0 / (signal*(signal - 1)))
|
|
||||||
res = dsl.vadd(res, iter_value)
|
|
||||||
signal = signal + 2
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def besselj0_sin(x):
|
|
||||||
"""sin"""
|
|
||||||
dtype = x.dtype
|
|
||||||
shape = shape_util.shape_to_list(x.shape)
|
|
||||||
has_improve_precision = False
|
|
||||||
cast_dtype = FLOAT_16
|
|
||||||
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
has_improve_precision = True
|
|
||||||
cast_dtype = FLOAT_32
|
|
||||||
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
if dtype == FLOAT_16 and has_improve_precision:
|
|
||||||
x = tbe.cast_to(x, FLOAT_32)
|
|
||||||
|
|
||||||
pai_multiple = tbe.vmuls(x, 1 / PI)
|
|
||||||
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
|
|
||||||
# to adjust x to [-pai/2,pai/2]
|
|
||||||
x = tbe.vsub(x, tbe.vmuls(round_float, PI))
|
|
||||||
res = _besselj0_sin(x)
|
|
||||||
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
|
|
||||||
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
|
|
||||||
# if odd, ceil*2-round is 1,if even, the value is 0
|
|
||||||
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
|
|
||||||
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
|
|
||||||
odd_tensor = tbe.vsub(tensor_one, sub_value)
|
|
||||||
even_tensor = tbe.vsub(odd_tensor, tensor_one)
|
|
||||||
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
|
|
||||||
res = tbe.vmul(res, odd_even_tensor)
|
|
||||||
|
|
||||||
# cast the dtype to float16
|
|
||||||
if dtype == FLOAT_16 and has_improve_precision:
|
|
||||||
res = tbe.cast_to(res, FLOAT_16)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def besselj0_polevl(x, n, coef, shape):
|
|
||||||
"""polevl"""
|
|
||||||
dtype = 'float32'
|
|
||||||
x = dsl.cast_to(x, dtype)
|
|
||||||
|
|
||||||
if n == 0:
|
|
||||||
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
|
|
||||||
return dsl.cast_to(coef_0, dtype)
|
|
||||||
|
|
||||||
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
|
|
||||||
res = dsl.vadd(dsl.vmul(besselj0_polevl(x, n-1, coef, shape), x), coef_n)
|
|
||||||
return dsl.cast_to(res, 'float32')
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_j0")
|
|
||||||
def bessel_j0_compute(x, kernel_name="bessel_j0"):
|
|
||||||
"""bessel_j0_compute"""
|
|
||||||
dtype = x.dtype
|
|
||||||
shape = shape_util.shape_to_list(x.shape)
|
|
||||||
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
x = dsl.cast_to(x, FLOAT_32)
|
|
||||||
dtype = FLOAT_32
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(x)
|
|
||||||
z = dsl.vmul(y, y)
|
|
||||||
y_le_five = dsl.vcmpsel(y, 1.0e-5, 'lt', dsl.vadds(dsl.vmuls(z, -0.25), 1),
|
|
||||||
dsl.vmul(dsl.vmul(dsl.vadds(z, -DR1), dsl.vadds(z, -DR2)),
|
|
||||||
dsl.vdiv(besselj0_polevl(z, 3, RP, shape), besselj0_polevl(z, 8, RQ, shape))))
|
|
||||||
s = dsl.vmuls(dsl.vrec(z), 25)
|
|
||||||
p = dsl.vdiv(besselj0_polevl(s, 6, PP, shape), besselj0_polevl(s, 6, PQ, shape))
|
|
||||||
q = dsl.vdiv(besselj0_polevl(s, 7, QP, shape), besselj0_polevl(s, 6, PQ, shape))
|
|
||||||
yn = dsl.vadds(y, NEG_PIO4)
|
|
||||||
w = dsl.vmuls(dsl.vrec(y), -5.0)
|
|
||||||
p = dsl.vadd(dsl.vmul(p, besselj0_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj0_sin(yn))))
|
|
||||||
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
|
|
||||||
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_j0(x, y, kernel_name="bessel_j0"):
|
|
||||||
"""bessel_j0"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
res = bessel_j0_compute(data_x, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,246 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselJ1 op"""
|
|
||||||
|
|
||||||
|
|
||||||
import te.lang.cce as tbe
|
|
||||||
import te.platform as tbe_platform
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
from tbe import dsl
|
|
||||||
from tbe.common.utils import shape_util
|
|
||||||
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_j1_op_info = TBERegOp("BesselJ1") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_j1.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_j1") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_j1_op_info)
|
|
||||||
def _bessel_j1_tbe():
|
|
||||||
"""BesselJ1 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
FLOAT_16 = "float16"
|
|
||||||
FLOAT_32 = "float32"
|
|
||||||
|
|
||||||
PI = 3.14159265358979
|
|
||||||
FIRST_ORDER = 5
|
|
||||||
LAST_ORDER = 13
|
|
||||||
FIRST_FACTOR = -1.0 / 6.0
|
|
||||||
|
|
||||||
|
|
||||||
def besselj1_cos(x):
|
|
||||||
"""cos"""
|
|
||||||
dtype = x.dtype
|
|
||||||
shape = shape_util.shape_to_list(x.shape)
|
|
||||||
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
x = dsl.cast_to(x, FLOAT_32)
|
|
||||||
dtype = FLOAT_32
|
|
||||||
has_improve_precision = True
|
|
||||||
# round the input
|
|
||||||
round_fp16 = dsl.round(dsl.vmuls(x, 1.0 / (2 * PI)))
|
|
||||||
round_fp32 = dsl.cast_to(round_fp16, dtype)
|
|
||||||
besselj1_input_x_round = dsl.vsub(x, dsl.vmuls(round_fp32, 2 * PI))
|
|
||||||
# the initial value one
|
|
||||||
const_res = tvm.const(1.0, dtype=dtype)
|
|
||||||
res = dsl.broadcast(const_res, shape)
|
|
||||||
# compute the rank 2
|
|
||||||
input_x_power = dsl.vmul(besselj1_input_x_round, besselj1_input_x_round)
|
|
||||||
iter_value = dsl.vmuls(input_x_power, -1.0/2.0)
|
|
||||||
res = dsl.vadd(res, iter_value)
|
|
||||||
# compute the rank 4~14
|
|
||||||
iter_list = (4, 6, 8, 10, 12, 14)
|
|
||||||
for i in iter_list:
|
|
||||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value), -1.0/(i*(i-1)))
|
|
||||||
res = dsl.vadd(res, iter_value)
|
|
||||||
# cast the dtype to float16
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _besselj1_sin(x):
|
|
||||||
"""_sin"""
|
|
||||||
input_x_power = dsl.vmul(x, x)
|
|
||||||
iter_value = dsl.vmul(dsl.vmuls(input_x_power, FIRST_FACTOR), x)
|
|
||||||
besselj1_res = dsl.vadd(x, iter_value)
|
|
||||||
|
|
||||||
signal = FIRST_ORDER
|
|
||||||
while signal < LAST_ORDER:
|
|
||||||
iter_value = dsl.vmuls(dsl.vmul(input_x_power, iter_value),
|
|
||||||
-1.0 / (signal*(signal - 1)))
|
|
||||||
besselj1_res = dsl.vadd(besselj1_res, iter_value)
|
|
||||||
signal = signal + 2
|
|
||||||
|
|
||||||
return besselj1_res
|
|
||||||
|
|
||||||
|
|
||||||
def besselj1_sin(besselj1_x):
|
|
||||||
"""sin"""
|
|
||||||
dtype = besselj1_x.dtype
|
|
||||||
shape = shape_util.shape_to_list(besselj1_x.shape)
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
cast_dtype = FLOAT_16
|
|
||||||
if tbe_platform.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
has_improve_precision = True
|
|
||||||
cast_dtype = FLOAT_32
|
|
||||||
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
if dtype == FLOAT_16 and has_improve_precision:
|
|
||||||
besselj1_x = tbe.cast_to(besselj1_x, FLOAT_32)
|
|
||||||
|
|
||||||
pai_multiple = tbe.vmuls(besselj1_x, 1 / PI)
|
|
||||||
round_float = tbe.cast_to(tbe.round(pai_multiple), cast_dtype)
|
|
||||||
# to adjust x to [-pai/2,pai/2]
|
|
||||||
besselj1_x = tbe.vsub(besselj1_x, tbe.vmuls(round_float, PI))
|
|
||||||
|
|
||||||
besselj1_res = _besselj1_sin(besselj1_x)
|
|
||||||
|
|
||||||
# if round is odd, the final result need to multiply -1.Need to multiply 1/2 to get the ceil value
|
|
||||||
ceil_value = tbe.ceil(tbe.vmuls(round_float, 1 / 2))
|
|
||||||
# if odd, ceil*2-round is 1,if even, the value is 0
|
|
||||||
sub_value = tbe.vsub(tbe.vmuls(ceil_value, tvm.const(2, dtype)), round_float)
|
|
||||||
tensor_one = tbe.broadcast(tvm.const(1, cast_dtype), shape)
|
|
||||||
odd_tensor = tbe.vsub(tensor_one, sub_value)
|
|
||||||
even_tensor = tbe.vsub(odd_tensor, tensor_one)
|
|
||||||
odd_even_tensor = tbe.vadd(odd_tensor, even_tensor)
|
|
||||||
besselj1_res = tbe.vmul(besselj1_res, odd_even_tensor)
|
|
||||||
|
|
||||||
# cast the dtype to float16
|
|
||||||
if dtype == FLOAT_16 and has_improve_precision:
|
|
||||||
besselj1_res = tbe.cast_to(besselj1_res, FLOAT_16)
|
|
||||||
|
|
||||||
return besselj1_res
|
|
||||||
|
|
||||||
|
|
||||||
PP = [7.62125616208173112003E-4, 7.31397056940917570436E-2,
|
|
||||||
1.12719608129684925192E0, 5.11207951146807644818E0,
|
|
||||||
8.42404590141772420927E0, 5.21451598682361504063E0,
|
|
||||||
1.00000000000000000254E0]
|
|
||||||
|
|
||||||
PQ = [5.71323128072548699714E-4, 6.88455908754495404082E-2,
|
|
||||||
1.10514232634061696926E0, 5.07386386128601488557E0,
|
|
||||||
8.39985554327604159757E0, 5.20982848682361821619E0,
|
|
||||||
9.99999999999999997461E-1]
|
|
||||||
|
|
||||||
QP = [5.10862594750176621635E-2, 4.98213872951233449420E0,
|
|
||||||
7.58238284132545283818E1, 3.66779609360150777800E2,
|
|
||||||
7.10856304998926107277E2, 5.97489612400613639965E2,
|
|
||||||
2.11688757100572135698E2, 2.52070205858023719784E1]
|
|
||||||
|
|
||||||
QQ = [1.00000000000000000000E0, 6.43178256118178023184E1,
|
|
||||||
8.56430025976980587198E2, 3.88240183605401609683E3,
|
|
||||||
7.24046774195652478189E3, 5.93072701187316984827E3,
|
|
||||||
2.06209331660327847417E3, 2.42005740240291393179E2]
|
|
||||||
|
|
||||||
RP = [-8.99971225705559398224E8, 4.52228297998194034323E11,
|
|
||||||
-7.27494245221818276015E13, 3.68295732863852883286E15]
|
|
||||||
|
|
||||||
RQ = [1.00000000000000000000E0, 6.20836478118054335476E2,
|
|
||||||
2.56987256757748830383E5, 8.35146791431949253037E7,
|
|
||||||
2.21511595479792499675E10, 4.74914122079991414898E12,
|
|
||||||
7.84369607876235854894E14, 8.95222336184627338078E16,
|
|
||||||
5.32278620332680085395E18]
|
|
||||||
|
|
||||||
Z1 = 1.46819706421238932572E1
|
|
||||||
Z2 = 4.92184563216946036703E1
|
|
||||||
NEG_THPIO4 = -2.35619449019234492885
|
|
||||||
SQ2OPI = 7.9788456080286535587989E-1
|
|
||||||
|
|
||||||
|
|
||||||
def polevl(x, n, coef, shape):
|
|
||||||
"""polevl"""
|
|
||||||
dtype = 'float32'
|
|
||||||
x = dsl.cast_to(x, dtype)
|
|
||||||
|
|
||||||
if n == 0:
|
|
||||||
coef_0 = dsl.broadcast(coef[0], shape, output_dtype=dtype)
|
|
||||||
return dsl.cast_to(coef_0, dtype)
|
|
||||||
|
|
||||||
coef_n = dsl.broadcast(coef[n], shape, output_dtype=dtype)
|
|
||||||
res = dsl.vadd(dsl.vmul(polevl(x, n-1, coef, shape), x), coef_n)
|
|
||||||
return dsl.cast_to(res, 'float32')
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_j1")
|
|
||||||
def bessel_j1_compute(x, kernel_name="bessel_j1"):
|
|
||||||
"""bessel_j1_compute"""
|
|
||||||
dtype = x.dtype
|
|
||||||
shape = shape_util.shape_to_list(x.shape)
|
|
||||||
|
|
||||||
# cast to type float32 when type is float16
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype.lower() == FLOAT_16 and tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul", "float32"):
|
|
||||||
x = dsl.cast_to(x, FLOAT_32)
|
|
||||||
dtype = FLOAT_32
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(x)
|
|
||||||
z = dsl.vmul(y, y)
|
|
||||||
y_le_five = dsl.vdiv(polevl(z, 3, RP, shape), polevl(z, 8, RQ, shape))
|
|
||||||
y_le_five = dsl.vmul(y_le_five, dsl.vmul(x, dsl.vmul(dsl.vadds(z, -Z1), dsl.vadds(z, -Z2))))
|
|
||||||
s = dsl.vmuls(dsl.vrec(z), 25)
|
|
||||||
p = dsl.vdiv(polevl(s, 6, PP, shape), polevl(s, 6, PQ, shape))
|
|
||||||
q = dsl.vdiv(polevl(s, 7, QP, shape), polevl(s, 7, QQ, shape))
|
|
||||||
yn = dsl.vadds(y, NEG_THPIO4)
|
|
||||||
w = dsl.vmuls(dsl.vrec(y), -5.0)
|
|
||||||
p = dsl.vadd(dsl.vmul(p, besselj1_cos(yn)), dsl.vmul(w, dsl.vmul(q, besselj1_sin(yn))))
|
|
||||||
|
|
||||||
y_gt_five = dsl.vmul(dsl.vmuls(p, SQ2OPI), dsl.vrsqrt(y))
|
|
||||||
y_gt_five = dsl.vcmpsel(x, 0.0, 'lt', dsl.vmuls(y_gt_five, -1.0), y_gt_five)
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(y, 5.0, 'le', y_le_five, y_gt_five)
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_j1(x, y, kernel_name="bessel_j1"):
|
|
||||||
"""
|
|
||||||
To do: Implement the operator by referring to the
|
|
||||||
TBE Operator Development Guide.
|
|
||||||
"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_j1_compute(data_x, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,186 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesseK0 op"""
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
|
|
||||||
bessel_k0_op_info = TBERegOp("BesselK0") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_k0.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_k0") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_k0_op_info)
|
|
||||||
def _bessel_k0_tbe():
|
|
||||||
"""BesselK0 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
A = [1.37446543561352307156E-16,
|
|
||||||
4.25981614279661018399E-14,
|
|
||||||
1.03496952576338420167E-11,
|
|
||||||
1.90451637722020886025E-9,
|
|
||||||
2.53479107902614945675E-7,
|
|
||||||
2.28621210311945178607E-5,
|
|
||||||
1.26461541144692592338E-3,
|
|
||||||
3.59799365153615016266E-2,
|
|
||||||
3.44289899924628486886E-1,
|
|
||||||
-5.35327393233902768720E-1]
|
|
||||||
B = [
|
|
||||||
5.30043377268626276149E-18, -1.64758043015242134646E-17,
|
|
||||||
5.21039150503902756861E-17, -1.67823109680541210385E-16,
|
|
||||||
5.51205597852431940784E-16, -1.84859337734377901440E-15,
|
|
||||||
6.34007647740507060557E-15, -2.22751332699166985548E-14,
|
|
||||||
8.03289077536357521100E-14, -2.98009692317273043925E-13,
|
|
||||||
1.14034058820847496303E-12, -4.51459788337394416547E-12,
|
|
||||||
1.85594911495471785253E-11, -7.95748924447710747776E-11,
|
|
||||||
3.57739728140030116597E-10, -1.69753450938905987466E-9,
|
|
||||||
8.57403401741422608519E-9, -4.66048989768794782956E-8,
|
|
||||||
2.76681363944501510342E-7, -1.83175552271911948767E-6,
|
|
||||||
1.39498137188764993662E-5, -1.28495495816278026384E-4,
|
|
||||||
1.56988388573005337491E-3, -3.14481013119645005427E-2,
|
|
||||||
2.44030308206595545468E0]
|
|
||||||
|
|
||||||
AA = [-4.41534164647933937950E-18, 3.33079451882223809783E-17,
|
|
||||||
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
|
|
||||||
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
|
|
||||||
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
|
|
||||||
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
|
|
||||||
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
|
|
||||||
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
|
||||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
|
||||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
|
||||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
|
||||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
|
||||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
|
||||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
|
||||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
|
||||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1
|
|
||||||
]
|
|
||||||
BB = [
|
|
||||||
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
|
|
||||||
4.46562142029675999901E-17, 3.46122286769746109310E-17,
|
|
||||||
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
|
|
||||||
1.77256013305652638360E-15, 3.81168066935262242075E-15,
|
|
||||||
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
|
|
||||||
1.54008621752140982691E-14, 3.85277838274214270114E-13,
|
|
||||||
7.18012445138366623367E-13, -1.79417853150680611778E-12,
|
|
||||||
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
|
|
||||||
1.18891471078464383424E-11, 4.94060238822496958910E-10,
|
|
||||||
3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
|
||||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
|
||||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
|
||||||
8.04490411014108831608E-1]
|
|
||||||
|
|
||||||
MAXNUM = 4294967295.0
|
|
||||||
TWO = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def chebevl(x, n, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
k0_broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
k0_broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
k0_none_signal = None
|
|
||||||
for i in range(1, n):
|
|
||||||
k0_none_signal = k0_broad_zero
|
|
||||||
k0_broad_zero = k0_broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
k0_broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, k0_broad_zero), coef_i), k0_none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(k0_broad_coef, k0_none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i0_compute(input_x):
|
|
||||||
"""bessel_i0_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight_in = dsl.vmuls(y, 0.5)
|
|
||||||
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
|
|
||||||
y_le_eight = chebevl(y_le_eight_in, 30, AA, shape, dtype)
|
|
||||||
|
|
||||||
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
|
|
||||||
y_gt_eight = chebevl(y_gt_eight_in, 25, BB, shape, dtype)
|
|
||||||
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(y))
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_k0")
|
|
||||||
def bessel_k0_compute(input_x, output_y, kernel_name="bessel_k0"):
|
|
||||||
"""bessel_k0_compute"""
|
|
||||||
shape = input_x.shape
|
|
||||||
dtype = input_x.dtype
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
x_le_two = chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 10, A, shape, dtype)
|
|
||||||
x_le_two = dsl.vadd(dsl.vmul(bessel_i0_compute(input_x),
|
|
||||||
dsl.vmuls(dsl.vlog(dsl.vmuls(input_x, 0.5)), -1.0)), x_le_two)
|
|
||||||
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
|
|
||||||
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)),
|
|
||||||
chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x), 8.0), -2.0), 25, B, shape, dtype)),
|
|
||||||
(dsl.vrsqrt(input_x)))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k0(x, output, kernel_name="bessel_k0"):
|
|
||||||
"""bessel_k0"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_k0_compute(data_x, output, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,187 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesseK0e op"""
|
|
||||||
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_k0e_op_info = TBERegOp("BesselK0e") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_k0e.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_k0e") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_k0e_op_info)
|
|
||||||
def _bessel_k0e_tbe():
|
|
||||||
"""BesselK0e TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
A = [1.37446543561352307156E-16,
|
|
||||||
4.25981614279661018399E-14,
|
|
||||||
1.03496952576338420167E-11,
|
|
||||||
1.90451637722020886025E-9,
|
|
||||||
2.53479107902614945675E-7,
|
|
||||||
2.28621210311945178607E-5,
|
|
||||||
1.26461541144692592338E-3,
|
|
||||||
3.59799365153615016266E-2,
|
|
||||||
3.44289899924628486886E-1,
|
|
||||||
-5.35327393233902768720E-1]
|
|
||||||
B = [
|
|
||||||
5.30043377268626276149E-18, -1.64758043015242134646E-17,
|
|
||||||
5.21039150503902756861E-17, -1.67823109680541210385E-16,
|
|
||||||
5.51205597852431940784E-16, -1.84859337734377901440E-15,
|
|
||||||
6.34007647740507060557E-15, -2.22751332699166985548E-14,
|
|
||||||
8.03289077536357521100E-14, -2.98009692317273043925E-13,
|
|
||||||
1.14034058820847496303E-12, -4.51459788337394416547E-12,
|
|
||||||
1.85594911495471785253E-11, -7.95748924447710747776E-11,
|
|
||||||
3.57739728140030116597E-10, -1.69753450938905987466E-9,
|
|
||||||
8.57403401741422608519E-9, -4.66048989768794782956E-8,
|
|
||||||
2.76681363944501510342E-7, -1.83175552271911948767E-6,
|
|
||||||
1.39498137188764993662E-5, -1.28495495816278026384E-4,
|
|
||||||
1.56988388573005337491E-3, -3.14481013119645005427E-2,
|
|
||||||
2.44030308206595545468E0]
|
|
||||||
|
|
||||||
AA = [-4.41534164647933937950E-18, 3.33079451882223809783E-17,
|
|
||||||
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
|
|
||||||
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
|
|
||||||
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
|
|
||||||
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
|
|
||||||
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
|
|
||||||
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
|
||||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
|
||||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
|
||||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
|
||||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
|
||||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
|
||||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
|
||||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
|
||||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1
|
|
||||||
]
|
|
||||||
BB = [
|
|
||||||
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
|
|
||||||
4.46562142029675999901E-17, 3.46122286769746109310E-17,
|
|
||||||
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
|
|
||||||
1.77256013305652638360E-15, 3.81168066935262242075E-15,
|
|
||||||
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
|
|
||||||
1.54008621752140982691E-14, 3.85277838274214270114E-13,
|
|
||||||
7.18012445138366623367E-13, -1.79417853150680611778E-12,
|
|
||||||
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
|
|
||||||
1.18891471078464383424E-11, 4.94060238822496958910E-10,
|
|
||||||
3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
|
||||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
|
||||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
|
||||||
8.04490411014108831608E-1]
|
|
||||||
|
|
||||||
MAXNUM = 4294967295.0
|
|
||||||
TWO = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def chebevl(x, n, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
none_signal = None
|
|
||||||
for i in range(1, n):
|
|
||||||
none_signal = broad_zero
|
|
||||||
broad_zero = broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i0_compute(input_x):
|
|
||||||
"""bessel_i0_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
k0e_has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
k0e_has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight_in = dsl.vmuls(y, 0.5)
|
|
||||||
y_le_eight_in = dsl.vadds(y_le_eight_in, -2.0)
|
|
||||||
y_le_eight = chebevl(y_le_eight_in, 30, AA, shape, dtype)
|
|
||||||
|
|
||||||
y_gt_eight_in = dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0)
|
|
||||||
y_gt_eight = chebevl(y_gt_eight_in, 25, BB, shape, dtype)
|
|
||||||
y_gt_eight = dsl.vmul(y_gt_eight, dsl.vrsqrt(y))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(y))
|
|
||||||
|
|
||||||
if k0e_has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_k0e")
|
|
||||||
def bessel_k0e_compute(input_x, output_y, kernel_name="bessel_k0e"):
|
|
||||||
"""bessel_k0e_compute"""
|
|
||||||
shape = input_x.shape
|
|
||||||
dtype = input_x.dtype
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
x_le_two = chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 10, A, shape, dtype)
|
|
||||||
x_le_two = dsl.vadd(dsl.vmul(bessel_i0_compute(input_x), dsl.vmuls(dsl.vlog(dsl.vmuls(input_x, 0.5)), -1.0)),
|
|
||||||
x_le_two)
|
|
||||||
x_le_two = dsl.vmul(dsl.vexp(input_x), x_le_two)
|
|
||||||
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
|
|
||||||
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)), chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x),
|
|
||||||
8.0), -2.0), 25, B,
|
|
||||||
shape, dtype)), (dsl.vrsqrt(input_x)))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k0e(x, output, kernel_name="bessel_k0e"):
|
|
||||||
"""bessel_k0e"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_k0e_compute(data_x, output, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,178 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesseK1 op"""
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
|
|
||||||
bessel_k1_op_info = TBERegOp("BesselK1") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_k1.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_k1") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_k1_op_info)
|
|
||||||
def _bessel_k1_tbe():
|
|
||||||
"""BesselK1 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
AA = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
|
|
||||||
1.55363195773620046921E-16, -1.10559694773538630805E-15,
|
|
||||||
7.60068429473540693410E-15, -5.04218550472791168711E-14,
|
|
||||||
3.22379336594557470981E-13, -1.98397439776494371520E-12,
|
|
||||||
1.17361862988909016308E-11, -6.66348972350202774223E-11,
|
|
||||||
3.62559028155211703701E-10, -1.88724975172282928790E-9,
|
|
||||||
9.38153738649577178388E-9, -4.44505912879632808065E-8,
|
|
||||||
2.00329475355213526229E-7, -8.56872026469545474066E-7,
|
|
||||||
3.47025130813767847674E-6, -1.32731636560394358279E-5,
|
|
||||||
4.78156510755005422638E-5, -1.61760815825896745588E-4,
|
|
||||||
5.12285956168575772895E-4, -1.51357245063125314899E-3,
|
|
||||||
4.15642294431288815669E-3, -1.05640848946261981558E-2,
|
|
||||||
2.47264490306265168283E-2, -5.29459812080949914269E-2,
|
|
||||||
1.02643658689847095384E-1, -1.76416518357834055153E-1,
|
|
||||||
2.52587186443633654823E-1]
|
|
||||||
|
|
||||||
BB = [
|
|
||||||
7.51729631084210481353E-18, 4.41434832307170791151E-18,
|
|
||||||
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
|
|
||||||
2.96262899764595013876E-16, 3.30820231092092828324E-16,
|
|
||||||
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
|
|
||||||
1.04202769841288027642E-14, 4.27244001671195135429E-14,
|
|
||||||
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
|
|
||||||
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
|
|
||||||
1.41258074366137813316E-11, 3.25260358301548823856E-11,
|
|
||||||
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
|
|
||||||
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
|
|
||||||
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
|
|
||||||
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
|
|
||||||
7.78576235018280120474E-1]
|
|
||||||
|
|
||||||
A = [-7.02386347938628759343E-18, -2.42744985051936593393E-15,
|
|
||||||
-6.66690169419932900609E-13, -1.41148839263352776110E-10,
|
|
||||||
-2.21338763073472585583E-8, -2.43340614156596823496E-6,
|
|
||||||
-1.73028895751305206302E-4, -6.97572385963986435018E-3,
|
|
||||||
-1.22611180822657148235E-1, -3.53155960776544875667E-1,
|
|
||||||
1.52530022733894777053E0]
|
|
||||||
|
|
||||||
B = [-5.75674448366501715755E-18, 1.79405087314755922667E-17,
|
|
||||||
-5.68946255844285935196E-17, 1.83809354436663880070E-16,
|
|
||||||
-6.05704724837331885336E-16, 2.03870316562433424052E-15,
|
|
||||||
-7.01983709041831346144E-15, 2.47715442448130437068E-14,
|
|
||||||
-8.97670518232499435011E-14, 3.34841966607842919884E-13,
|
|
||||||
-1.28917396095102890680E-12, 5.13963967348173025100E-12,
|
|
||||||
-2.12996783842756842877E-11, 9.21831518760500529508E-11,
|
|
||||||
-4.19035475934189648750E-10, 2.01504975519703286596E-9,
|
|
||||||
-1.03457624656780970260E-8, 5.74108412545004946722E-8,
|
|
||||||
-3.50196060308781257119E-7, 2.40648494783721712015E-6,
|
|
||||||
-1.93619797416608296024E-5, 1.95215518471351631108E-4,
|
|
||||||
-2.85781685962277938680E-3, 1.03923736576817238437E-1,
|
|
||||||
2.72062619048444266945E0]
|
|
||||||
|
|
||||||
MAX_NUM = 4294967295.0
|
|
||||||
NUM_TWO = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k1_chebevl(x, n, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
k1_broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
k1_broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
k1_none_signal = None
|
|
||||||
for i in range(1, n):
|
|
||||||
k1_none_signal = k1_broad_zero
|
|
||||||
k1_broad_zero = k1_broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
k1_broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, k1_broad_zero), coef_i), k1_none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(k1_broad_coef, k1_none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i1_compute(input_x):
|
|
||||||
"""bessel_i1_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
k1_has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
k1_has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight = dsl.vmul(y, bessel_k1_chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, AA, shape, dtype))
|
|
||||||
y_gt_eight = dsl.vmul(bessel_k1_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, BB,
|
|
||||||
shape, dtype), dsl.vrsqrt(y))
|
|
||||||
|
|
||||||
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
|
|
||||||
|
|
||||||
if k1_has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_k1")
|
|
||||||
def bessel_k1_compute(input_x, output_y, kernel_name="bessel_k1"):
|
|
||||||
"""bessel_k1_compute"""
|
|
||||||
shape = input_x.shape
|
|
||||||
dtype = input_x.dtype
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
x_le_two = dsl.vdiv(bessel_k1_chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 11, A, shape, dtype), input_x)
|
|
||||||
x_le_two = dsl.vadd(dsl.vmul(bessel_i1_compute(input_x), dsl.vlog(dsl.vmuls(input_x, 0.5))), x_le_two)
|
|
||||||
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAX_NUM, x_le_two)
|
|
||||||
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)),
|
|
||||||
bessel_k1_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(input_x), 8.0), -2.0), 25,
|
|
||||||
B, shape, dtype)), (dsl.vrsqrt(input_x)))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(input_x, NUM_TWO, 'le', x_le_two, x_gt_two)
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k1(x, output, kernel_name="bessel_k1"):
|
|
||||||
"""bessel_k1"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_k1_compute(data_x, output, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,178 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesseK1e op"""
|
|
||||||
|
|
||||||
from tbe import dsl
|
|
||||||
from te import tvm
|
|
||||||
from te.platform.fusion_manager import fusion_manager
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_k1e_op_info = TBERegOp("BesselK1e") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_k1e.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_k1e") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_k1e_op_info)
|
|
||||||
def _bessel_k1e_tbe():
|
|
||||||
"""BesselK1e TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
AA = [2.77791411276104639959E-18, -2.11142121435816608115E-17,
|
|
||||||
1.55363195773620046921E-16, -1.10559694773538630805E-15,
|
|
||||||
7.60068429473540693410E-15, -5.04218550472791168711E-14,
|
|
||||||
3.22379336594557470981E-13, -1.98397439776494371520E-12,
|
|
||||||
1.17361862988909016308E-11, -6.66348972350202774223E-11,
|
|
||||||
3.62559028155211703701E-10, -1.88724975172282928790E-9,
|
|
||||||
9.38153738649577178388E-9, -4.44505912879632808065E-8,
|
|
||||||
2.00329475355213526229E-7, -8.56872026469545474066E-7,
|
|
||||||
3.47025130813767847674E-6, -1.32731636560394358279E-5,
|
|
||||||
4.78156510755005422638E-5, -1.61760815825896745588E-4,
|
|
||||||
5.12285956168575772895E-4, -1.51357245063125314899E-3,
|
|
||||||
4.15642294431288815669E-3, -1.05640848946261981558E-2,
|
|
||||||
2.47264490306265168283E-2, -5.29459812080949914269E-2,
|
|
||||||
1.02643658689847095384E-1, -1.76416518357834055153E-1,
|
|
||||||
2.52587186443633654823E-1]
|
|
||||||
|
|
||||||
BB = [
|
|
||||||
7.51729631084210481353E-18, 4.41434832307170791151E-18,
|
|
||||||
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
|
|
||||||
2.96262899764595013876E-16, 3.30820231092092828324E-16,
|
|
||||||
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
|
|
||||||
1.04202769841288027642E-14, 4.27244001671195135429E-14,
|
|
||||||
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
|
|
||||||
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
|
|
||||||
1.41258074366137813316E-11, 3.25260358301548823856E-11,
|
|
||||||
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
|
|
||||||
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
|
|
||||||
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
|
|
||||||
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
|
|
||||||
7.78576235018280120474E-1]
|
|
||||||
|
|
||||||
A = [-7.02386347938628759343E-18, -2.42744985051936593393E-15,
|
|
||||||
-6.66690169419932900609E-13, -1.41148839263352776110E-10,
|
|
||||||
-2.21338763073472585583E-8, -2.43340614156596823496E-6,
|
|
||||||
-1.73028895751305206302E-4, -6.97572385963986435018E-3,
|
|
||||||
-1.22611180822657148235E-1, -3.53155960776544875667E-1,
|
|
||||||
1.52530022733894777053E0]
|
|
||||||
|
|
||||||
B = [-5.75674448366501715755E-18, 1.79405087314755922667E-17,
|
|
||||||
-5.68946255844285935196E-17, 1.83809354436663880070E-16,
|
|
||||||
-6.05704724837331885336E-16, 2.03870316562433424052E-15,
|
|
||||||
-7.01983709041831346144E-15, 2.47715442448130437068E-14,
|
|
||||||
-8.97670518232499435011E-14, 3.34841966607842919884E-13,
|
|
||||||
-1.28917396095102890680E-12, 5.13963967348173025100E-12,
|
|
||||||
-2.12996783842756842877E-11, 9.21831518760500529508E-11,
|
|
||||||
-4.19035475934189648750E-10, 2.01504975519703286596E-9,
|
|
||||||
-1.03457624656780970260E-8, 5.74108412545004946722E-8,
|
|
||||||
-3.50196060308781257119E-7, 2.40648494783721712015E-6,
|
|
||||||
-1.93619797416608296024E-5, 1.95215518471351631108E-4,
|
|
||||||
-2.85781685962277938680E-3, 1.03923736576817238437E-1,
|
|
||||||
2.72062619048444266945E0]
|
|
||||||
|
|
||||||
MAXNUM = 4294967295.0
|
|
||||||
TWO = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k1e_chebevl(x, n, coef, shape, dtype):
|
|
||||||
"""chebevl"""
|
|
||||||
broad_coef = dsl.broadcast(coef[0], shape, dtype)
|
|
||||||
broad_zero = dsl.broadcast(0, shape, dtype)
|
|
||||||
none_signal = None
|
|
||||||
for i in range(1, n):
|
|
||||||
none_signal = broad_zero
|
|
||||||
broad_zero = broad_coef
|
|
||||||
coef_i = dsl.broadcast(coef[i], shape, dtype)
|
|
||||||
broad_coef = dsl.vsub(dsl.vadd(dsl.vmul(x, broad_zero), coef_i), none_signal)
|
|
||||||
return dsl.vmuls(dsl.vsub(broad_coef, none_signal), 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_i1_compute(input_x):
|
|
||||||
"""bessel_i1_compute"""
|
|
||||||
dtype = input_x.dtype
|
|
||||||
shape = input_x.shape
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
y = dsl.vabs(input_x)
|
|
||||||
|
|
||||||
y_le_eight = dsl.vmul(y, bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(y, 0.5), -2), 29, AA, shape, dtype))
|
|
||||||
y_gt_eight = dsl.vmul(bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(dsl.vrec(y), 32.0), -2.0), 25, BB, shape, dtype),
|
|
||||||
dsl.vrsqrt(y))
|
|
||||||
|
|
||||||
y = dsl.vcmpsel(y, 8.0, 'le', y_le_eight, y_gt_eight)
|
|
||||||
res = dsl.vcmpsel(input_x, 0, 'lt', dsl.vmuls(y, -1.0), y)
|
|
||||||
res = dsl.vmul(res, dsl.vexp(dsl.vabs(input_x)))
|
|
||||||
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("bessel_k1e")
|
|
||||||
def bessel_k1e_compute(input_x, output_y, kernel_name="bessel_k1e"):
|
|
||||||
"""bessel_k1e_compute"""
|
|
||||||
shape = input_x.shape
|
|
||||||
dtype = input_x.dtype
|
|
||||||
|
|
||||||
has_improve_precision = False
|
|
||||||
if dtype != "float32":
|
|
||||||
input_x = dsl.cast_to(input_x, "float32")
|
|
||||||
dtype = "float32"
|
|
||||||
has_improve_precision = True
|
|
||||||
|
|
||||||
x_le_two = dsl.vdiv(bessel_k1e_chebevl(dsl.vadds(dsl.vmul(input_x, input_x), -2.0), 11, A, shape, dtype), input_x)
|
|
||||||
x_le_two = dsl.vadd(dsl.vmul(bessel_i1_compute(input_x), dsl.vlog(dsl.vmuls(input_x, 0.5))), x_le_two)
|
|
||||||
x_le_two = dsl.vmul(x_le_two, dsl.vexp(input_x))
|
|
||||||
x_le_two = dsl.vcmpsel(input_x, 0.0, 'le', MAXNUM, x_le_two)
|
|
||||||
x_gt_two = dsl.vmul(dsl.vmul(dsl.vexp(dsl.vmuls(input_x, -1.0)), bessel_k1e_chebevl(dsl.vadds(dsl.vmuls(
|
|
||||||
dsl.vrec(input_x), 8.0), -2.0), 25, B, shape, dtype)), (dsl.vrsqrt(input_x)))
|
|
||||||
|
|
||||||
res = dsl.vcmpsel(input_x, TWO, 'le', x_le_two, x_gt_two)
|
|
||||||
if has_improve_precision:
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_k1e(x, output, kernel_name="bessel_k1e"):
|
|
||||||
"""bessel_k1e"""
|
|
||||||
data_x = tvm.placeholder(x.get("shape"), dtype=x.get("dtype"), name="data_x")
|
|
||||||
|
|
||||||
res = bessel_k1e_compute(data_x, output, kernel_name)
|
|
||||||
|
|
||||||
# auto schedule
|
|
||||||
with tvm.target.cce():
|
|
||||||
schedule = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
# operator build
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"tensor_list": [data_x, res]}
|
|
||||||
dsl.build(schedule, config)
|
|
|
@ -1,424 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselY0 op"""
|
|
||||||
|
|
||||||
import te.lang.cce as tbe
|
|
||||||
from te import tvm
|
|
||||||
from te.utils import para_check
|
|
||||||
from te.utils import shape_util
|
|
||||||
from tbe.common.platform import api_check_support
|
|
||||||
from tbe.common.register import register_op_compute
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_y0_op_info = TBERegOp("BesselY0") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_y0.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_y0") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_y0_op_info)
|
|
||||||
def _bessel_y0_tbe():
|
|
||||||
"""BesselY0 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
ITR_A1 = (57568490574.0, -13362590354.0, 651619640.7, -11214424.18, 77392.33017, -184.9052456)
|
|
||||||
ITR_A2 = (57568490411.0, 1029532985.0, 9494680.718, 59272.64853, 267.8532712, 1.0)
|
|
||||||
ITR_A33 = (1.0, -0.1098628627e-2, 0.2734510407e-4, -0.2073370639e-5, 0.2093887211e-6)
|
|
||||||
ITR_A44 = (-0.1562499995e-1, 0.1430488765e-3, -0.6911147651e-5, 0.7621095161e-6, 0.934935152e-7)
|
|
||||||
ITR_A5 = (-2957821389.0, 7062834065.0, -512359803.6, 10879881.29, -86327.92757, 228.4622733)
|
|
||||||
ITR_A6 = (40076544269.0, 745249964.8, 7189466.438, 47447.26470, 226.1030244, 1.0)
|
|
||||||
ITR_A = (-0.785398164, 0.636619772)
|
|
||||||
LEN_A1256 = 6
|
|
||||||
LEN_A34 = 5
|
|
||||||
EIGHT = 8.0
|
|
||||||
ZERO = 0
|
|
||||||
NEG_ONE = -1.0
|
|
||||||
ONE = 1.0
|
|
||||||
PI2 = 6.2831853071796
|
|
||||||
HALF_PI = 1.5707963267948966192313216916398
|
|
||||||
BOUNDARY_1 = 0.70710678118654752440084436210485
|
|
||||||
# Taylor coefficient
|
|
||||||
COEF = (1.0,
|
|
||||||
0.16666666666666666666666666666667,
|
|
||||||
0.075,
|
|
||||||
0.04464285714285714285714285714286,
|
|
||||||
0.03038194444444444444444444444444,
|
|
||||||
0.02237215909090909090909090909091,
|
|
||||||
0.01735276442307692307692307692308,
|
|
||||||
0.01396484375)
|
|
||||||
# TAYLOR COUNT
|
|
||||||
TAYLOR_COUNT = 7
|
|
||||||
# negative min float16 value
|
|
||||||
NEG_MIN_FP16 = -2 ** (-24)
|
|
||||||
# min float16 * 2
|
|
||||||
TWO_MIN_FP16 = 2 ** (-23)
|
|
||||||
|
|
||||||
|
|
||||||
def _taylor_compute(data_x, x_square=None):
|
|
||||||
"""_taylor_compute"""
|
|
||||||
if x_square is None:
|
|
||||||
x_square = tbe.vmul(data_x, data_x)
|
|
||||||
|
|
||||||
res = tbe.vmuls(x_square, tvm.const(COEF[TAYLOR_COUNT], "float32"))
|
|
||||||
for temp in reversed(range(TAYLOR_COUNT)):
|
|
||||||
res = tbe.vadds(res, tvm.const(COEF[temp], "float32"))
|
|
||||||
if temp == 0:
|
|
||||||
res = tbe.vmul(res, data_x)
|
|
||||||
else:
|
|
||||||
res = tbe.vmul(x_square, res)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def acos_compute(x):
|
|
||||||
"""
|
|
||||||
do element-wise acos compute using asin op
|
|
||||||
acos(x) = HALF_PI - asin(x)
|
|
||||||
|
|
||||||
asin(x) = | arcsin(sqrt(1-x^2)) - HALF_PI, x belongs to (-1, -2^(-0.5))
|
|
||||||
| the 15th order taylor expansion, x belongs to (-2^(-0.5),
|
|
||||||
| 2^(-0.5))
|
|
||||||
| HALF_PI - arcsin(sqrt(1-x^2)), x belongs to (2^(-0.5), 1)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
----------
|
|
||||||
x: the placeholder of data input
|
|
||||||
|
|
||||||
y : the dict of output
|
|
||||||
|
|
||||||
Returns : A Tensor. Has the same type as x.
|
|
||||||
-------
|
|
||||||
"""
|
|
||||||
shape = x.shape
|
|
||||||
dtype = x.dtype
|
|
||||||
# Change dtype to float32
|
|
||||||
if (dtype in ('float16', 'double')) and \
|
|
||||||
api_check_support("te.lang.cce.vadd", "float32"):
|
|
||||||
x = tbe.cast_to(x, "float32")
|
|
||||||
|
|
||||||
# to fix bug for input data is 1.0
|
|
||||||
x = tbe.vadds(x, NEG_MIN_FP16)
|
|
||||||
# Sign mask
|
|
||||||
sign = tbe.vcmpsel(x, 0, 'lt', NEG_ONE, ONE)
|
|
||||||
# All positive
|
|
||||||
x = tbe.vmul(x, sign)
|
|
||||||
|
|
||||||
# x belongs to (0, 2^(-0.5))
|
|
||||||
if api_check_support("te.lang.cce.vmins", x.dtype):
|
|
||||||
choice_1 = tbe.vmins(x, tvm.const(BOUNDARY_1, x.dtype))
|
|
||||||
else:
|
|
||||||
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, x.dtype), shape)
|
|
||||||
choice_1 = tbe.vmin(x, boundary_mask1)
|
|
||||||
|
|
||||||
if api_check_support("te.lang.cce.vsubs", choice_1.dtype):
|
|
||||||
choice_1 = tbe.vsubs(choice_1, tvm.const(BOUNDARY_1, choice_1.dtype))
|
|
||||||
else:
|
|
||||||
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, choice_1.dtype), shape)
|
|
||||||
choice_1 = tbe.vsub(choice_1, boundary_mask1)
|
|
||||||
|
|
||||||
choice_1 = tbe.vmuls(tbe.floor(choice_1), NEG_ONE)
|
|
||||||
|
|
||||||
res_1 = _taylor_compute(x)
|
|
||||||
res_1 = tbe.vmul(res_1, choice_1)
|
|
||||||
|
|
||||||
# x belongs to (2^(-0.5), 1)
|
|
||||||
choice_2 = tbe.vmuls(choice_1, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
choice_2 = tbe.vadds(choice_2, tvm.const(ONE, x.dtype))
|
|
||||||
|
|
||||||
# to fix bug for input data is 1.0
|
|
||||||
x = tbe.vadds(x, TWO_MIN_FP16)
|
|
||||||
res_2 = tbe.vmul(x, x)
|
|
||||||
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
res_2 = tbe.vadds(res_2, tvm.const(ONE, x.dtype))
|
|
||||||
res_2_sqrt = tbe.vsqrt(res_2, 1)
|
|
||||||
|
|
||||||
res_2 = _taylor_compute(res_2_sqrt, res_2)
|
|
||||||
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
res_2 = tbe.vadds(res_2, tvm.const(HALF_PI, x.dtype))
|
|
||||||
res_2 = tbe.vmul(res_2, choice_2)
|
|
||||||
|
|
||||||
# Restore sign of asin
|
|
||||||
res_1 = tbe.vadd(res_1, res_2)
|
|
||||||
res_1 = tbe.vmul(res_1, sign)
|
|
||||||
res_1 = tbe.vmuls(res_1, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
res_1 = tbe.vadds(res_1, tvm.const(HALF_PI, x.dtype))
|
|
||||||
|
|
||||||
return res_1
|
|
||||||
|
|
||||||
|
|
||||||
def asin_compute(x):
|
|
||||||
"""
|
|
||||||
do element-wise asin compute
|
|
||||||
asin(x) = | arcsin(sqrt(1-x^2)) - HALF_PI, x belongs to (-1, -2^(-0.5))
|
|
||||||
| the 15th order taylor expansion, x belongs to (-2^(-0.5), 2^(-0.5))
|
|
||||||
| HALF_PI - arcsin(sqrt(1-x^2)), x belongs to (2^(-0.5), 1)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
----------
|
|
||||||
x: the placeholder of data input
|
|
||||||
|
|
||||||
y : the dict of output
|
|
||||||
|
|
||||||
Returns : A Tensor. Has the same type as data_input.
|
|
||||||
-------
|
|
||||||
"""
|
|
||||||
shape = x.shape
|
|
||||||
dtype = x.dtype
|
|
||||||
# Change dtype to float32
|
|
||||||
if (dtype in ('float16', 'double')) and api_check_support("te.lang.cce.vadd", "float32"):
|
|
||||||
x = tbe.cast_to(x, "float32")
|
|
||||||
|
|
||||||
# Sign mask
|
|
||||||
bessely0_sign = tbe.vcmpsel(x, 0, 'lt', NEG_ONE, ONE)
|
|
||||||
|
|
||||||
# All positive
|
|
||||||
x = tbe.vmul(x, bessely0_sign)
|
|
||||||
|
|
||||||
# x belongs to (0, 2^(-0.5))
|
|
||||||
if api_check_support("te.lang.cce.vmins", x.dtype):
|
|
||||||
y0_choice_1 = tbe.vmins(x, tvm.const(BOUNDARY_1, x.dtype))
|
|
||||||
else:
|
|
||||||
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, x.dtype), shape)
|
|
||||||
y0_choice_1 = tbe.vmin(x, boundary_mask1)
|
|
||||||
|
|
||||||
if api_check_support("te.lang.cce.vsubs", y0_choice_1.dtype):
|
|
||||||
y0_choice_1 = tbe.vsubs(y0_choice_1, tvm.const(BOUNDARY_1, y0_choice_1.dtype))
|
|
||||||
else:
|
|
||||||
boundary_mask1 = tbe.broadcast(tvm.const(BOUNDARY_1, y0_choice_1.dtype), shape)
|
|
||||||
y0_choice_1 = tbe.vsub(y0_choice_1, boundary_mask1)
|
|
||||||
|
|
||||||
y0_choice_1 = tbe.vmuls(tbe.floor(y0_choice_1), NEG_ONE)
|
|
||||||
|
|
||||||
res_1 = _taylor_compute(x)
|
|
||||||
res_1 = tbe.vmul(res_1, y0_choice_1)
|
|
||||||
|
|
||||||
# x belongs to (2^(-0.5), 1)
|
|
||||||
choice_2 = tbe.vmuls(y0_choice_1, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
choice_2 = tbe.vadds(choice_2, tvm.const(ONE, x.dtype))
|
|
||||||
|
|
||||||
res_2 = tbe.vmul(x, x)
|
|
||||||
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
res_2 = tbe.vadds(res_2, tvm.const(ONE, x.dtype))
|
|
||||||
res_2_sqrt = tbe.vsqrt(res_2)
|
|
||||||
|
|
||||||
res_2 = _taylor_compute(res_2_sqrt, res_2)
|
|
||||||
|
|
||||||
res_2 = tbe.vmuls(res_2, tvm.const(NEG_ONE, x.dtype))
|
|
||||||
res_2 = tbe.vadds(res_2, tvm.const(HALF_PI, x.dtype))
|
|
||||||
res_2 = tbe.vmul(res_2, choice_2)
|
|
||||||
|
|
||||||
# Restore sign
|
|
||||||
res_1 = tbe.vadd(res_1, res_2)
|
|
||||||
res_1 = tbe.vmul(res_1, bessely0_sign)
|
|
||||||
|
|
||||||
return res_1
|
|
||||||
|
|
||||||
|
|
||||||
def _besselj0(x):
|
|
||||||
"""
|
|
||||||
Algrithm:
|
|
||||||
y = x * x;
|
|
||||||
ans1 = 57568490574.0 + y * (-13362590354.0 + y * (651619640.7 + y * (-11214424.18 + y *
|
|
||||||
(77392.33017 + y * (-184.9052456)))))
|
|
||||||
ans2 = 57568490411.0 + y * (1029532985.0 + y * (9494680.718 + y * (59272.64853 + y * (267.8532712 + y * 1.0))))
|
|
||||||
ans = ans1 / ans2 (x < 8.0)
|
|
||||||
z = 8.0 / x
|
|
||||||
y = z * z
|
|
||||||
xx = ax - 0.785398164;
|
|
||||||
ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
|
|
||||||
ans2 = -0.1562499995e-1 + y * (0.1430488765e-3 + y * (-0.6911147651e-5 +
|
|
||||||
y * (0.7621095161e-6 - y * 0.934935152e-7)));
|
|
||||||
ans = sqrt(0.636619772 / ax) * (cos(xx) * ans1 - z * sin(xx) * ans2), (x >= 8.0)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x: the placeholder of data input
|
|
||||||
|
|
||||||
y : the dict of output
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A tensor. Has the same type as x.
|
|
||||||
|
|
||||||
"""
|
|
||||||
jax = tbe.vabs(x)
|
|
||||||
tensor_eight = tbe.broadcast(tvm.const(EIGHT, jax.dtype), jax.shape)
|
|
||||||
first_res = tbe.vmax(jax, tensor_eight)
|
|
||||||
jz = tbe.vdiv(tensor_eight, first_res)
|
|
||||||
jy = tbe.vmul(jz, jz)
|
|
||||||
jxx = tbe.vadds(first_res, ITR_A[0])
|
|
||||||
jans1 = tbe.vmuls(jy, tvm.const(ITR_A33[LEN_A34 - 1]))
|
|
||||||
jans1 = tbe.vadds(jans1, ITR_A33[LEN_A34 - 2])
|
|
||||||
for index in reversed(range(LEN_A34 - 2)):
|
|
||||||
jans1 = tbe.vmul(jans1, jy)
|
|
||||||
jans1 = tbe.vadds(jans1, ITR_A33[index])
|
|
||||||
jans2 = tbe.vmuls(jy, tvm.const(ITR_A44[LEN_A34 - 1]))
|
|
||||||
jans2 = tbe.vadds(jans2, ITR_A44[LEN_A34 - 2])
|
|
||||||
for index in reversed(range(LEN_A34 - 2)):
|
|
||||||
jans2 = tbe.vmul(jans2, jy)
|
|
||||||
jans2 = tbe.vadds(jans2, ITR_A44[index])
|
|
||||||
jansres1 = tbe.vmul(tbe.vsqrt(tbe.vmuls(tbe.vrec(first_res), ITR_A[1])),
|
|
||||||
tbe.vsub(tbe.vmul(acos_compute(jxx), jans1), tbe.vmul(jz, tbe.vmul(asin_compute(jxx), jans2))))
|
|
||||||
|
|
||||||
first_res = tbe.vmin(jax, tensor_eight)
|
|
||||||
jy = tbe.vmul(first_res, first_res)
|
|
||||||
jans1 = tbe.vmuls(jy, tvm.const(ITR_A1[LEN_A1256 - 1]))
|
|
||||||
jans1 = tbe.vadds(jans1, ITR_A1[LEN_A1256 - 2])
|
|
||||||
for index in reversed(range(LEN_A1256 - 2)):
|
|
||||||
jans1 = tbe.vmul(jans1, jy)
|
|
||||||
jans1 = tbe.vadds(jans1, ITR_A1[index])
|
|
||||||
jans2 = tbe.vmuls(jy, tvm.const(ITR_A2[LEN_A1256 - 1]))
|
|
||||||
jans2 = tbe.vadds(jans2, ITR_A2[LEN_A1256 - 2])
|
|
||||||
for index in reversed(range(LEN_A1256 - 2)):
|
|
||||||
jans2 = tbe.vmul(jans2, jy)
|
|
||||||
jans2 = tbe.vadds(jans2, ITR_A2[index])
|
|
||||||
jansres2 = tbe.vdiv(jans1, jans2)
|
|
||||||
res = tbe.vcmpsel(jax, tensor_eight, 'lt', jansres2, jansres1)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@register_op_compute("bessel_y0")
|
|
||||||
def bessel_y0_compute(x, y, kernel_name="bessel_y0"):
|
|
||||||
"""
|
|
||||||
Algrithm:
|
|
||||||
y = x * x;
|
|
||||||
ans1 = -2957821389.0 + y * (7062834065.0 + y * (-512359803.6 + y * (10879881.29 +
|
|
||||||
y * (-86327.92757 + y * 228.4622733))))
|
|
||||||
ans2 = 40076544269.0 + y * (745249964.8 + y * (7189466.438 + y * (47447.26470 + y * (226.1030244 + y * 1.0))))
|
|
||||||
ans = (ans1 / ans2) + 0.636619772 * bessj0(x) * math.log(x), (x < 8.0)
|
|
||||||
|
|
||||||
z = 8.0 / x
|
|
||||||
y = z * z
|
|
||||||
xx = x - 0.785398164
|
|
||||||
ans1 = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)))
|
|
||||||
ans2 = -0.1562499995e-1 + y * (0.1430488765e-3 + y * (-0.6911147651e-5 +
|
|
||||||
y * (0.7621095161e-6 + y * (-0.934945152e-7))))
|
|
||||||
ans = math.sqrt(0.636619772 / x) * (math.asin(xx) * ans1 + z * math.acos(xx) * ans2), (x >= 8.0)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x: the placeholder of data input
|
|
||||||
|
|
||||||
y : the dict of output
|
|
||||||
|
|
||||||
kernel_name : cce kernel name, default value is "bessel_y0"
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A tensor. Has the same type as x.
|
|
||||||
|
|
||||||
"""
|
|
||||||
dtype_input = x.dtype
|
|
||||||
|
|
||||||
# chose the type of data in begin
|
|
||||||
if (dtype_input in ('float16', 'double', 'float32')) \
|
|
||||||
and api_check_support("te.lang.cce.vadd", "float32"):
|
|
||||||
x = tbe.cast_to(x, "float32")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
|
|
||||||
x = tbe.vabs(x)
|
|
||||||
tensor_eight = tbe.broadcast(tvm.const(EIGHT, x.dtype), x.shape)
|
|
||||||
first_res = tbe.vmin(x, tensor_eight)
|
|
||||||
yy = tbe.vmul(first_res, first_res)
|
|
||||||
ans1 = tbe.vmuls(yy, tvm.const(ITR_A5[LEN_A1256 - 1]))
|
|
||||||
ans1 = tbe.vadds(ans1, ITR_A5[LEN_A1256 - 2])
|
|
||||||
for index in reversed(range(LEN_A1256 - 2)):
|
|
||||||
ans1 = tbe.vmul(ans1, yy)
|
|
||||||
ans1 = tbe.vadds(ans1, ITR_A5[index])
|
|
||||||
ans2 = tbe.vmuls(yy, tvm.const(ITR_A6[LEN_A1256 - 1]))
|
|
||||||
ans2 = tbe.vadds(ans2, ITR_A6[LEN_A1256 - 2])
|
|
||||||
for index in reversed(range(LEN_A1256 - 2)):
|
|
||||||
ans2 = tbe.vmul(ans2, yy)
|
|
||||||
ans2 = tbe.vadds(ans2, ITR_A6[index])
|
|
||||||
res1 = tbe.vadd(tbe.vdiv(ans1, ans2), tbe.vmuls(tbe.vmul(_besselj0(first_res), tbe.vlog(first_res)), ITR_A[1]))
|
|
||||||
|
|
||||||
first_res = tbe.vmax(x, tensor_eight)
|
|
||||||
z = tbe.vdiv(tensor_eight, first_res)
|
|
||||||
y = tbe.vmul(z, z)
|
|
||||||
xx = tbe.vadds(first_res, ITR_A[0])
|
|
||||||
ans1 = tbe.vmuls(y, tvm.const(ITR_A33[LEN_A34 - 1]))
|
|
||||||
ans1 = tbe.vadds(ans1, ITR_A33[LEN_A34 - 2])
|
|
||||||
for index in reversed(range(LEN_A34 - 2)):
|
|
||||||
ans1 = tbe.vmul(ans1, y)
|
|
||||||
ans1 = tbe.vadds(ans1, ITR_A33[index])
|
|
||||||
ans2 = tbe.vmuls(y, tvm.const(ITR_A44[LEN_A34 - 1]))
|
|
||||||
ans2 = tbe.vadds(ans2, ITR_A44[LEN_A34 - 2])
|
|
||||||
for index in reversed(range(LEN_A34 - 2)):
|
|
||||||
ans2 = tbe.vmul(ans2, y)
|
|
||||||
ans2 = tbe.vadds(ans2, ITR_A44[index])
|
|
||||||
res2 = tbe.vmul(tbe.vsqrt(tbe.vmuls(tbe.vrec(first_res), ITR_A[1])),
|
|
||||||
tbe.vadd(tbe.vmul(asin_compute(xx), ans1),
|
|
||||||
tbe.vmul(z, tbe.vmul(acos_compute(xx), ans2))))
|
|
||||||
res = tbe.vcmpsel(x, tensor_eight, 'lt', res1, res2)
|
|
||||||
|
|
||||||
# Restore dtype
|
|
||||||
if dtype_input == "float16":
|
|
||||||
res = tbe.cast_to(res, "float16")
|
|
||||||
if dtype_input == "double":
|
|
||||||
res = tbe.cast_to(res, "double")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
|
|
||||||
def bessel_y0(x, y, kernel_name="bessel_y0"):
|
|
||||||
"""
|
|
||||||
Computes the Bessel y0 function of x element-wise.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x: only support float16, float32, double
|
|
||||||
|
|
||||||
y : output
|
|
||||||
|
|
||||||
kernel_name : cce kernel name, default value is "bessel_y0"
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
shape_input = x.get("shape")
|
|
||||||
dtype_input = x.get("dtype")
|
|
||||||
|
|
||||||
para_check.check_shape(shape_input, param_name="x")
|
|
||||||
shape_input, _ = shape_util.refine_shape_axes(shape_input, [])
|
|
||||||
|
|
||||||
check_list = ("float16", "float32", "double")
|
|
||||||
para_check.check_dtype(dtype_input, check_list, param_name="x")
|
|
||||||
|
|
||||||
input_dtype = dtype_input.lower()
|
|
||||||
data = tvm.placeholder(shape_input, dtype=input_dtype, name="data_input")
|
|
||||||
|
|
||||||
res = bessel_y0_compute(data, y, kernel_name)
|
|
||||||
|
|
||||||
with tvm.target.cce():
|
|
||||||
sch = tbe.auto_schedule(res)
|
|
||||||
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"print_ir": False,
|
|
||||||
"tensor_list": (data, res),
|
|
||||||
"bool_storage_as_1bit": False}
|
|
||||||
tbe.cce_build_code(sch, config)
|
|
|
@ -1,281 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""BesselY1 op"""
|
|
||||||
|
|
||||||
import tbe.dsl as dsl
|
|
||||||
from tbe.common.platform import api_check_support
|
|
||||||
from tbe.common.register import register_op_compute
|
|
||||||
from tbe.common.utils import para_check
|
|
||||||
from tbe.common.utils import shape_util
|
|
||||||
from te import tvm
|
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
||||||
|
|
||||||
bessel_y1_op_info = TBERegOp("BesselY1") \
|
|
||||||
.fusion_type("ELEMWISE") \
|
|
||||||
.async_flag(False) \
|
|
||||||
.binfile_name("bessel_y1.so") \
|
|
||||||
.compute_cost(10) \
|
|
||||||
.kernel_name("bessel_y1") \
|
|
||||||
.partial_flag(True) \
|
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", False, "required", "all") \
|
|
||||||
.output(0, "y", False, "required", "all") \
|
|
||||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
||||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
||||||
.get_op_info()
|
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(bessel_y1_op_info)
|
|
||||||
def _bessel_y1_tbe():
|
|
||||||
"""BesselY1 TBE register"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
CONST_LIMIT = 8.0
|
|
||||||
ITR_BEFORE1_J = [72362614232.0, -7895059235.0, 242396853.1, -2972611.439, 15704.48260, -30.16036606]
|
|
||||||
ITR_BEFORE2_J = [144725228442.0, 2300535178.0, 18583304.74, 99447.43394, 376.9991397, 1.0]
|
|
||||||
ITR_AFTER1_J = [1.0, 0.183105e-2, -0.3516396496e-4, 0.2457520174e-5, -0.240337019e-6]
|
|
||||||
ITR_AFTER2_J = [0.04687499995, -0.2002690873e-3, 0.8449199096e-5, -0.88228987e-6, 0.105787412e-6]
|
|
||||||
ITR_BEFORE1 = [-0.4900604943e13, 0.1275274390e13, -0.5153438139e11, 0.7349264551e9,
|
|
||||||
-0.4237922726e7, 0.8511937935e4]
|
|
||||||
ITR_BEFORE2 = [0.2499580570e14, 0.4244419664e12, 0.3733650367e10, 0.2245904002e8,
|
|
||||||
0.1020426050e6, 0.3549632885e3, 1.0]
|
|
||||||
THREE_QUARTERS_PI = 2.356194490192345
|
|
||||||
ITR_AFTER1 = [1.0, 0.183105e-2, -0.3516396496e-4, 0.2457520174e-5, -0.240337019e-6]
|
|
||||||
ITR_AFTER2 = [0.04687499995, -0.2002690873e-3, 0.8449199096e-5, -0.88228987e-6, 0.105787412e-6]
|
|
||||||
TOW_OVER_PI = 0.636619772367581
|
|
||||||
PI = 3.141592653589793
|
|
||||||
|
|
||||||
|
|
||||||
def angle_trans_cal(x):
|
|
||||||
"""angle_trans_cal"""
|
|
||||||
consult = dsl.vdiv(x, dsl.broadcast(tvm.const(PI * 2), x.shape))
|
|
||||||
floor_consult = dsl.cast_to(dsl.floor(consult), 'float32')
|
|
||||||
fixed_x = dsl.vsub(x, dsl.vmuls(floor_consult, PI * 2))
|
|
||||||
|
|
||||||
coe = -0.707106781186548 # -sqrt(2)/2
|
|
||||||
quarter_x = dsl.vmuls(x, 0.25)
|
|
||||||
sin_quar_x, cos_quar_x = cordic(quarter_x)
|
|
||||||
cos_quar_x2 = dsl.vmul(cos_quar_x, cos_quar_x)
|
|
||||||
sin_quar_x2 = dsl.vmul(sin_quar_x, sin_quar_x)
|
|
||||||
cos_quar_x4 = dsl.vmul(cos_quar_x2, cos_quar_x2)
|
|
||||||
|
|
||||||
temp_res1 = dsl.vadds(dsl.vadd(dsl.vmuls(cos_quar_x2, -8.0), dsl.vmuls(cos_quar_x4, 8.0)), 1)
|
|
||||||
temp_res2 = dsl.vmuls(sin_quar_x2, 2.0)
|
|
||||||
temp_res2 = dsl.vadds(temp_res2, -1.0)
|
|
||||||
temp_res2 = dsl.vmul(temp_res2, sin_quar_x)
|
|
||||||
temp_res2 = dsl.vmul(temp_res2, cos_quar_x)
|
|
||||||
temp_res2 = dsl.vmuls(temp_res2, -4.0)
|
|
||||||
|
|
||||||
sin_res = dsl.vmuls(dsl.vadd(temp_res1, temp_res2), coe)
|
|
||||||
cos_res = dsl.vmuls(dsl.vsub(temp_res2, temp_res1), coe)
|
|
||||||
return sin_res, cos_res
|
|
||||||
|
|
||||||
|
|
||||||
def cordic(angle):
|
|
||||||
"""cordic"""
|
|
||||||
shape = angle.shape
|
|
||||||
dtype = angle.dtype
|
|
||||||
ceof = [1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64, 1 / 128, 1 / 256,
|
|
||||||
1 / 512, 1 / 1024, 1 / 2048, 1 / 4096, 1 / 8192, 1 / 16384, 1 / 32768,
|
|
||||||
1 / 65536, 1 / 131072, 1 / 262144, 1 / 524288, 1 / 1048576]
|
|
||||||
|
|
||||||
dangle = [45, 26.565051177078, 14.0362434679265, 7.1250163489018,
|
|
||||||
3.57633437499735, 1.78991060824607, 0.8951737102111,
|
|
||||||
0.4476141708606, 0.2238105003685, 0.1119056770662,
|
|
||||||
0.0559528918938, 0.027976452617, 0.01398822714227,
|
|
||||||
0.006994113675353, 0.003497056950704, 0.001748528426980,
|
|
||||||
0.000874264213694, 0.000437132106872, 0.000218566053439,
|
|
||||||
0.000109283026720, 0.000054641513360]
|
|
||||||
|
|
||||||
k = 0.60725293500888
|
|
||||||
x = dsl.broadcast(1.0, shape, dtype)
|
|
||||||
y = dsl.broadcast(0.0, shape, dtype)
|
|
||||||
z = angle
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
ones = dsl.broadcast(1.0, shape, dtype)
|
|
||||||
nones = dsl.broadcast(-1.0, shape, dtype)
|
|
||||||
cmp = dsl.vcmp(z, dsl.broadcast(0.0, shape, dtype))
|
|
||||||
d = dsl.vsel(cmp, ones, nones)
|
|
||||||
|
|
||||||
xn = x
|
|
||||||
|
|
||||||
d_ceof = dsl.vmuls(d, ceof[i])
|
|
||||||
x = dsl.vsub(xn, dsl.vmul(y, d_ceof))
|
|
||||||
y = dsl.vadd(y, dsl.vmul(xn, d_ceof))
|
|
||||||
z = dsl.vsub(z, dsl.vmuls(d, dangle[i]))
|
|
||||||
|
|
||||||
return dsl.vmuls(y, k), dsl.vmuls(x, k)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=locally-disabled,too-many-arguments,unused-argument,invalid-name,too-many-locals,
|
|
||||||
def prod(data, iter_arr):
|
|
||||||
"""prod"""
|
|
||||||
input_shape = data.shape
|
|
||||||
input_dtype = data.dtype
|
|
||||||
res = dsl.broadcast(tvm.const(iter_arr[-1], input_dtype), input_shape)
|
|
||||||
for addition in reversed(iter_arr[:-1]):
|
|
||||||
res = dsl.vmul(res, data)
|
|
||||||
res = dsl.vadd(res, dsl.broadcast(tvm.const(addition, input_dtype), input_shape))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def bessel_j1(x):
|
|
||||||
"""bessel_j1"""
|
|
||||||
shape_input = x.shape
|
|
||||||
dtype_input = x.dtype
|
|
||||||
# 1. chose the type of data in begin
|
|
||||||
if (dtype_input in ('float16', 'float32')) and api_check_support("te.lang.cce.vadd", "float32"):
|
|
||||||
x = dsl.cast_to(x, "float32")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
|
|
||||||
abs_data = dsl.vabs(x)
|
|
||||||
|
|
||||||
# 2. compute bessel_j1 for data in (-8, 8)
|
|
||||||
broad_const_limit = dsl.broadcast(tvm.const(CONST_LIMIT, x.dtype), shape_input)
|
|
||||||
before_abs_data = dsl.vmin(abs_data, broad_const_limit)
|
|
||||||
square_data = dsl.vmul(before_abs_data, before_abs_data) # x * x
|
|
||||||
before_res1 = prod(square_data, ITR_BEFORE1_J)
|
|
||||||
before_res1 = dsl.vmul(before_res1, before_abs_data)
|
|
||||||
|
|
||||||
before_res2 = prod(square_data, ITR_BEFORE2_J)
|
|
||||||
before_final_res = dsl.vdiv(before_res1, before_res2)
|
|
||||||
|
|
||||||
# 3. compute bessel_j1 for data in (-inf, -8) or (8, inf)
|
|
||||||
div_data = dsl.vdiv(dsl.broadcast(tvm.const(8.0), shape_input), abs_data)
|
|
||||||
square_div_data = dsl.vmul(div_data, div_data)
|
|
||||||
minus_pi_data = dsl.vsub(abs_data, dsl.broadcast(tvm.const(THREE_QUARTERS_PI), shape_input))
|
|
||||||
after_res1 = prod(square_div_data, ITR_AFTER1_J)
|
|
||||||
after_res2 = prod(square_div_data, ITR_AFTER2_J)
|
|
||||||
# 3.1 sqrt(0.636619772/ax)
|
|
||||||
tmp_res1 = dsl.vsqrt(dsl.vdiv(dsl.broadcast(tvm.const(TOW_OVER_PI), shape_input), abs_data),
|
|
||||||
impl_mode='high_precision')
|
|
||||||
# 3.2 cos(xx)*ans1
|
|
||||||
sinv, cosv = angle_trans_cal(abs_data)
|
|
||||||
tmp_res2 = dsl.vmul(cosv, after_res1)
|
|
||||||
# 3.3 z*math.sin(xx)*ans2
|
|
||||||
|
|
||||||
tmp_res3 = dsl.vmul(dsl.vmul(div_data, sinv), after_res2)
|
|
||||||
after_final_res = dsl.vmul(tmp_res1, dsl.vsub(tmp_res2, tmp_res3))
|
|
||||||
|
|
||||||
zero = dsl.broadcast(0.0, shape_input, 'float32')
|
|
||||||
neg_cond = dsl.vcmp(after_final_res, zero, operation='lt', mode='bool')
|
|
||||||
neg_after_res = dsl.vmuls(after_final_res, -1.0)
|
|
||||||
after_final_res = dsl.vsel(neg_cond, neg_after_res, after_final_res)
|
|
||||||
|
|
||||||
# 5. select res
|
|
||||||
# 5.1 compare with limit
|
|
||||||
select_condition = dsl.vcmp(abs_data, broad_const_limit, operation='lt', mode='bool')
|
|
||||||
|
|
||||||
# 5.2 select
|
|
||||||
res = dsl.vsel(select_condition, before_final_res, after_final_res)
|
|
||||||
|
|
||||||
# 6. chose the type of data in end
|
|
||||||
if dtype_input == "float16":
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@register_op_compute("bessel_y1")
|
|
||||||
def bessel_y1_compute(x, y, kernel_name="bessel_y1"):
|
|
||||||
"""bessel_y1_compute"""
|
|
||||||
shape_input = x.shape
|
|
||||||
dtype_input = x.dtype
|
|
||||||
# chose the type of data in begin
|
|
||||||
if (dtype_input in ('float16', 'float32')) and api_check_support("te.lang.cce.vadd", "float32"):
|
|
||||||
x = dsl.cast_to(x, "float32")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("BesselY0 kernel data type [%s] not support." % dtype_input)
|
|
||||||
x = dsl.vabs(x)
|
|
||||||
|
|
||||||
# compute bessel_y1 for data in (0, 8)
|
|
||||||
broad_const_limit = dsl.broadcast(tvm.const(CONST_LIMIT, x.dtype), shape_input)
|
|
||||||
square_data = dsl.vmul(x, x) # y = x * x
|
|
||||||
before_res1 = prod(square_data, ITR_BEFORE1)
|
|
||||||
before_res1 = dsl.vmul(before_res1, x)
|
|
||||||
before_res2 = prod(square_data, ITR_BEFORE2)
|
|
||||||
before_final_res = dsl.vmul(bessel_j1(x), dsl.vlog(x, impl_mode="high_precision"))
|
|
||||||
|
|
||||||
before_final_res = dsl.vsub(before_final_res, dsl.vrec(x, impl_mode="high_precision"))
|
|
||||||
before_final_res = dsl.vmuls(before_final_res, TOW_OVER_PI)
|
|
||||||
before_final_res = dsl.vadd(before_final_res, dsl.vdiv(before_res1, before_res2))
|
|
||||||
|
|
||||||
|
|
||||||
# compute bessel_y1 for data in (8, inf)
|
|
||||||
div_data = dsl.vdiv(broad_const_limit, x) # z = 8.0 / x
|
|
||||||
square_div_data = dsl.vmul(div_data, div_data) # y = z * z
|
|
||||||
minus_pi_data = dsl.vsub(x, dsl.broadcast(tvm.const(THREE_QUARTERS_PI), shape_input)) # xx
|
|
||||||
after_res1 = prod(square_div_data, ITR_AFTER1)
|
|
||||||
after_res2 = prod(square_div_data, ITR_AFTER2)
|
|
||||||
tmp_res1 = dsl.vsqrt(dsl.vdiv(dsl.broadcast(tvm.const(TOW_OVER_PI), shape_input), x), impl_mode="high_precision")
|
|
||||||
sinv, cosv = angle_trans_cal(x)
|
|
||||||
tmp_res2 = dsl.vmul(sinv, after_res1)
|
|
||||||
tmp_res3 = dsl.vmul(dsl.vmul(div_data, cosv), after_res2)
|
|
||||||
after_final_res = dsl.vmul(tmp_res1, dsl.vadd(tmp_res2, tmp_res3))
|
|
||||||
|
|
||||||
# select res
|
|
||||||
# compare with limit
|
|
||||||
select_condition = dsl.vcmp(x, broad_const_limit, operation='lt', mode='bool')
|
|
||||||
|
|
||||||
# select
|
|
||||||
res = dsl.vsel(select_condition, before_final_res, after_final_res)
|
|
||||||
|
|
||||||
# chose the type of data in end
|
|
||||||
if dtype_input == "float16":
|
|
||||||
res = dsl.cast_to(res, "float16")
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
|
|
||||||
def bessel_y1(x, y, kernel_name="bessel_y1"):
|
|
||||||
"""
|
|
||||||
Computes the Bessel j1 function of x element-wise.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x: the dict of input, only support float16, float32
|
|
||||||
|
|
||||||
y : the dict of output
|
|
||||||
|
|
||||||
kernel_name : cce kernel name, default value is "bessel_y1"
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
shape_input = x.get("shape")
|
|
||||||
dtype_input = x.get("dtype")
|
|
||||||
|
|
||||||
para_check.check_shape(shape_input, param_name="x")
|
|
||||||
shape_input, _ = shape_util.refine_shape_axes(shape_input, [])
|
|
||||||
|
|
||||||
check_list = ("float16", "float32")
|
|
||||||
para_check.check_dtype(dtype_input, check_list, param_name="x")
|
|
||||||
|
|
||||||
input_dtype = dtype_input.lower()
|
|
||||||
data = tvm.placeholder(shape_input, dtype=input_dtype, name="data_input")
|
|
||||||
|
|
||||||
res = bessel_y1_compute(data, y, kernel_name)
|
|
||||||
|
|
||||||
with tvm.target.cce():
|
|
||||||
sch = dsl.auto_schedule(res)
|
|
||||||
|
|
||||||
config = {"name": kernel_name,
|
|
||||||
"print_ir": False,
|
|
||||||
"tensor_list": (data, res)}
|
|
||||||
dsl.build(sch, config)
|
|
|
@ -4767,7 +4767,7 @@ class BesselI0(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_i0 = ops.BesselI0()
|
>>> bessel_i0 = ops.BesselI0()
|
||||||
|
@ -4798,7 +4798,7 @@ class BesselI1(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_i1 = ops.BesselI1()
|
>>> bessel_i1 = ops.BesselI1()
|
||||||
|
@ -4837,7 +4837,7 @@ class BesselI0e(Primitive):
|
||||||
TypeError: If dtype of `x` is not float16, float32 or float64.
|
TypeError: If dtype of `x` is not float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_i0e = ops.BesselI0e()
|
>>> bessel_i0e = ops.BesselI0e()
|
||||||
|
@ -4877,7 +4877,7 @@ class BesselI1e(Primitive):
|
||||||
TypeError: If dtype of `x` is not float16, float32 or float64.
|
TypeError: If dtype of `x` is not float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_i1e = ops.BesselI1e()
|
>>> bessel_i1e = ops.BesselI1e()
|
||||||
|
@ -4909,7 +4909,7 @@ class BesselK0(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_k0 = ops.BesselK0()
|
>>> bessel_k0 = ops.BesselK0()
|
||||||
|
@ -4940,7 +4940,7 @@ class BesselK1(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_k1 = ops.BesselK1()
|
>>> bessel_k1 = ops.BesselK1()
|
||||||
|
@ -4971,7 +4971,7 @@ class BesselK0e(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_k0e = ops.BesselK0e()
|
>>> bessel_k0e = ops.BesselK0e()
|
||||||
|
@ -5002,7 +5002,7 @@ class BesselK1e(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_k1e = ops.BesselK1e()
|
>>> bessel_k1e = ops.BesselK1e()
|
||||||
|
@ -5033,7 +5033,7 @@ class BesselJ0(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_j0 = ops.BesselJ0()
|
>>> bessel_j0 = ops.BesselJ0()
|
||||||
|
@ -5065,7 +5065,7 @@ class BesselJ1(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
TypeError: If `x` is not a Tensor of float16, float32 or float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_j1 = ops.BesselJ1()
|
>>> bessel_j1 = ops.BesselJ1()
|
||||||
|
@ -5097,7 +5097,7 @@ class BesselY0(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_y0 = ops.BesselY0()
|
>>> bessel_y0 = ops.BesselY0()
|
||||||
|
@ -5129,7 +5129,7 @@ class BesselY1(Primitive):
|
||||||
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
TypeError: If `x` is not a Tensor of float16, float32, float64.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``CPU`` ``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> bessel_y1 = ops.BesselY1()
|
>>> bessel_y1 = ops.BesselY1()
|
||||||
|
|
Loading…
Reference in New Issue