!14 add adapter for ge atan2
Merge pull request !14 from zhaozhenlong/op/atan2
This commit is contained in:
commit
69d5403319
|
@ -182,6 +182,7 @@ const char kNameDiag[] = "Diag";
|
|||
const char kNameDiagPart[] = "DiagPart";
|
||||
const char kNameSpaceToBatch[] = "SpaceToBatch";
|
||||
const char kNameBatchToSpace[] = "BatchToSpace";
|
||||
const char kNameAtan2[] = "Atan2";
|
||||
|
||||
// -----------------OpAdapter initialization--------------
|
||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||
|
@ -365,7 +366,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameDiag), ADPT_DESC(Diag)},
|
||||
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
||||
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
||||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}};
|
||||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
||||
{string(kNameAtan2), ADPT_DESC(Atan2)}};
|
||||
#ifdef ENABLE_GE
|
||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||
#endif
|
||||
|
|
|
@ -1196,6 +1196,12 @@ ATTR_MAP(BatchToSpaceD) = {
|
|||
{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())},
|
||||
{"crops", ATTR_DESC(crops, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// Atan2
|
||||
INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
// Print
|
||||
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||
|
|
|
@ -443,6 +443,8 @@ DECLARE_OP_ADAPTER(SpaceToBatchD)
|
|||
DECLARE_OP_USE_OUTPUT(SpaceToBatchD)
|
||||
DECLARE_OP_ADAPTER(BatchToSpaceD)
|
||||
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
|
||||
DECLARE_OP_ADAPTER(Atan2)
|
||||
DECLARE_OP_USE_OUTPUT(Atan2)
|
||||
#ifdef ENABLE_GE
|
||||
DECLARE_OP_ADAPTER(Print)
|
||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||
|
|
|
@ -738,3 +738,16 @@ def get_bprop_round(self):
|
|||
def bprop(x, out, dout):
|
||||
return (zeros_like(x),)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Atan2)
|
||||
def get_bprop_atan2(self):
|
||||
"""Generate bprop for Atan2"""
|
||||
|
||||
square = P.Square()
|
||||
def bprop(x, y, out, dout):
|
||||
tmp = dout / (square(x) + square(y))
|
||||
dx = tmp * y
|
||||
dy = tmp * (-x)
|
||||
return (dx, dy)
|
||||
return bprop
|
||||
|
|
|
@ -37,7 +37,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary,
|
|||
TensorSummary, Print)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast
|
||||
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, BatchMatMul,
|
||||
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
|
||||
Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv,
|
||||
Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd,
|
||||
|
@ -226,7 +226,8 @@ __all__ = [
|
|||
"Round",
|
||||
"ApplyFtrl",
|
||||
"SpaceToBatch",
|
||||
"BatchToSpace"
|
||||
"BatchToSpace",
|
||||
"Atan2",
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -1858,3 +1858,26 @@ class Round(PrimitiveWithInfer):
|
|||
validator.check_subclass("x_dtype", x_type, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_type, mstype.number_type)
|
||||
return x_type
|
||||
|
||||
|
||||
class Atan2(_MathBinaryOp):
|
||||
r"""
|
||||
Returns arctangent of input_x/input_y element-wise.
|
||||
|
||||
It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})`
|
||||
such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
- **input_y** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[0, 1]]), mstype.float32)
|
||||
>>> input_y = Tensor(np.array([[1, 1]]), mstype.float32)
|
||||
>>> atan2 = Atan2()
|
||||
>>> atan2(input_x, input_y)
|
||||
[[0. 0.7853982]]
|
||||
"""
|
||||
|
|
|
@ -481,7 +481,12 @@ test_case_math_ops = [
|
|||
('Round', {
|
||||
'block': P.Round(),
|
||||
'desc_inputs': [[3]],
|
||||
'desc_bprop': [[3]]})
|
||||
'desc_bprop': [[3]]}),
|
||||
('Atan2', {
|
||||
'block': P.Atan2(),
|
||||
'desc_inputs': [Tensor(np.array([0, 1]).astype(np.float32)),
|
||||
Tensor(np.array([1, 1]).astype(np.float32))],
|
||||
'desc_bprop': [[2]]})
|
||||
]
|
||||
|
||||
test_case_nn_ops = [
|
||||
|
|
Loading…
Reference in New Issue