!44537 code_docs_ctcloss
Merge pull request !44537 from yide12/code_docs_ctcloss
This commit is contained in:
commit
897b7e28df
|
@ -10,23 +10,23 @@ mindspore.nn.CTCLoss
|
|||
参数:
|
||||
- **blank** (int) - 空白标签。默认值:0。
|
||||
- **reduction** (str) - 指定输出结果的计算方式。可选值为"none"、"mean"或"sum"。默认值:"mean"。
|
||||
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值:"False"。
|
||||
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值:False。
|
||||
|
||||
输入:
|
||||
- **log_probs** (Tensor) - 输入Tensor,shape :math:`(T, N, C)` 。其中T表示输入长度,N表示批次大小,C是分类数。
|
||||
- **target** (Tensor) - 目标Tensor,shape :math:`(N, S)` 。其中S表示最大目标长度。
|
||||
- **input_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示输入长度。
|
||||
- **target_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示目标长度。
|
||||
- **log_probs** (Tensor) - 输入Tensor,shape :math:`(T, N, C)` 或 :math:`(T, C)` 。其中T表示输入长度,N表示批次大小,C是分类数。T,N,C均为正整数。
|
||||
- **target** (Tensor) - 目标Tensor,shape :math:`(N, S)` 或 (sum( `target_lengths` ))。其中S表示最大目标长度。
|
||||
- **input_lengths** (Union[tuple, Tensor, int]) - shape为N的Tensor或tuple或者是一个正整数。表示输入长度。
|
||||
- **target_lengths** (Union[tuple, Tensor, int]) - shape为N的Tensor或tuple或者是一个正整数。表示目标长度。
|
||||
|
||||
输出:
|
||||
- **neg_log_likelihood** (Tensor) - 对每一个输入节点可微调的损失值。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `zero_infinity` 不是布尔值, `reduction` 不是字符串。
|
||||
- **TypeError** - `log_probs` 的数据类型不是float或bouble。
|
||||
- **TypeError** - `log_probs` 的数据类型不是float或double。
|
||||
- **TypeError** - `targets` , `input_lengths` 或 `target_lengths` 数据类型不是int32或int64。
|
||||
- **ValueError** - `reduction` 不为"none","mean"或"sum"。
|
||||
- **ValueError** - `targets` , `input_lengths` 或 `target_lengths` 的数据类型是不同的。
|
||||
- **ValueError** - `blank` 值不介于0到C之间。
|
||||
- **ValueError** - `input_lengths` 的值大于C。
|
||||
- **ValueError** - `target_lengths[i]` 不在值不介于0到 `input_length[i]` 之间。
|
||||
- **ValueError** - `blank` 值不介于0到C之间。C是 `log_probs` 的分类数。
|
||||
- **ValueError** - `input_lengths` 的值大于C。C是 `log_probs` 的分类数。
|
||||
- **ValueError** - `target_lengths[i]` 的值不介于0到 `input_length[i]` 之间。
|
||||
|
|
|
@ -2258,12 +2258,12 @@ class CTCLoss(LossBase):
|
|||
|
||||
Inputs:
|
||||
- **log_probs** (Tensor) - A tensor of shape (T, N, C) or (T, C), where T is input length, N is batch size and
|
||||
C is number of classes (including blank).
|
||||
C is number of classes (including blank). T, N and C are positive integers.
|
||||
- **targets** (Tensor) - A tensor of shape (N, S) or (sum( `target_lengths` )), where S is max target length,
|
||||
means the target sequences.
|
||||
- **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
|
||||
- **input_lengths** (Union[tuple, Tensor, int]) - A tuple or Tensor of shape(N), or a number.
|
||||
It means the lengths of the input.
|
||||
- **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
|
||||
- **target_lengths** (Union[tuple, Tensor, int]) - A tuple or Tensor of shape(N), or a number.
|
||||
It means the lengths of the target.
|
||||
|
||||
Outputs:
|
||||
|
@ -2275,8 +2275,8 @@ class CTCLoss(LossBase):
|
|||
TypeError: If the dtype of `targets`, `input_lengths` or `target_lengths` is not int32 or int64.
|
||||
ValueError: If `reduction` is not "none", "mean" or "sum".
|
||||
ValueError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
|
||||
ValueError: If the value of `blank` is not in range [0, C).
|
||||
ValueError: If any value of `input_lengths` is larger than C.
|
||||
ValueError: If the value of `blank` is not in range [0, C). C is number of classes of `log_probs` .
|
||||
ValueError: If any value of `input_lengths` is larger than C. C is number of classes of `log_probs` .
|
||||
ValueError: If any target_lengths[i] is not in range [0, input_length[i]].
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue