!44537 code_docs_ctcloss

Merge pull request !44537 from yide12/code_docs_ctcloss
This commit is contained in:
i-robot 2022-10-26 03:09:42 +00:00 committed by Gitee
commit 897b7e28df
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 14 additions and 14 deletions

View File

@ -10,23 +10,23 @@ mindspore.nn.CTCLoss
参数:
- **blank** (int) - 空白标签。默认值0。
- **reduction** (str) - 指定输出结果的计算方式。可选值为"none"、"mean"或"sum"。默认值:"mean"。
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值:"False"
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值False。
输入:
- **log_probs** (Tensor) - 输入Tensorshape :math:`(T, N, C)` 。其中T表示输入长度N表示批次大小C是分类数。
- **target** (Tensor) - 目标Tensorshape :math:`(N, S)` 。其中S表示最大目标长度。
- **input_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示输入长度。
- **target_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示目标长度。
- **log_probs** (Tensor) - 输入Tensorshape :math:`(T, N, C)` :math:`(T, C)` 。其中T表示输入长度N表示批次大小C是分类数。TNC均为正整数。
- **target** (Tensor) - 目标Tensorshape :math:`(N, S)` 或 (sum( `target_lengths` ))。其中S表示最大目标长度。
- **input_lengths** (Union[tuple, Tensor, int]) - shape为N的Tensor或tuple或者是一个正整数。表示输入长度。
- **target_lengths** (Union[tuple, Tensor, int]) - shape为N的Tensor或tuple或者是一个正整数。表示目标长度。
输出:
- **neg_log_likelihood** (Tensor) - 对每一个输入节点可微调的损失值。
异常:
- **TypeError** - `zero_infinity` 不是布尔值, `reduction` 不是字符串。
- **TypeError** - `log_probs` 的数据类型不是float或bouble。
- **TypeError** - `log_probs` 的数据类型不是float或double。
- **TypeError** - `targets` `input_lengths``target_lengths` 数据类型不是int32或int64。
- **ValueError** - `reduction` 不为"none""mean"或"sum"。
- **ValueError** - `targets` `input_lengths``target_lengths` 的数据类型是不同的。
- **ValueError** - `blank` 值不介于0到C之间。
- **ValueError** - `input_lengths` 的值大于C。
- **ValueError** - `target_lengths[i]` 不在值不介于0到 `input_length[i]` 之间。
- **ValueError** - `blank` 值不介于0到C之间。C是 `log_probs` 的分类数。
- **ValueError** - `input_lengths` 的值大于C。C是 `log_probs` 的分类数。
- **ValueError** - `target_lengths[i]` 值不介于0到 `input_length[i]` 之间。

View File

@ -2258,12 +2258,12 @@ class CTCLoss(LossBase):
Inputs:
- **log_probs** (Tensor) - A tensor of shape (T, N, C) or (T, C), where T is input length, N is batch size and
C is number of classes (including blank).
C is number of classes (including blank). T, N and C are positive integers.
- **targets** (Tensor) - A tensor of shape (N, S) or (sum( `target_lengths` )), where S is max target length,
means the target sequences.
- **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
- **input_lengths** (Union[tuple, Tensor, int]) - A tuple or Tensor of shape(N), or a number.
It means the lengths of the input.
- **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
- **target_lengths** (Union[tuple, Tensor, int]) - A tuple or Tensor of shape(N), or a number.
It means the lengths of the target.
Outputs:
@ -2275,8 +2275,8 @@ class CTCLoss(LossBase):
TypeError: If the dtype of `targets`, `input_lengths` or `target_lengths` is not int32 or int64.
ValueError: If `reduction` is not "none", "mean" or "sum".
ValueError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
ValueError: If the value of `blank` is not in range [0, C).
ValueError: If any value of `input_lengths` is larger than C.
ValueError: If the value of `blank` is not in range [0, C). C is number of classes of `log_probs` .
ValueError: If any value of `input_lengths` is larger than C. C is number of classes of `log_probs` .
ValueError: If any target_lengths[i] is not in range [0, input_length[i]].
Supported Platforms: