forked from mindspore-Ecosystem/mindspore
!1518 add register info for Atan and AtanGrad and Atanh
Merge pull request !1518 from zhouneng/add_vm_support_Atan_AtanGrad_Atanh
This commit is contained in:
commit
efaaf58074
|
@ -912,6 +912,17 @@ def get_bprop_bessel_i0e(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Atan)
|
||||
def get_bprop_atan(self):
|
||||
"""Grad definition for `Atan` operation."""
|
||||
input_grad = G.AtanGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = input_grad(x, dout)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BesselI1e)
|
||||
def get_bprop_bessel_i1e(self):
|
||||
"""Generate bprop for BesselI1e"""
|
||||
|
@ -934,3 +945,16 @@ def get_bprop_bessel_i1e(self):
|
|||
dx = select(x_is_valid, tmp, 0.5 + zeros)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Atanh)
|
||||
def get_bprop_atanh(self):
|
||||
"""Grad definition for `Atanh` operation."""
|
||||
power = P.Pow()
|
||||
div = P.Div()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
tmp = 1 - power(x, 2)
|
||||
dx = div(1, tmp) * dout
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
|
|
@ -221,3 +221,6 @@ from .asin import _asin_tbe
|
|||
from .asin_grad import _asin_grad_tbe
|
||||
from .asinh import _asinh_tbe
|
||||
from .asinh_grad import _asinh_grad_tbe
|
||||
from .atan import _atan_tbe
|
||||
from .atan_grad import _atan_grad_tbe
|
||||
from .atanh import _atanh_tbe
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Atan op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
atan_op_info = TBERegOp("Atan") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("atan.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("atan") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(atan_op_info)
|
||||
def _atan_tbe():
|
||||
"""Atan TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AtanGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
atan_grad_op_info = TBERegOp("AtanGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("atan_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("atan_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "y", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracNZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracNZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(atan_grad_op_info)
|
||||
def _atan_grad_tbe():
|
||||
"""AtanGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Atanh op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
atanh_op_info = TBERegOp("Atanh") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("atanh.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("atanh") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(atanh_op_info)
|
||||
def _atanh_tbe():
|
||||
"""Atanh TBE register"""
|
||||
return
|
|
@ -50,7 +50,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
|
|||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
|
||||
from .random_ops import (RandomChoiceWithMask)
|
||||
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||
BiasAdd, Conv2D,
|
||||
|
@ -277,6 +277,8 @@ __all__ = [
|
|||
"BitwiseXor",
|
||||
"BesselI0e",
|
||||
"BesselI1e",
|
||||
"Atan",
|
||||
"Atanh"
|
||||
]
|
||||
|
||||
__all__.extend(_quant_ops.__all__)
|
||||
|
|
|
@ -1151,3 +1151,25 @@ class RefToEmbed(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class AtanGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes AtanGrad of input element-wise.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type as input.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init AtanGrad"""
|
||||
|
||||
def infer_shape(self, x, dout):
|
||||
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, dout):
|
||||
args = {"x": x, "dout": dout}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x
|
||||
|
|
|
@ -2207,6 +2207,66 @@ class Round(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class Atan(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the trignometric inverse tangent of x element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor): The input tensor.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as x.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
|
||||
>>> tan = P.Tan()
|
||||
>>> output_y = tan(input_x)
|
||||
>>> atan = P.Atan()
|
||||
>>> atan(output_y)
|
||||
[[1.047, 07850001]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
class Atanh(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes inverse hyperbolic tangent of x element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor): The input tensor.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as x.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
|
||||
>>> atanh = P.Atanh()
|
||||
>>> atanh(input_x)
|
||||
[[1.8869909 1.058268]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
class Atan2(_MathBinaryOp):
|
||||
r"""
|
||||
Returns arctangent of input_x/input_y element-wise.
|
||||
|
|
|
@ -680,6 +680,18 @@ test_case_math_ops = [
|
|||
'block': P.BesselI1e(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('Atan', {
|
||||
'block': P.Atan(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('AtanGrad', {
|
||||
'block': G.AtanGrad(),
|
||||
'desc_inputs': [[2, 3], [2, 3]],
|
||||
'skip': ['backward']}),
|
||||
('Atanh', {
|
||||
'block': P.Atanh(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
]
|
||||
|
||||
test_case_nn_ops = [
|
||||
|
|
Loading…
Reference in New Issue