support operator Tan in Taylor differentiation

This commit is contained in:
chenzhuo 2022-07-04 15:06:38 +08:00
parent 87f2993534
commit 4be29b46b5
3 changed files with 49 additions and 11 deletions

View File

@ -26,7 +26,7 @@ namespace internal {
// White list of ops with taylor rule.
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
prim::kPrimExp->name(), prim::kPrimLog->name()};
prim::kPrimTan->name(), prim::kPrimExp->name(), prim::kPrimLog->name()};
// The ops below are excluded when considering taylor rules.
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};

View File

@ -15,6 +15,7 @@
"""Define the taylor rules of operations."""
from __future__ import absolute_import
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
@ -205,3 +206,15 @@ def taylor_cos(self):
return series_cos
return taylor_fprop_cos
@taylor_fprop_getters.register(P.Tan)
def taylor_tan(self):
"""Higher order derivatives rule definition for `Tan` operation."""
def taylor_fprop_tan(inputs):
series_sin_cos = _taylor_fprop_sin_cos(inputs)
series_tan = _taylor_fprop_realdiv(series_sin_cos[0], series_sin_cos[1])
return series_tan
return taylor_fprop_tan

View File

@ -17,7 +17,7 @@ import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore import ops
from mindspore import Tensor
from mindspore.ops.functional import jet, derivative
@ -27,9 +27,9 @@ context.set_context(mode=context.GRAPH_MODE)
class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
self.sin = ops.Sin()
self.cos = ops.Cos()
self.exp = ops.Exp()
def construct(self, x, y):
out1 = self.sin(x)
@ -42,8 +42,8 @@ class MultipleInputSingleOutputNet(nn.Cell):
class MultipleInputMultipleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputMultipleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.sin = ops.Sin()
self.cos = ops.Cos()
def construct(self, x, y):
out1 = self.sin(x)
@ -54,9 +54,9 @@ class MultipleInputMultipleOutputNet(nn.Cell):
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
self.sin = ops.Sin()
self.cos = ops.Cos()
self.exp = ops.Exp()
def construct(self, x):
out1 = self.sin(x)
@ -66,10 +66,16 @@ class SingleInputSingleOutputNet(nn.Cell):
return out
def function_graph(x):
y = ops.exp(x)
z = ops.tan(y)
return z
class SingleInputSingleOutputWithScalarNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputWithScalarNet, self).__init__()
self.log = P.Log()
self.log = ops.Log()
def construct(self, x):
out1 = self.log(x)
@ -258,3 +264,22 @@ def test_derivative_construct_graph_mode():
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_function_graph_mode():
"""
Features: Function jet
Description: Test function in graph mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
out_primals, out_series = jet(function_graph, primals, series)
expected_primals = np.array([-0.450549, -0.450549]).astype(np.float32)
expected_series = np.array([[3.270079, 3.270079], [-4.739784, -4.739784],
[56.995613, 56.995613]]).astype(np.float32)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)