forked from mindspore-Ecosystem/mindspore
!49904 add_ops_ctc_loss_docs_master
Merge pull request !49904 from yide12/code_docs_nn_master
This commit is contained in:
commit
b92f5b24a9
|
@ -59,6 +59,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
|
|||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cosine_embedding_loss
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.ctc_loss
|
||||
mindspore.ops.gaussian_nll_loss
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.huber_loss
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
mindspore.ops.ctc_loss
|
||||
======================
|
||||
|
||||
.. py:function:: mindspore.ops.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False)
|
||||
|
||||
计算CTC(Connectist Temporal Classification)损失和梯度。
|
||||
|
||||
关于CTCLoss算法详细介绍,请参考 `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data withRecurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ 。
|
||||
|
||||
参数:
|
||||
- **log_probs** (Tensor) - 输入Tensor,shape :math:`(T, N, C)` 。其中T表示输入长度,N表示批次大小,C是分类数,包含空白。
|
||||
- **targets** (Tensor) - 目标Tensor,shape :math:`(N, S)` 。其中S表示最大目标长度。
|
||||
- **input_lengths** (Union[tuple, Tensor]) - shape为N的Tensor或tuple。表示输入长度。
|
||||
- **target_lengths** (Union[tuple, Tensor]) - shape为N的Tensor或tuple。表示目标长度。
|
||||
- **blank** (int) - 空白标签。默认值:0。
|
||||
- **reduction** (str) - 对输出应用归约方法。可选值为"none"、"mean"或"sum"。默认值:"mean"。
|
||||
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值:False。
|
||||
|
||||
返回:
|
||||
- **neg_log_likelihood** (Tensor) - 对每一个输入节点可微调的损失值,shape是 :math:`(N)`。
|
||||
- **log_alpha** (Tensor) - shape为 :math:`(N, T, 2 * S + 1)` 的输入到输出的轨迹概率。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `zero_infinity` 不是布尔值, `reduction` 不是字符串。
|
||||
- **TypeError** - `log_probs` 的数据类型不是float或double。
|
||||
- **TypeError** - `targets` 、 `input_lengths` 或 `target_lengths` 数据类型不是int32或int64。
|
||||
- **ValueError** - `log_probs` 的秩不是3。
|
||||
- **ValueError** - `targets` 的秩不是2。
|
||||
- **ValueError** - `input_lengths` 的shape大小不等于N。N是 `log_probs` 的批次大小。
|
||||
- **ValueError** - `target_lengths` 的shape大小不等于N。N是 `log_probs` 的批次大小。
|
||||
- **ValueError** - `targets` 、 `input_lengths` 或 `target_lengths` 的数据类型是不同的。
|
||||
- **ValueError** - `blank` 值不介于0到C之间。C是 `log_probs` 的分类数。
|
||||
- **RuntimeError** - `input_lengths` 的值大于T。T是 `log_probs` 的长度。
|
||||
- **RuntimeError** - `target_lengths[i]` 的取值范围不在0到 `input_length[i]` 之间。
|
|
@ -59,6 +59,7 @@ Loss Functions
|
|||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cosine_embedding_loss
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.ctc_loss
|
||||
mindspore.ops.gaussian_nll_loss
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.huber_loss
|
||||
|
|
|
@ -2302,7 +2302,7 @@ class CTCLoss(LossBase):
|
|||
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ .
|
||||
|
||||
Args:
|
||||
blank (int): The blank tag. Default: 0.
|
||||
blank (int): The blank label. Default: 0.
|
||||
reduction (str): Implements the reduction method to the output with 'none', 'mean', or 'sum'. Default: 'mean'.
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to 0. Default: False.
|
||||
|
||||
|
|
|
@ -3974,26 +3974,27 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
|
|||
input_lengths (Union(Tuple, Tensor)): A tuple or Tensor of shape(N). It means the lengths of the input.
|
||||
target_lengths (Union(Tuple, Tensor)): A tuple or Tensor of shape(N). It means the lengths of the target.
|
||||
blank (int): The blank label. Default: 0.
|
||||
reduction (string): Apply specific reduction method to the output: 'none', 'mean', or 'sum'. Default: 'mean'.
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
|
||||
reduction (string): Implements the reduction method to the output with 'none', 'mean', or 'sum'.
|
||||
Default: 'mean'.
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to 0. Default: False.
|
||||
|
||||
Returns:
|
||||
neg_log_likelihood (Tensor), A loss value with shape (N), which is differentiable with respect to
|
||||
neg_log_likelihood (Tensor), A loss value with shape :math:`(N)` , which is differentiable with respect to
|
||||
each input node.
|
||||
|
||||
log_alpha (Tensor), The probability of possible trace of input to target with shape (N, T, 2 * S + 1).
|
||||
log_alpha (Tensor), The probability of possible trace of input to target with shape :math:`(N, T, 2 * S + 1)` .
|
||||
|
||||
Raises:
|
||||
TypeError: If `zero_infinity` is not a bool, reduction is not string.
|
||||
TypeError: If the dtype of `log_probs` or `grad_out` is not float or double.
|
||||
TypeError: If `zero_infinity` is not a bool, `reduction` is not string.
|
||||
TypeError: If the dtype of `log_probs` is not float or double.
|
||||
TypeError: If the dtype of `targets`, `input_lengths` or `target_lengths` is not int32 or int64.
|
||||
ValueError: If the rank of `log_probs` is not 3.
|
||||
ValueError: If the rank of `targets` is not 2.
|
||||
ValueError: If the shape of `input_lengths` does not match {batch_size|N}.
|
||||
ValueError: If the shape of `target_lengths` does not match {batch_size|N}.
|
||||
TypeError: If the types of `targets`, `input_lengths`, `grad_out` or `target_lengths` are different.
|
||||
ValueError: If the value of `blank` is not in range [0, num_labels|C).
|
||||
RuntimeError: If any value of `input_lengths` is larger than (num_labels|C).
|
||||
ValueError: If the shape of `input_lengths` does not match N. N is batch size of `log_probs` .
|
||||
ValueError: If the shape of `target_lengths` does not match N. N is batch size of `log_probs` .
|
||||
TypeError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
|
||||
ValueError: If the value of `blank` is not in range [0, num_labels|C). C is number of classes of `log_probs` .
|
||||
RuntimeError: If any value of `input_lengths` is larger than T. T is the length of `log_probs`.
|
||||
RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]].
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue