forked from mindspore-Ecosystem/mindspore
!31312 [MS]Simplify taylor rule and fix code warning
Merge pull request !31312 from chenzhuo/taylor_0314
This commit is contained in:
commit
32e0f0b415
|
@ -592,6 +592,7 @@ class _PynativeExecutor:
|
|||
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
|
||||
|
||||
def set_grad_position(self, grad, grad_position):
|
||||
"""Set position of grad"""
|
||||
return self._executor.set_grad_position(grad, grad_position)
|
||||
|
||||
def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
|
||||
|
|
|
@ -22,12 +22,19 @@ from .. import operations as P
|
|||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import taylor_fprop_getters
|
||||
|
||||
ones_op = P.Ones()
|
||||
concat_op = P.Concat()
|
||||
mul_op = P.Mul()
|
||||
div_op = P.Div()
|
||||
sin_op = P.Sin()
|
||||
cos_op = P.Cos()
|
||||
exp_op = P.Exp()
|
||||
log_op = P.Log()
|
||||
|
||||
|
||||
def _factorial(order):
|
||||
"""Return [0!, 1!, 2!,..., order!]."""
|
||||
range_op = nn.Range(1, order + 1)
|
||||
ones_op = P.Ones()
|
||||
concat_op = P.Concat()
|
||||
factorial_zero = ones_op(1, ms.float32)
|
||||
factorial_positive = range_op().astype(ms.float32)
|
||||
for i in range(1, order):
|
||||
|
@ -44,6 +51,7 @@ def taylor_add_or_sub(self):
|
|||
prim = Primitive(self)
|
||||
else:
|
||||
prim = self
|
||||
|
||||
def taylor_fprop_add_or_sub(input_x, input_y):
|
||||
series = prim(input_x, input_y)
|
||||
return series
|
||||
|
@ -51,43 +59,53 @@ def taylor_add_or_sub(self):
|
|||
return taylor_fprop_add_or_sub
|
||||
|
||||
|
||||
def _taylor_fprop_mul(input_x, input_y):
|
||||
"""The rule to generate `Mul` taylor rule."""
|
||||
primals = mul_op(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k + 1):
|
||||
tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] *= factorial[k]
|
||||
return series
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Mul)
|
||||
def taylor_mul(self):
|
||||
"""Higher order derivatives rule definition for `Mul` operation."""
|
||||
mul_func = P.Mul()
|
||||
|
||||
def taylor_fprop_mul(input_x, input_y):
|
||||
primals = mul_func(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k + 1):
|
||||
tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] *= factorial[k]
|
||||
series = _taylor_fprop_mul(input_x, input_y)
|
||||
return series
|
||||
|
||||
return taylor_fprop_mul
|
||||
|
||||
|
||||
def _taylor_fprop_realdiv(input_x, input_y):
|
||||
"""The rule to generate `RealDiv` taylor rule."""
|
||||
primals = div_op(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k):
|
||||
tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0]
|
||||
return series
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.RealDiv)
|
||||
def taylor_realdiv(self):
|
||||
"""Higher order derivatives rule definition for `RealDiv` operation."""
|
||||
div_op = P.Div()
|
||||
|
||||
def taylor_fprop_realdiv(input_x, input_y):
|
||||
primals = div_op(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k):
|
||||
tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0]
|
||||
series = _taylor_fprop_realdiv(input_x, input_y)
|
||||
return series
|
||||
|
||||
return taylor_fprop_realdiv
|
||||
|
@ -96,10 +114,9 @@ def taylor_realdiv(self):
|
|||
@taylor_fprop_getters.register(P.Exp)
|
||||
def taylor_exp(self):
|
||||
"""Higher order derivatives rule definition for `Exp` operation."""
|
||||
exp_ = P.Exp()
|
||||
|
||||
def taylor_fprop_exp(inputs):
|
||||
primals = exp_(inputs[0])
|
||||
primals = exp_op(inputs[0])
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(inputs)
|
||||
|
@ -114,27 +131,31 @@ def taylor_exp(self):
|
|||
return taylor_fprop_exp
|
||||
|
||||
|
||||
def _taylor_fprop_sin_cos(inputs):
|
||||
"""Common rule to generate `Sin` and `Cos` taylor rules."""
|
||||
primal_sin = sin_op(inputs[0])
|
||||
primal_cos = cos_op(inputs[0])
|
||||
series_sin = zeros_like(inputs)
|
||||
series_cos = zeros_like(inputs)
|
||||
series_sin[0] = primal_sin
|
||||
series_cos[0] = primal_cos
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_sin[k] *= factorial[k - 1]
|
||||
series_cos[k] *= factorial[k - 1]
|
||||
return series_sin, series_cos
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Sin)
|
||||
def taylor_sin(self):
|
||||
"""Higher order derivatives rule definition for `Sin` operation."""
|
||||
cos = P.Cos()
|
||||
sin = P.Sin()
|
||||
|
||||
def taylor_fprop_sin(inputs):
|
||||
primal_sin = sin(inputs[0])
|
||||
primal_cos = cos(inputs[0])
|
||||
series_sin = zeros_like(inputs)
|
||||
series_cos = zeros_like(inputs)
|
||||
series_sin[0] = primal_sin
|
||||
series_cos[0] = primal_cos
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_sin[k] *= factorial[k - 1]
|
||||
series_cos[k] *= factorial[k - 1]
|
||||
series_sin = _taylor_fprop_sin_cos(inputs)[0]
|
||||
return series_sin
|
||||
|
||||
return taylor_fprop_sin
|
||||
|
@ -143,24 +164,9 @@ def taylor_sin(self):
|
|||
@taylor_fprop_getters.register(P.Cos)
|
||||
def taylor_cos(self):
|
||||
"""Higher order derivatives rule definition for `Cos` operation."""
|
||||
cos = P.Cos()
|
||||
sin = P.Sin()
|
||||
|
||||
def taylor_fprop_cos(inputs):
|
||||
primal_cos = cos(inputs[0])
|
||||
primal_sin = sin(inputs[0])
|
||||
series_cos = zeros_like(inputs)
|
||||
series_sin = zeros_like(inputs)
|
||||
series_cos[0] = primal_cos
|
||||
series_sin[0] = primal_sin
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_cos[k] *= factorial[k - 1]
|
||||
series_sin[k] *= factorial[k - 1]
|
||||
series_cos = _taylor_fprop_sin_cos(inputs)[1]
|
||||
return series_cos
|
||||
|
||||
return taylor_fprop_cos
|
||||
|
|
|
@ -415,10 +415,12 @@ class _TaylorOperation(TaylorOperation_):
|
|||
if self.grad_fn is not None and self.fn == fn:
|
||||
return self.grad_fn
|
||||
taylor_grad_ = _TaylorOperation()
|
||||
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
@ms_function
|
||||
def after_taylor_grad(*args):
|
||||
return taylor_grad_(fn)(*args)
|
||||
|
||||
self.grad_fn = after_taylor_grad
|
||||
self.fn = fn
|
||||
return self.grad_fn
|
||||
|
@ -449,28 +451,6 @@ class _Grad(GradOperation_):
|
|||
self.pynative_ = False
|
||||
self.grad_position = None
|
||||
|
||||
def _pynative_forward_run(self, grad, args, kwargs, fn, grad_position):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
new_kwargs = kwargs
|
||||
if self.sens_param:
|
||||
if not 'sens' in kwargs.keys():
|
||||
args = args[:-1]
|
||||
else:
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.pop('sens')
|
||||
if isinstance(fn, FunctionType):
|
||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
||||
_pynative_executor.set_grad_flag(True)
|
||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||
output = fn(*args, **new_kwargs)
|
||||
_pynative_executor.end_graph(fn, output, *args, **new_kwargs)
|
||||
else:
|
||||
# Check if fn have run already
|
||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
||||
fn.set_grad()
|
||||
fn(*args, **new_kwargs)
|
||||
fn.set_grad(False)
|
||||
|
||||
def __call__(self, fn, weights=None, grad_position=0):
|
||||
if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position:
|
||||
return self.grad_fn
|
||||
|
@ -499,7 +479,7 @@ class _Grad(GradOperation_):
|
|||
def after_grad(*args, **kwargs):
|
||||
if _pynative_executor.check_graph(fn, *args, **kwargs):
|
||||
print("Another grad step is running")
|
||||
self._pynative_forward_run(grad_, args, kwargs, fn, grad_position)
|
||||
self._pynative_forward_run(grad_, args, kwargs, fn)
|
||||
_pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
|
||||
out = _pynative_executor(fn, *args, **kwargs)
|
||||
_pynative_executor.clear_grad(fn, *args, **kwargs)
|
||||
|
@ -523,6 +503,28 @@ class _Grad(GradOperation_):
|
|||
self.grad_position = grad_position
|
||||
return self.grad_fn
|
||||
|
||||
def _pynative_forward_run(self, grad, args, kwargs, fn):
|
||||
""" Pynative forward runs to build grad graph. """
|
||||
new_kwargs = kwargs
|
||||
if self.sens_param:
|
||||
if 'sens' in kwargs.keys():
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.pop('sens')
|
||||
else:
|
||||
args = args[:-1]
|
||||
if isinstance(fn, FunctionType):
|
||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
||||
_pynative_executor.set_grad_flag(True)
|
||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||
outputs = fn(*args, **new_kwargs)
|
||||
_pynative_executor.end_graph(fn, outputs, *args, **new_kwargs)
|
||||
else:
|
||||
# Check if fn has run already.
|
||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
||||
fn.set_grad()
|
||||
fn(*args, **new_kwargs)
|
||||
fn.set_grad(False)
|
||||
|
||||
|
||||
class _Vmap(VmapOperation_):
|
||||
"""
|
||||
|
|
|
@ -282,6 +282,7 @@ def grad(fn, grad_position=0, sens_param=False):
|
|||
return grad_by_position_with_sens(fn, None, grad_position)
|
||||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _trans_jet_inputs(primals_item, series_item):
|
||||
"""Trans inputs of jet"""
|
||||
|
@ -294,6 +295,7 @@ def _trans_jet_inputs(primals_item, series_item):
|
|||
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
|
||||
return primals_item, series_item
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_jet_inputs(primals, series):
|
||||
"""Check inputs of jet"""
|
||||
|
@ -315,8 +317,10 @@ def _check_jet_inputs(primals, series):
|
|||
check_series.append(trans_series_item)
|
||||
return check_primals, check_series
|
||||
|
||||
|
||||
_taylor = _TaylorOperation()
|
||||
|
||||
|
||||
def jet(fn, primals, series):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
|
|
Loading…
Reference in New Issue