mindspore/docs/api/api_python/nn/mindspore.nn.FocalLoss.rst

38 lines
2.0 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mindspore.nn.FocalLoss
=======================
.. py:class:: mindspore.nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')
FocalLoss函数。
解决了类别不平衡的问题。
FocalLoss函数在论文 `Focal Loss for Dense Object Detection <https://arxiv.org/pdf/1708.02002.pdf>`_ 中提出,提高了图像目标检测的效果。
函数如下:
.. math::
FL(p_t) = -(1-p_t)^\gamma log(p_t)
**参数:**
- **gamma** (float) - gamma用于调整Focal Loss的权重曲线的陡峭程度。默认值2.0。
- **weight** (Union[Tensor, None]) - Focal Loss的权重维度为1。如果为None则不使用权重。默认值None。
- **reduction** (str) - loss的计算方式。取值为"mean""sum",或"none"。默认值:"mean"。
**输入:**
- **logits** (Tensor) - shape为 :math:`(N, C)`:math:`(N, C, H)` 、或 :math:`(N, C, H, W)` 的Tensor其中 :math:`C` 是分类的数量值大于1。如果shape为 :math:`(N, C, H, W)`:math:`(N, C, H)` ,则 :math:`H`:math:`H`:math:`W` 的乘积应与 `labels` 的相同。
- **labels** (Tensor) - shape为 :math:`(N, C)`:math:`(N, C, H)` 、或 :math:`(N, C, H, W)` 的Tensor :math:`C` 的值为1或者与 `logits`:math:`C` 相同。如果 :math:`C` 不为1则shape应与 `logits` 的shape相同其中 :math:`C` 是分类的数量。如果shape为 :math:`(N, C, H, W)`:math:`(N, C, H)` ,则 :math:`H`:math:`H`:math:`W` 的乘积应与 `logits` 相同。
**输出:**
Tensor或Scalar如果 `reduction` 为"none"其shape与 `logits` 相同。否则将返回Scalar。
**异常:**
- **TypeError** - `gamma` 的数据类型不是float。
- **TypeError** - `weight` 不是Tensor。
- **ValueError** - `labels` 维度与 `logits` 不同。
- **ValueError** - `labels` 通道不为1`labels` 的shape与 `logits` 不同。
- **ValueError** - `reduction` 不为"mean""sum",或"none"。