forked from mindspore-Ecosystem/mindspore
add ctcloss
This commit is contained in:
parent
890e5a124c
commit
a9b4a4b0f9
|
@ -234,6 +234,7 @@ Dropout层
|
|||
mindspore.nn.BCEWithLogitsLoss
|
||||
mindspore.nn.CosineEmbeddingLoss
|
||||
mindspore.nn.CrossEntropyLoss
|
||||
mindspore.nn.CTCLoss
|
||||
mindspore.nn.DiceLoss
|
||||
mindspore.nn.FocalLoss
|
||||
mindspore.nn.GaussianNLLLoss
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
mindspore.nn.CTCLoss
|
||||
====================
|
||||
|
||||
.. py:class:: mindspore.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
|
||||
|
||||
CTCLoss损失函数。
|
||||
|
||||
关于CTCLoss算法详细介绍,请参考 `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data withRecurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ 。
|
||||
|
||||
参数:
|
||||
- **blank** (int) - 空白标签。默认值:0。
|
||||
- **reduction** (str) - 指定输出结果的计算方式。可选值为"none"、"mean"或"sum"。默认值:"mean"。
|
||||
- **zero_infinity** (bool) - 是否设置无限损失和相关梯度为零。默认值:"False"。
|
||||
|
||||
输入:
|
||||
- **log_probs** (Tensor) - 输入Tensor,shape :math:`(T, N, C)` 。其中T表示输入长度,N表示批次大小,C是分类数。
|
||||
- **target** (Tensor) - 目标Tensor,shape :math:`(N, S)` 。其中S表示最大目标长度。
|
||||
- **input_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示输入长度。
|
||||
- **target_lengths** (Union(Tuple, Tensor)) - shape为N的Tensor或tuple。表示目标长度。
|
||||
|
||||
输出:
|
||||
- **neg_log_likelihood** (Tensor) - 对每一个输入节点可微调的损失值。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `zero_infinity` 不是布尔值, `reduction` 不是字符串。
|
||||
- **TypeError** - `log_probs` 的数据类型不是float或bouble。
|
||||
- **TypeError** - `targets` , `input_lengths` 或 `target_lengths` 数据类型不是int32或int64。
|
||||
- **ValueError** - `reduction` 不为"none","mean"或"sum"。
|
||||
- **ValueError** - `targets` , `input_lengths` 或 `target_lengths` 的数据类型是不同的。
|
||||
- **ValueError** - `blank` 值不介于0到C之间。
|
||||
- **ValueError** - `input_lengths` 的值大于C。
|
||||
- **ValueError** - `target_lengths[i]` 不在值不介于0到 `input_length[i]` 之间。
|
|
@ -234,6 +234,7 @@ Loss Function
|
|||
mindspore.nn.BCEWithLogitsLoss
|
||||
mindspore.nn.CosineEmbeddingLoss
|
||||
mindspore.nn.CrossEntropyLoss
|
||||
mindspore.nn.CTCLoss
|
||||
mindspore.nn.DiceLoss
|
||||
mindspore.nn.FocalLoss
|
||||
mindspore.nn.GaussianNLLLoss
|
||||
|
|
|
@ -20,12 +20,13 @@ It shows how well the model works on a dataset and the optimization target which
|
|||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss, \
|
||||
from mindspore.nn.loss.loss import LossBase, L1Loss, CTCLoss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss, \
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, MultilabelMarginLoss, \
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss',
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss',
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
|
||||
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
|
||||
from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp
|
||||
from mindspore.ops.operations.nn_ops import CTCLossV2
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import nn
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -2242,6 +2243,99 @@ class KLDivLoss(LossBase):
|
|||
return F.kl_div(logits, labels, self.reduction)
|
||||
|
||||
|
||||
class CTCLoss(LossBase):
|
||||
"""
|
||||
Calculates the CTC (Connectionist Temporal Classification) loss.
|
||||
|
||||
For the CTC algorithm, refer to `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
|
||||
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ .
|
||||
|
||||
Args:
|
||||
blank (int): The blank label. Default: 0.
|
||||
reduction (string): Apply specific reduction method to the output: 'none', 'mean', or 'sum'. Default: 'mean'.
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **log_probs** (Tensor) - A tensor of shape (T, N, C) or (T, C), where T is input length, N is batch size and
|
||||
C is number of classes (including blank).
|
||||
- **targets** (Tensor) - A tensor of shape (N, S) or (sum( `target_lengths` )), where S is max target length,
|
||||
means the target sequences.
|
||||
- **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
|
||||
It means the lengths of the input.
|
||||
- **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N), or a number.
|
||||
It means the lengths of the target.
|
||||
|
||||
Outputs:
|
||||
- **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
|
||||
|
||||
Raises:
|
||||
TypeError: If `zero_infinity` is not a bool, `reduction` is not string.
|
||||
TypeError: If the dtype of `log_probs` is not float or double.
|
||||
TypeError: If the dtype of `targets`, `input_lengths` or `target_lengths` is not int32 or int64.
|
||||
ValueError: If `reduction` is not "none", "mean" or "sum".
|
||||
ValueError: If the types of `targets`, `input_lengths` or `target_lengths` are different.
|
||||
ValueError: If the value of `blank` is not in range [0, C).
|
||||
ValueError: If any value of `input_lengths` is larger than C.
|
||||
ValueError: If any target_lengths[i] is not in range [0, input_length[i]].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> from mindspore.nn.loss import CTCLoss
|
||||
>>> T = 5 # Input sequence length
|
||||
>>> C = 2 # Number of classes
|
||||
>>> N = 2 # Batch size
|
||||
>>> S = 3 # Target sequence length of longest target in batch (padding length)
|
||||
>>> S_min = 2 # Minimum target length, for demonstration purposes
|
||||
>>> arr = np.arange(T*N*C).reshape((T, N, C))
|
||||
>>> ms_input = Tensor(arr, dtype=mstype.float32)
|
||||
>>> input_lengths = np.full(shape=(N), fill_value=T)
|
||||
>>> input_lengths = Tensor(input_lengths, dtype=mstype.int32)
|
||||
>>> target_lengths = np.full(shape=(N), fill_value=S_min)
|
||||
>>> target_lengths = Tensor(target_lengths, dtype=mstype.int32)
|
||||
>>> target = np.random.randint(1, C, size=(N, S))
|
||||
>>> target = Tensor(target, dtype=mstype.int32)
|
||||
>>> ctc_loss = CTCLoss(blank=0, reduction='none', zero_infinity=False)
|
||||
>>> loss = ctc_loss(ms_input, target, input_lengths, target_lengths)
|
||||
>>> print(loss)
|
||||
Tensor(shape=[2], dtype=Float32, value= [-4.57949715e+001, -5.57949677e+001])
|
||||
>>> arr = np.arange(T*C).reshape((T, C))
|
||||
>>> ms_input = Tensor(arr, dtype=mstype.float32)
|
||||
>>> input_lengths = T
|
||||
>>> target_lengths = S_min
|
||||
>>> target = np.random.randint(1, C, size=(S_min,))
|
||||
>>> target = Tensor(target, dtype=mstype.int32)
|
||||
>>> ctc_loss = CTCLoss(blank=0, reduction='none', zero_infinity=False)
|
||||
>>> loss = ctc_loss(ms_input, target, input_lengths, target_lengths)
|
||||
>>> print(loss)
|
||||
Tensor(shape=[1], dtype=Float32, value= [-2.57949677e+001])
|
||||
"""
|
||||
|
||||
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
|
||||
super().__init__(reduction)
|
||||
self.ctcloss = CTCLossV2(blank=blank, reduction='none', zero_infinity=zero_infinity)
|
||||
|
||||
def construct(self, log_probs, targets, input_lengths, target_lengths):
|
||||
if len(log_probs.shape) == 2:
|
||||
n, c = log_probs.shape
|
||||
log_probs = log_probs.reshape((n, 1, c))
|
||||
targets = targets.reshape(1, targets.shape[0])
|
||||
if isinstance(input_lengths, int):
|
||||
input_lengths = Tensor([input_lengths], mstype.int32)
|
||||
else:
|
||||
raise ValueError("The dtype of input_lengths should be int32 or int64.")
|
||||
if isinstance(target_lengths, int):
|
||||
target_lengths = Tensor([target_lengths], mstype.int32)
|
||||
else:
|
||||
raise ValueError("The dtype of target_lengths should be int32 or int64.")
|
||||
neg_log_hood, _ = self.ctcloss(log_probs, targets, input_lengths, target_lengths)
|
||||
return self.get_loss(neg_log_hood)
|
||||
|
||||
|
||||
class GaussianNLLLoss(LossBase):
|
||||
r"""Gaussian negative log likelihood loss.
|
||||
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
class CTCLossNet(nn.Cell):
|
||||
def __init__(self, reduction="none"):
|
||||
super(CTCLossNet, self).__init__()
|
||||
self.ctcloss = nn.CTCLoss(blank=0, reduction=reduction, zero_infinity=False)
|
||||
|
||||
def construct(self, log_probs, target, input_length, target_length):
|
||||
return self.ctcloss(log_probs, target, input_length, target_length)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('reduct', ["none", "mean", "sum"])
|
||||
def test_ctc_loss_tnc(mode, reduct):
|
||||
"""
|
||||
Feature: test CTCLoss op with input shape (T, N, C).
|
||||
Description: Verify the result of CTCLoss.
|
||||
Expectation: expect correct forward result.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
loss = CTCLossNet(reduction=reduct)
|
||||
|
||||
t = 10 # Input sequence length
|
||||
c = 4 # Number of classes
|
||||
n = 2 # Batch size
|
||||
s = 5 # Target sequence length of longest target in batch (padding length)
|
||||
s_min = 3 # Minimum target length, for demonstration purposes
|
||||
arr = np.arange(t * n * c).reshape((t, n, c))
|
||||
inputs = Tensor(arr, dtype=mstype.float32)
|
||||
input_lengths = np.full(shape=(n), fill_value=t)
|
||||
input_lengths = Tensor(input_lengths, dtype=mstype.int32)
|
||||
target_lengths = np.full(shape=(n), fill_value=s_min)
|
||||
target_lengths = Tensor(target_lengths, dtype=mstype.int32)
|
||||
arr = np.arange(n * s).reshape((n, s))
|
||||
targets = Tensor(arr, dtype=mstype.int32)
|
||||
|
||||
output = loss(inputs, targets, input_lengths, target_lengths)
|
||||
|
||||
if reduct == "none":
|
||||
expect_output = np.array([-3.78184143e+002, -4.60606476e+002])
|
||||
elif reduct == "mean":
|
||||
expect_output = np.array([-419.395])
|
||||
else:
|
||||
expect_output = np.array([-838.791])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('reduct', ["none", "mean", "sum"])
|
||||
def test_ctc_loss_tc(mode, reduct):
|
||||
"""
|
||||
Feature: test CTCLoss op with input shape (T, C).
|
||||
Description: Verify the result of CTCLoss.
|
||||
Expectation: expect correct forward result.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
loss = CTCLossNet(reduction=reduct)
|
||||
|
||||
t = 10 # Input sequence length
|
||||
c = 4 # Number of classes
|
||||
s_min = 3 # Minimum target length, for demonstration purposes
|
||||
arr = np.arange(t * c).reshape((t, c))
|
||||
inputs = Tensor(arr, dtype=mstype.float32)
|
||||
input_lengths = t
|
||||
target_lengths = s_min
|
||||
arr = np.arange(s_min).reshape((s_min,))
|
||||
targets = Tensor(arr, dtype=mstype.int32)
|
||||
|
||||
output = loss(inputs, targets, input_lengths, target_lengths)
|
||||
|
||||
expect_output = np.array([-1.98184158e+002])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -308,6 +308,29 @@ def test_margin_ranking_loss():
|
|||
loss(input1, input2, target)
|
||||
|
||||
|
||||
def test_ctc_loss():
|
||||
"""
|
||||
Feature: Test CTCLoss.
|
||||
Description: Test CTCLoss functional.
|
||||
Expectation: Success.
|
||||
"""
|
||||
t = 10 # Input sequence length
|
||||
c = 4 # Number of classes
|
||||
n = 2 # Batch size
|
||||
s = 5 # Target sequence length of longest target in batch
|
||||
s_min = 3 # Minimum target length, for demonstration purposes
|
||||
arr = np.random.randn(t * n * c).reshape((t, n, c))
|
||||
inputs = Tensor(arr, dtype=mstype.float32)
|
||||
input_lengths = np.full(shape=n, fill_value=t)
|
||||
input_lengths = Tensor(input_lengths, dtype=mstype.int32)
|
||||
target_lengths = np.full(shape=n, fill_value=s_min)
|
||||
target_lengths = Tensor(target_lengths, dtype=mstype.int32)
|
||||
target = np.random.randint(1, c, size=(n, s))
|
||||
target = Tensor(target, dtype=mstype.int32)
|
||||
ctc_loss = nn.CTCLoss(blank=0, reduction='none', zero_infinity=False)
|
||||
ctc_loss(inputs, target, input_lengths, target_lengths)
|
||||
|
||||
|
||||
def test_gaussian_nll_loss():
|
||||
"""
|
||||
Feature: Test GaussianNLLLoss.
|
||||
|
|
Loading…
Reference in New Issue