diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index b7f85923afc..af9254a57d0 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -84,6 +84,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支 mindspore.ops.BCEWithLogitsLoss mindspore.ops.BinaryCrossEntropy mindspore.ops.CTCLoss + mindspore.ops.CTCLossV2 mindspore.ops.KLDivLoss mindspore.ops.L2Loss mindspore.ops.NLLLoss diff --git a/docs/api/api_python/ops/mindspore.ops.CTCLossV2.rst b/docs/api/api_python/ops/mindspore.ops.CTCLossV2.rst new file mode 100644 index 00000000000..fa7d65859bf --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.CTCLossV2.rst @@ -0,0 +1,37 @@ +mindspore.ops.CTCLossV2 +======================= + +.. py:class:: mindspore.ops.CTCLossV2(blank=0, reduction="none", zero_infinity=False) + + 计算CTC(Connectionist Temporal Classification)损失和梯度。 + + CTC算法是在 `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with Recurrent Neural Networks `_ 中提出的。 + + 参数: + - **blank** (int,可选) - 空白标签。默认值:0。 + - **reduction** (str,可选) - 对输出应用特定的缩减方法。目前仅支持“none”,不区分大小写。默认值:“none”。 + - **zero_infinity** (bool,可选) - 是否将无限损失和相关梯度设置为零。默认值:False。 + + 输入: + - **log_probs** (Tensor) - 输入Tensor,其shape为 :math:`(T, C, N)` 的三维Tensor。 :math:`T` 表示输入长度, :math:`N` 表示批大小, :math:`C` 表示类别数,包含空白标签。 + - **targets** (Tensor) - 标签序列。其shape为 :math:`(N, S)` 的三维Tensor。 :math:`S` 表示最大标签长度。 + - **input_lengths** (Union(Tuple, Tensor)) - 输入的长度。其shape为 :math:`(N)` 。 + - **target_lengths** (Union(Tuple, Tensor)) - 标签的长度。其shape为 :math:`(N)` 。 + + 输出: + - **neg_log_likelihood** (Tensor) - 相对于每个输入节点可微分的损失值。 + - **log_alpha** (Tensor) - 输入到目标的可能跟踪概率。 + + 异常: + - **TypeError** - 如果 `zero_infinity` 不是bool类型。 + - **TypeError** - 如果 `reduction` 不是string类型。 + - **TypeError** - 如果 `log_probs` 的dtype不是float类型或double类型。 + - **TypeError** - 如果 `targets`、 `input_lengths` 或 `target_lengths` 的dtype不是int32类型或int64类型。 + - **ValueError** - 如果 `log_probs` 的秩不等于2。 + - **ValueError** - 如果 `targets` 的秩不等于2。 + - **ValueError** - 如果 `input_lengths` 的shape与批大小 :math:`N` 不匹配。 + - **ValueError** - 如果 `targets` 的shape与批大小 :math:`N` 不匹配。 + - **TypeError** - 如果 `targets`、 `input_lengths` 或 `target_lengths` 的类型不同。 + - **ValueError** - 如果 `blank` 的数值不是介于0和 :math:`C` 之间。 + - **RuntimeError** - `labels_indices` 的数据类型不是int64。 + - **RuntimeError** - 如果任何 `target_lengths[i]` 不在范围 [0, `input_length[i]` ] 范围内。 diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index b940fdf15aa..65daa6966b4 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -84,6 +84,7 @@ Loss Function mindspore.ops.BCEWithLogitsLoss mindspore.ops.BinaryCrossEntropy mindspore.ops.CTCLoss + mindspore.ops.CTCLossV2 mindspore.ops.KLDivLoss mindspore.ops.L2Loss mindspore.ops.NLLLoss diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index c2390f6aa13..3100bac1f6e 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -8002,34 +8002,38 @@ class CTCLossV2(Primitive): Recurrent Neural Networks `_. Args: - blank (int): The blank label. Default: 0. - reduction (string): Apply specific reduction method to the output. Currently only support 'none', + blank (int, optional): The blank label. Default: 0. + reduction (string, optional): Apply specific reduction method to the output. Currently only support 'none', not case sensitive. Default: "none". - zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False. + zero_infinity (bool, optional): Whether to set infinite loss and correlation gradient to zero. Default: False. Inputs: - - **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number - of classes (including blank). - - **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences. - - **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input. - - **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target. + - **log_probs** (Tensor) - A tensor of shape :math:`(T, C, N)`, where :math:`T` is input length, :math:`N` is + batch size and :math:`C` is number of classes (including blank). + - **targets** (Tensor) - A tensor of shape :math:`(N, S)`, where :math:`S` is max target length, + means the target sequences. + - **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape :math:`(N)`. + It means the lengths of the input. + - **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape :math:`(N)`. + It means the lengths of the target. Outputs: - **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node. - **log_alpha** (Tensor) - The probability of possible trace of input to target. Raises: - TypeError: If `zero_infinity` is not a bool, reduction is not string. + TypeError: If `zero_infinity` is not a bool. + TypeError: If `reduction` is not string. TypeError: If the dtype of `log_probs` is not float or double. TypeError: If the dtype of `targets`, `input_lengths` or `target_lengths` is not int32 or int64. ValueError: If the rank of `log_probs` is not 3. ValueError: If the rank of `targets` is not 2. ValueError: If the shape of `input_lengths` does not match {batch_size|N}. ValueError: If the shape of `target_lengths` does not match {batch_size|N}. - TypeError: If the types of `targets`, `input_lengths`, `grad_out` or `target_lengths` are different. + TypeError: If the types of `targets`, `input_lengths` or `target_lengths` are different. ValueError: If the value of `blank` is not in range [0, num_labels|C). RuntimeError: If any value of `input_lengths` is larger than (num_labels|C). - RuntimeError: If any target_lengths[i] is not in range [0, input_length[i]]. + RuntimeError: If any `target_lengths[i]` is not in range [0, `input_length[i]` ]. Supported Platforms: ``Ascend`` ``GPU`` ``CPU``