support operator Tan in Taylor differentiation

This commit is contained in:
chenzhuo 2022-07-04 15:06:38 +08:00
parent 87f2993534
commit 4be29b46b5
3 changed files with 49 additions and 11 deletions

View File

@ -26,7 +26,7 @@ namespace internal {
// White list of ops with taylor rule. // White list of ops with taylor rule.
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(), mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(), prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
prim::kPrimExp->name(), prim::kPrimLog->name()}; prim::kPrimTan->name(), prim::kPrimExp->name(), prim::kPrimLog->name()};
// The ops below are excluded when considering taylor rules. // The ops below are excluded when considering taylor rules.
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(), mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()}; prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};

View File

@ -15,6 +15,7 @@
"""Define the taylor rules of operations.""" """Define the taylor rules of operations."""
from __future__ import absolute_import
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as ops import mindspore.ops as ops
@ -205,3 +206,15 @@ def taylor_cos(self):
return series_cos return series_cos
return taylor_fprop_cos return taylor_fprop_cos
@taylor_fprop_getters.register(P.Tan)
def taylor_tan(self):
"""Higher order derivatives rule definition for `Tan` operation."""
def taylor_fprop_tan(inputs):
series_sin_cos = _taylor_fprop_sin_cos(inputs)
series_tan = _taylor_fprop_realdiv(series_sin_cos[0], series_sin_cos[1])
return series_tan
return taylor_fprop_tan

View File

@ -17,7 +17,7 @@ import pytest
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.ops import operations as P from mindspore import ops
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.functional import jet, derivative from mindspore.ops.functional import jet, derivative
@ -27,9 +27,9 @@ context.set_context(mode=context.GRAPH_MODE)
class MultipleInputSingleOutputNet(nn.Cell): class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self): def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__() super(MultipleInputSingleOutputNet, self).__init__()
self.sin = P.Sin() self.sin = ops.Sin()
self.cos = P.Cos() self.cos = ops.Cos()
self.exp = P.Exp() self.exp = ops.Exp()
def construct(self, x, y): def construct(self, x, y):
out1 = self.sin(x) out1 = self.sin(x)
@ -42,8 +42,8 @@ class MultipleInputSingleOutputNet(nn.Cell):
class MultipleInputMultipleOutputNet(nn.Cell): class MultipleInputMultipleOutputNet(nn.Cell):
def __init__(self): def __init__(self):
super(MultipleInputMultipleOutputNet, self).__init__() super(MultipleInputMultipleOutputNet, self).__init__()
self.sin = P.Sin() self.sin = ops.Sin()
self.cos = P.Cos() self.cos = ops.Cos()
def construct(self, x, y): def construct(self, x, y):
out1 = self.sin(x) out1 = self.sin(x)
@ -54,9 +54,9 @@ class MultipleInputMultipleOutputNet(nn.Cell):
class SingleInputSingleOutputNet(nn.Cell): class SingleInputSingleOutputNet(nn.Cell):
def __init__(self): def __init__(self):
super(SingleInputSingleOutputNet, self).__init__() super(SingleInputSingleOutputNet, self).__init__()
self.sin = P.Sin() self.sin = ops.Sin()
self.cos = P.Cos() self.cos = ops.Cos()
self.exp = P.Exp() self.exp = ops.Exp()
def construct(self, x): def construct(self, x):
out1 = self.sin(x) out1 = self.sin(x)
@ -66,10 +66,16 @@ class SingleInputSingleOutputNet(nn.Cell):
return out return out
def function_graph(x):
y = ops.exp(x)
z = ops.tan(y)
return z
class SingleInputSingleOutputWithScalarNet(nn.Cell): class SingleInputSingleOutputWithScalarNet(nn.Cell):
def __init__(self): def __init__(self):
super(SingleInputSingleOutputWithScalarNet, self).__init__() super(SingleInputSingleOutputWithScalarNet, self).__init__()
self.log = P.Log() self.log = ops.Log()
def construct(self, x): def construct(self, x):
out1 = self.log(x) out1 = self.log(x)
@ -258,3 +264,22 @@ def test_derivative_construct_graph_mode():
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4) assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4) assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4) assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_function_graph_mode():
"""
Features: Function jet
Description: Test function in graph mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
out_primals, out_series = jet(function_graph, primals, series)
expected_primals = np.array([-0.450549, -0.450549]).astype(np.float32)
expected_series = np.array([[3.270079, 3.270079], [-4.739784, -4.739784],
[56.995613, 56.995613]]).astype(np.float32)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)