support operator Tan in Taylor differentiation
This commit is contained in:
parent
87f2993534
commit
4be29b46b5
|
@ -26,7 +26,7 @@ namespace internal {
|
|||
// White list of ops with taylor rule.
|
||||
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
|
||||
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
|
||||
prim::kPrimExp->name(), prim::kPrimLog->name()};
|
||||
prim::kPrimTan->name(), prim::kPrimExp->name(), prim::kPrimLog->name()};
|
||||
// The ops below are excluded when considering taylor rules.
|
||||
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
|
||||
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Define the taylor rules of operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
@ -205,3 +206,15 @@ def taylor_cos(self):
|
|||
return series_cos
|
||||
|
||||
return taylor_fprop_cos
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Tan)
|
||||
def taylor_tan(self):
|
||||
"""Higher order derivatives rule definition for `Tan` operation."""
|
||||
|
||||
def taylor_fprop_tan(inputs):
|
||||
series_sin_cos = _taylor_fprop_sin_cos(inputs)
|
||||
series_tan = _taylor_fprop_realdiv(series_sin_cos[0], series_sin_cos[1])
|
||||
return series_tan
|
||||
|
||||
return taylor_fprop_tan
|
||||
|
|
|
@ -17,7 +17,7 @@ import pytest
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jet, derivative
|
||||
|
||||
|
@ -27,9 +27,9 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MultipleInputSingleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
self.exp = P.Exp()
|
||||
self.sin = ops.Sin()
|
||||
self.cos = ops.Cos()
|
||||
self.exp = ops.Exp()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.sin(x)
|
||||
|
@ -42,8 +42,8 @@ class MultipleInputSingleOutputNet(nn.Cell):
|
|||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MultipleInputMultipleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
self.sin = ops.Sin()
|
||||
self.cos = ops.Cos()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.sin(x)
|
||||
|
@ -54,9 +54,9 @@ class MultipleInputMultipleOutputNet(nn.Cell):
|
|||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleInputSingleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
self.exp = P.Exp()
|
||||
self.sin = ops.Sin()
|
||||
self.cos = ops.Cos()
|
||||
self.exp = ops.Exp()
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.sin(x)
|
||||
|
@ -66,10 +66,16 @@ class SingleInputSingleOutputNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def function_graph(x):
|
||||
y = ops.exp(x)
|
||||
z = ops.tan(y)
|
||||
return z
|
||||
|
||||
|
||||
class SingleInputSingleOutputWithScalarNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleInputSingleOutputWithScalarNet, self).__init__()
|
||||
self.log = P.Log()
|
||||
self.log = ops.Log()
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.log(x)
|
||||
|
@ -258,3 +264,22 @@ def test_derivative_construct_graph_mode():
|
|||
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
|
||||
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
|
||||
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_function_graph_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test function in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = Tensor([1., 1.])
|
||||
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
|
||||
out_primals, out_series = jet(function_graph, primals, series)
|
||||
expected_primals = np.array([-0.450549, -0.450549]).astype(np.float32)
|
||||
expected_series = np.array([[3.270079, 3.270079], [-4.739784, -4.739784],
|
||||
[56.995613, 56.995613]]).astype(np.float32)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
|
|
Loading…
Reference in New Issue