support operator Tan in Taylor differentiation
This commit is contained in:
parent
87f2993534
commit
4be29b46b5
|
@ -26,7 +26,7 @@ namespace internal {
|
||||||
// White list of ops with taylor rule.
|
// White list of ops with taylor rule.
|
||||||
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
|
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
|
||||||
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
|
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
|
||||||
prim::kPrimExp->name(), prim::kPrimLog->name()};
|
prim::kPrimTan->name(), prim::kPrimExp->name(), prim::kPrimLog->name()};
|
||||||
// The ops below are excluded when considering taylor rules.
|
// The ops below are excluded when considering taylor rules.
|
||||||
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
|
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
|
||||||
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
|
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
"""Define the taylor rules of operations."""
|
"""Define the taylor rules of operations."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.ops as ops
|
import mindspore.ops as ops
|
||||||
|
@ -205,3 +206,15 @@ def taylor_cos(self):
|
||||||
return series_cos
|
return series_cos
|
||||||
|
|
||||||
return taylor_fprop_cos
|
return taylor_fprop_cos
|
||||||
|
|
||||||
|
|
||||||
|
@taylor_fprop_getters.register(P.Tan)
|
||||||
|
def taylor_tan(self):
|
||||||
|
"""Higher order derivatives rule definition for `Tan` operation."""
|
||||||
|
|
||||||
|
def taylor_fprop_tan(inputs):
|
||||||
|
series_sin_cos = _taylor_fprop_sin_cos(inputs)
|
||||||
|
series_tan = _taylor_fprop_realdiv(series_sin_cos[0], series_sin_cos[1])
|
||||||
|
return series_tan
|
||||||
|
|
||||||
|
return taylor_fprop_tan
|
||||||
|
|
|
@ -17,7 +17,7 @@ import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.ops import operations as P
|
from mindspore import ops
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops.functional import jet, derivative
|
from mindspore.ops.functional import jet, derivative
|
||||||
|
|
||||||
|
@ -27,9 +27,9 @@ context.set_context(mode=context.GRAPH_MODE)
|
||||||
class MultipleInputSingleOutputNet(nn.Cell):
|
class MultipleInputSingleOutputNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MultipleInputSingleOutputNet, self).__init__()
|
super(MultipleInputSingleOutputNet, self).__init__()
|
||||||
self.sin = P.Sin()
|
self.sin = ops.Sin()
|
||||||
self.cos = P.Cos()
|
self.cos = ops.Cos()
|
||||||
self.exp = P.Exp()
|
self.exp = ops.Exp()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out1 = self.sin(x)
|
out1 = self.sin(x)
|
||||||
|
@ -42,8 +42,8 @@ class MultipleInputSingleOutputNet(nn.Cell):
|
||||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MultipleInputMultipleOutputNet, self).__init__()
|
super(MultipleInputMultipleOutputNet, self).__init__()
|
||||||
self.sin = P.Sin()
|
self.sin = ops.Sin()
|
||||||
self.cos = P.Cos()
|
self.cos = ops.Cos()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out1 = self.sin(x)
|
out1 = self.sin(x)
|
||||||
|
@ -54,9 +54,9 @@ class MultipleInputMultipleOutputNet(nn.Cell):
|
||||||
class SingleInputSingleOutputNet(nn.Cell):
|
class SingleInputSingleOutputNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SingleInputSingleOutputNet, self).__init__()
|
super(SingleInputSingleOutputNet, self).__init__()
|
||||||
self.sin = P.Sin()
|
self.sin = ops.Sin()
|
||||||
self.cos = P.Cos()
|
self.cos = ops.Cos()
|
||||||
self.exp = P.Exp()
|
self.exp = ops.Exp()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
out1 = self.sin(x)
|
out1 = self.sin(x)
|
||||||
|
@ -66,10 +66,16 @@ class SingleInputSingleOutputNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def function_graph(x):
|
||||||
|
y = ops.exp(x)
|
||||||
|
z = ops.tan(y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
class SingleInputSingleOutputWithScalarNet(nn.Cell):
|
class SingleInputSingleOutputWithScalarNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SingleInputSingleOutputWithScalarNet, self).__init__()
|
super(SingleInputSingleOutputWithScalarNet, self).__init__()
|
||||||
self.log = P.Log()
|
self.log = ops.Log()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
out1 = self.log(x)
|
out1 = self.log(x)
|
||||||
|
@ -258,3 +264,22 @@ def test_derivative_construct_graph_mode():
|
||||||
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
|
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
|
||||||
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
|
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
|
||||||
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
|
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jet_function_graph_mode():
|
||||||
|
"""
|
||||||
|
Features: Function jet
|
||||||
|
Description: Test function in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
primals = Tensor([1., 1.])
|
||||||
|
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
|
||||||
|
out_primals, out_series = jet(function_graph, primals, series)
|
||||||
|
expected_primals = np.array([-0.450549, -0.450549]).astype(np.float32)
|
||||||
|
expected_series = np.array([[3.270079, 3.270079], [-4.739784, -4.739784],
|
||||||
|
[56.995613, 56.995613]]).astype(np.float32)
|
||||||
|
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||||
|
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||||
|
|
Loading…
Reference in New Issue