|
|
|
@ -222,7 +222,7 @@ def _trans_jet_inputs(primals_item, series_item):
|
|
|
|
|
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
|
|
|
|
|
raise TypeError(f"For `F.jet`, the elements' types of primals and series must be the same and belong to "
|
|
|
|
|
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
|
|
|
|
f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.")
|
|
|
|
|
f" {dtype(primals_item)} and {dtype(series_item)}.")
|
|
|
|
|
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
|
|
|
|
return cast(primals_item, mstype.float32), cast(series_item, mstype.float32)
|
|
|
|
|
return primals_item, series_item
|
|
|
|
@ -235,18 +235,19 @@ def _check_jet_inputs(primals, series):
|
|
|
|
|
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple, "
|
|
|
|
|
f"but got {type(primals).__name__} and {type(series).__name__}.")
|
|
|
|
|
if isinstance(primals, Tensor):
|
|
|
|
|
if primals.shape != series.shape[1:]:
|
|
|
|
|
raise ValueError("The shape of each element must be the same as the primals.")
|
|
|
|
|
return _trans_jet_inputs(primals, series)
|
|
|
|
|
if isinstance(primals, tuple):
|
|
|
|
|
if len(primals) != len(series):
|
|
|
|
|
raise ValueError("The lengths of primals and series must be the same.")
|
|
|
|
|
check_primals = []
|
|
|
|
|
check_series = []
|
|
|
|
|
for i, j in zip(primals, series):
|
|
|
|
|
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
|
|
|
|
|
check_primals.append(trans_primals_item)
|
|
|
|
|
check_series.append(trans_series_item)
|
|
|
|
|
if primals.shape == series.shape[1:]:
|
|
|
|
|
return _trans_jet_inputs(primals, series)
|
|
|
|
|
if primals.shape == series.shape:
|
|
|
|
|
return _trans_jet_inputs(primals, series.expand_dims(axis=0))
|
|
|
|
|
raise ValueError("In series, the shape of each element must be the same as the primals.")
|
|
|
|
|
if len(primals) != len(series):
|
|
|
|
|
raise ValueError("The lengths of primals and series must be the same.")
|
|
|
|
|
check_primals = []
|
|
|
|
|
check_series = []
|
|
|
|
|
for i, j in zip(primals, series):
|
|
|
|
|
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
|
|
|
|
|
check_primals.append(trans_primals_item)
|
|
|
|
|
check_series.append(trans_series_item)
|
|
|
|
|
return check_primals, check_series
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -258,22 +259,26 @@ def jet(fn, primals, series):
|
|
|
|
|
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
|
|
|
|
first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
|
|
|
|
|
must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
|
|
|
|
|
while the other to 0.
|
|
|
|
|
while the other to 0, which is like the derivative of origin input with respect to itself.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fn (Union(Cell, function)): Function to do TaylorOperation.
|
|
|
|
|
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
|
|
|
|
|
series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs.
|
|
|
|
|
fn (Union[Cell, function]): Function to do TaylorOperation.
|
|
|
|
|
primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
|
|
|
|
|
series (Union[Tensor, tuple[Tensor]]): If tuple, the length and type of series should be the same as inputs.
|
|
|
|
|
For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
|
|
|
|
|
output with respect to the inputs will be figured out.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, tuple of out_primals and out_series.
|
|
|
|
|
|
|
|
|
|
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
|
|
|
|
|
- **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect
|
|
|
|
|
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
|
|
|
|
|
- **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
|
|
|
|
|
to the inputs.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If `primals` is not a tensor or tuple of tensors.
|
|
|
|
|
TypeError: If type of `primals` is not the same as type of `series`.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
|
|
|
|
|
@ -299,6 +304,15 @@ def jet(fn, primals, series):
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> out_primals, out_series = jet(net, primals, series)
|
|
|
|
|
>>> print(out_primals, out_series)
|
|
|
|
|
[[2.319777 2.4825778]
|
|
|
|
|
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
|
|
|
|
|
[-1.1400385 -0.3066662 ]]
|
|
|
|
|
|
|
|
|
|
[[-1.2748207 -1.8274734 ]
|
|
|
|
|
[ 0.966121 0.55551505]]
|
|
|
|
|
|
|
|
|
|
[[-4.0515366 3.6724353 ]
|
|
|
|
|
[ 0.5053504 -0.52061415]]]
|
|
|
|
|
"""
|
|
|
|
|
primals, series = _check_jet_inputs(primals, series)
|
|
|
|
|
derivative_fn = _taylor(fn)
|
|
|
|
@ -325,12 +339,22 @@ def _trans_derivative_inputs(primals_item):
|
|
|
|
|
if not dtype(primals_item) in value_type:
|
|
|
|
|
raise TypeError(f"For `F.derivative`, the elements of primals must belong to "
|
|
|
|
|
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
|
|
|
|
f" {dtype(primals_item).__name__}.")
|
|
|
|
|
f" {dtype(primals_item)}.")
|
|
|
|
|
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
|
|
|
|
return cast(primals_item, mstype.float32)
|
|
|
|
|
return primals_item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_derivative_order(order):
|
|
|
|
|
"""check input order of derivative"""
|
|
|
|
|
if not isinstance(order, int):
|
|
|
|
|
raise TypeError(f"For `F.derivative`, the type of order must be int, but got {type(order).__name__}.")
|
|
|
|
|
if order < 1:
|
|
|
|
|
raise ValueError(f"For `F.derivative`, value of order should not be less than 1, but got {order}.")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def derivative(fn, primals, order):
|
|
|
|
|
"""
|
|
|
|
|
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
|
|
|
@ -338,18 +362,23 @@ def derivative(fn, primals, order):
|
|
|
|
|
input first order derivative is set to 1, while the other to 0.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fn (Union(Cell, function)): Function to do TaylorOperation.
|
|
|
|
|
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
|
|
|
|
|
fn (Union[Cell, function]): Function to do TaylorOperation.
|
|
|
|
|
primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
|
|
|
|
|
order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
|
|
|
|
|
figured out.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, tuple of out_primals and out_series.
|
|
|
|
|
|
|
|
|
|
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
|
|
|
|
|
- **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect
|
|
|
|
|
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
|
|
|
|
|
- **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
|
|
|
|
|
to the inputs.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If `primals` is not a tensor or tuple of tensors.
|
|
|
|
|
TypeError: If `order` is not int.
|
|
|
|
|
ValueError: If `order` is less than 1.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
|
|
|
|
|
@ -375,10 +404,14 @@ def derivative(fn, primals, order):
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> out_primals, out_series = derivative(net, primals, order)
|
|
|
|
|
>>> print(out_primals, out_series)
|
|
|
|
|
[[2.319777 2.4825778]
|
|
|
|
|
[1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
|
|
|
|
|
[ 0.5053504 -0.52061415]]
|
|
|
|
|
"""
|
|
|
|
|
derivative_fn = _taylor(fn)
|
|
|
|
|
concat_op = P.Concat()
|
|
|
|
|
series_one = 1
|
|
|
|
|
_check_derivative_order(order)
|
|
|
|
|
if isinstance(primals, tuple):
|
|
|
|
|
trans_primals = [_trans_derivative_inputs(item) for item in primals]
|
|
|
|
|
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
|
|
|
|
|