!14 add adapter for ge atan2

Merge pull request !14 from zhaozhenlong/op/atan2
This commit is contained in:
mindspore-ci-bot 2020-04-01 18:24:55 +08:00 committed by Gitee
commit 69d5403319
7 changed files with 56 additions and 4 deletions

View File

@ -182,6 +182,7 @@ const char kNameDiag[] = "Diag";
const char kNameDiagPart[] = "DiagPart";
const char kNameSpaceToBatch[] = "SpaceToBatch";
const char kNameBatchToSpace[] = "BatchToSpace";
const char kNameAtan2[] = "Atan2";
// -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
@ -365,7 +366,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameDiag), ADPT_DESC(Diag)},
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}};
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
{string(kNameAtan2), ADPT_DESC(Atan2)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
#endif

View File

@ -1196,6 +1196,12 @@ ATTR_MAP(BatchToSpaceD) = {
{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())},
{"crops", ATTR_DESC(crops, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}};
// Atan2
INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
#ifdef ENABLE_GE
// Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP;

View File

@ -443,6 +443,8 @@ DECLARE_OP_ADAPTER(SpaceToBatchD)
DECLARE_OP_USE_OUTPUT(SpaceToBatchD)
DECLARE_OP_ADAPTER(BatchToSpaceD)
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
DECLARE_OP_ADAPTER(Atan2)
DECLARE_OP_USE_OUTPUT(Atan2)
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print)

View File

@ -738,3 +738,16 @@ def get_bprop_round(self):
def bprop(x, out, dout):
return (zeros_like(x),)
return bprop
@bprop_getters.register(P.Atan2)
def get_bprop_atan2(self):
"""Generate bprop for Atan2"""
square = P.Square()
def bprop(x, y, out, dout):
tmp = dout / (square(x) + square(y))
dx = tmp * y
dy = tmp * (-x)
return (dx, dy)
return bprop

View File

@ -37,7 +37,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary,
TensorSummary, Print)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, BatchMatMul,
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv,
Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd,
@ -226,7 +226,8 @@ __all__ = [
"Round",
"ApplyFtrl",
"SpaceToBatch",
"BatchToSpace"
"BatchToSpace",
"Atan2",
]
__all__.sort()

View File

@ -1858,3 +1858,26 @@ class Round(PrimitiveWithInfer):
validator.check_subclass("x_dtype", x_type, mstype.tensor)
validator.check_typename('x_dtype', x_type, mstype.number_type)
return x_type
class Atan2(_MathBinaryOp):
r"""
Returns arctangent of input_x/input_y element-wise.
It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})`
such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`.
Inputs:
- **input_x** (Tensor) - The input tensor.
- **input_y** (Tensor) - The input tensor.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Examples:
>>> input_x = Tensor(np.array([[0, 1]]), mstype.float32)
>>> input_y = Tensor(np.array([[1, 1]]), mstype.float32)
>>> atan2 = Atan2()
>>> atan2(input_x, input_y)
[[0. 0.7853982]]
"""

View File

@ -481,7 +481,12 @@ test_case_math_ops = [
('Round', {
'block': P.Round(),
'desc_inputs': [[3]],
'desc_bprop': [[3]]})
'desc_bprop': [[3]]}),
('Atan2', {
'block': P.Atan2(),
'desc_inputs': [Tensor(np.array([0, 1]).astype(np.float32)),
Tensor(np.array([1, 1]).astype(np.float32))],
'desc_bprop': [[2]]})
]
test_case_nn_ops = [