!34210 [docs]add api docs of jet and derivative

Merge pull request !34210 from chenzhuo/master
This commit is contained in:
i-robot 2022-05-12 12:32:39 +00:00 committed by Gitee
commit 424598255a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 116 additions and 24 deletions

View File

@ -432,8 +432,10 @@ Parameter操作算子
mindspore.ops.core
mindspore.ops.count_nonzero
mindspore.ops.cummin
mindspore.ops.derivative
mindspore.ops.dot
mindspore.ops.grad
mindspore.ops.jet
mindspore.ops.jvp
mindspore.ops.laplace
mindspore.ops.narrow

View File

@ -0,0 +1,28 @@
mindspore.ops.derivative
========================
.. py:function:: mindspore.ops.derivative(fn, primals, order)
计算函数或网络输出对输入的高阶微分。给定待求导函数的原始输入和求导的阶数n将返回函数输出对输入的第n阶导数。
.. note::
- 若 `primals` 是int型的Tensor会被转化成float32格式进行计算。
**参数:**
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。
- **primals** (Union[Tensor, tuple[Tensor]]) - `fn` 的输入单输入的type为Tensor多输入的type为Tensor组成的tuple。
- **order** (int) - 求导的阶数。
**返回:**
tuple`out_primals``out_series` 组成。
- **out_primals** (Union[Tensor, list[Tensor]]) - `fn(primals)` 的结果。
- **out_series** (Union[Tensor, list[Tensor]]) - `fn` 输出对输入的第n阶导数。
**异常:**
- **TypeError** - `primals` 不是Tensor或tuple。
- **TypeError** - `order` 不是int。
- **ValueError** - `order` 不是正数。

View File

@ -0,0 +1,27 @@
mindspore.ops.jet
=================
.. py:function:: mindspore.ops.jet(fn, primals, series)
计算函数或网络输出对输入的高阶微分。给定待求导函数的原始输入和自定义的1到n阶导数将返回函数输出对输入的第1到n阶导数。一般情况建议输入的1阶导数值为全1更高阶的导数值为全0这与输入对本身的导数情况是一致的。
.. note::
- 若 `primals` 是int型的Tensor会被转化成float32格式进行计算。
**参数:**
- **fn** (Union[Function, Cell]) - 待求导的函数或网络。
- **primals** (Union[Tensor, tuple[Tensor]]) - `fn` 的输入单输入的type为Tensor多输入的type为Tensor组成的tuple。
- **series** (Union[Tensor, tuple[Tensor]]) - 输入的原始第1到第n阶导数。type与 `primals` 一致长度表示待求导的阶数n。
**返回:**
tuple`out_primals``out_series` 组成。
- **out_primals** (Union[Tensor, list[Tensor]]) - `fn(primals)` 的结果。
- **out_series** (Union[Tensor, list[Tensor]]) - `fn` 输出对输入的第1到n阶导数。
**异常:**
- **TypeError** - `primals` 不是Tensor或tuple。
- **TypeError** - `primals``series` 的type不一致。

View File

@ -431,8 +431,10 @@ Other Operators
mindspore.ops.core
mindspore.ops.count_nonzero
mindspore.ops.cummin
mindspore.ops.derivative
mindspore.ops.dot
mindspore.ops.grad
mindspore.ops.jet
mindspore.ops.jvp
mindspore.ops.laplace
mindspore.ops.narrow

View File

@ -222,7 +222,7 @@ def _trans_jet_inputs(primals_item, series_item):
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
raise TypeError(f"For `F.jet`, the elements' types of primals and series must be the same and belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.")
f" {dtype(primals_item)} and {dtype(series_item)}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float32), cast(series_item, mstype.float32)
return primals_item, series_item
@ -235,18 +235,19 @@ def _check_jet_inputs(primals, series):
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple, "
f"but got {type(primals).__name__} and {type(series).__name__}.")
if isinstance(primals, Tensor):
if primals.shape != series.shape[1:]:
raise ValueError("The shape of each element must be the same as the primals.")
return _trans_jet_inputs(primals, series)
if isinstance(primals, tuple):
if len(primals) != len(series):
raise ValueError("The lengths of primals and series must be the same.")
check_primals = []
check_series = []
for i, j in zip(primals, series):
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
check_primals.append(trans_primals_item)
check_series.append(trans_series_item)
if primals.shape == series.shape[1:]:
return _trans_jet_inputs(primals, series)
if primals.shape == series.shape:
return _trans_jet_inputs(primals, series.expand_dims(axis=0))
raise ValueError("In series, the shape of each element must be the same as the primals.")
if len(primals) != len(series):
raise ValueError("The lengths of primals and series must be the same.")
check_primals = []
check_series = []
for i, j in zip(primals, series):
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
check_primals.append(trans_primals_item)
check_series.append(trans_series_item)
return check_primals, check_series
@ -258,22 +259,26 @@ def jet(fn, primals, series):
This function is designed to calculate the higher order differentiation of given composite function. To figure out
first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
while the other to 0.
while the other to 0, which is like the derivative of origin input with respect to itself.
Args:
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs.
fn (Union[Cell, function]): Function to do TaylorOperation.
primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
series (Union[Tensor, tuple[Tensor]]): If tuple, the length and type of series should be the same as inputs.
For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
output with respect to the inputs will be figured out.
Returns:
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
to the inputs.
Raises:
TypeError: If `primals` is not a tensor or tuple of tensors.
TypeError: If type of `primals` is not the same as type of `series`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -299,6 +304,15 @@ def jet(fn, primals, series):
>>> net = Net()
>>> out_primals, out_series = jet(net, primals, series)
>>> print(out_primals, out_series)
[[2.319777 2.4825778]
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
[-1.1400385 -0.3066662 ]]
[[-1.2748207 -1.8274734 ]
[ 0.966121 0.55551505]]
[[-4.0515366 3.6724353 ]
[ 0.5053504 -0.52061415]]]
"""
primals, series = _check_jet_inputs(primals, series)
derivative_fn = _taylor(fn)
@ -325,12 +339,22 @@ def _trans_derivative_inputs(primals_item):
if not dtype(primals_item) in value_type:
raise TypeError(f"For `F.derivative`, the elements of primals must belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__}.")
f" {dtype(primals_item)}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float32)
return primals_item
@constexpr
def _check_derivative_order(order):
"""check input order of derivative"""
if not isinstance(order, int):
raise TypeError(f"For `F.derivative`, the type of order must be int, but got {type(order).__name__}.")
if order < 1:
raise ValueError(f"For `F.derivative`, value of order should not be less than 1, but got {order}.")
return True
def derivative(fn, primals, order):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
@ -338,18 +362,23 @@ def derivative(fn, primals, order):
input first order derivative is set to 1, while the other to 0.
Args:
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
fn (Union[Cell, function]): Function to do TaylorOperation.
primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
figured out.
Returns:
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
to the inputs.
Raises:
TypeError: If `primals` is not a tensor or tuple of tensors.
TypeError: If `order` is not int.
ValueError: If `order` is less than 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -375,10 +404,14 @@ def derivative(fn, primals, order):
>>> net = Net()
>>> out_primals, out_series = derivative(net, primals, order)
>>> print(out_primals, out_series)
[[2.319777 2.4825778]
[1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
[ 0.5053504 -0.52061415]]
"""
derivative_fn = _taylor(fn)
concat_op = P.Concat()
series_one = 1
_check_derivative_order(order)
if isinstance(primals, tuple):
trans_primals = [_trans_derivative_inputs(item) for item in primals]
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))