!1518 add register info for Atan and AtanGrad and Atanh

Merge pull request !1518 from zhouneng/add_vm_support_Atan_AtanGrad_Atanh
This commit is contained in:
mindspore-ci-bot 2020-05-30 16:48:09 +08:00 committed by Gitee
commit efaaf58074
9 changed files with 241 additions and 1 deletions

View File

@ -912,6 +912,17 @@ def get_bprop_bessel_i0e(self):
return bprop
@bprop_getters.register(P.Atan)
def get_bprop_atan(self):
"""Grad definition for `Atan` operation."""
input_grad = G.AtanGrad()
def bprop(x, out, dout):
dx = input_grad(x, dout)
return (dx,)
return bprop
@bprop_getters.register(P.BesselI1e)
def get_bprop_bessel_i1e(self):
"""Generate bprop for BesselI1e"""
@ -934,3 +945,16 @@ def get_bprop_bessel_i1e(self):
dx = select(x_is_valid, tmp, 0.5 + zeros)
return (dx,)
return bprop
@bprop_getters.register(P.Atanh)
def get_bprop_atanh(self):
"""Grad definition for `Atanh` operation."""
power = P.Pow()
div = P.Div()
def bprop(x, out, dout):
tmp = 1 - power(x, 2)
dx = div(1, tmp) * dout
return (dx,)
return bprop

View File

@ -221,3 +221,6 @@ from .asin import _asin_tbe
from .asin_grad import _asin_grad_tbe
from .asinh import _asinh_tbe
from .asinh_grad import _asinh_grad_tbe
from .atan import _atan_tbe
from .atan_grad import _atan_grad_tbe
from .atanh import _atanh_tbe

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Atan op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
atan_op_info = TBERegOp("Atan") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("atan.so") \
.compute_cost(10) \
.kernel_name("atan") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(atan_op_info)
def _atan_tbe():
"""Atan TBE register"""
return

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AtanGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
atan_grad_op_info = TBERegOp("AtanGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("atan_grad.so") \
.compute_cost(10) \
.kernel_name("atan_grad") \
.partial_flag(True) \
.input(0, "y", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracNZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracNZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(atan_grad_op_info)
def _atan_grad_tbe():
"""AtanGrad TBE register"""
return

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Atanh op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
atanh_op_info = TBERegOp("Atanh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("atanh.so") \
.compute_cost(10) \
.kernel_name("atanh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(atanh_op_info)
def _atanh_tbe():
"""Atanh TBE register"""
return

View File

@ -50,7 +50,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
@ -277,6 +277,8 @@ __all__ = [
"BitwiseXor",
"BesselI0e",
"BesselI1e",
"Atan",
"Atanh"
]
__all__.extend(_quant_ops.__all__)

View File

@ -1151,3 +1151,25 @@ class RefToEmbed(Primitive):
@prim_attr_register
def __init__(self):
pass
class AtanGrad(PrimitiveWithInfer):
"""
Computes AtanGrad of input element-wise.
Returns:
Tensor, has the same type as input.
"""
@prim_attr_register
def __init__(self):
"""init AtanGrad"""
def infer_shape(self, x, dout):
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x

View File

@ -2207,6 +2207,66 @@ class Round(PrimitiveWithInfer):
return x_type
class Atan(PrimitiveWithInfer):
"""
Computes the trignometric inverse tangent of x element-wise.
Inputs:
- **input_x** (Tensor): The input tensor.
Outputs:
A Tensor. Has the same type as x.
Examples:
>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
>>> tan = P.Tan()
>>> output_y = tan(input_x)
>>> atan = P.Atan()
>>> atan(output_y)
[[1.047, 07850001]]
"""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
class Atanh(PrimitiveWithInfer):
"""
Computes inverse hyperbolic tangent of x element-wise.
Inputs:
- **input_x** (Tensor): The input tensor.
Outputs:
A Tensor. Has the same type as x.
Examples:
>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
>>> atanh = P.Atanh()
>>> atanh(input_x)
[[1.8869909 1.058268]]
"""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
class Atan2(_MathBinaryOp):
r"""
Returns arctangent of input_x/input_y element-wise.

View File

@ -680,6 +680,18 @@ test_case_math_ops = [
'block': P.BesselI1e(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Atan', {
'block': P.Atan(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('AtanGrad', {
'block': G.AtanGrad(),
'desc_inputs': [[2, 3], [2, 3]],
'skip': ['backward']}),
('Atanh', {
'block': P.Atanh(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
]
test_case_nn_ops = [