forked from mindspore-Ecosystem/mindspore
!30628 Add CPU support for BCEWithLogitsLoss
Merge pull request !30628 from 吕昱峰(Nate.River)/bceloss
This commit is contained in:
commit
a7bea38992
|
@ -1251,7 +1251,7 @@ class BCEWithLogitsLoss(LossBase):
|
|||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
|
@ -1265,6 +1265,7 @@ class BCEWithLogitsLoss(LossBase):
|
|||
def __init__(self, reduction='mean', weight=None, pos_weight=None):
|
||||
"""Initialize BCEWithLogitsLoss."""
|
||||
super(BCEWithLogitsLoss, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
|
||||
if isinstance(weight, Parameter):
|
||||
raise TypeError(f"For '{self.cls_name}', the 'weight' can not be a Parameter.")
|
||||
|
@ -1273,10 +1274,42 @@ class BCEWithLogitsLoss(LossBase):
|
|||
self.weight = weight
|
||||
self.pos_weight = pos_weight
|
||||
self.ones = P.OnesLike()
|
||||
self.is_cpu = context.get_context("device_target") == "CPU"
|
||||
|
||||
def _construct_cpu(self, logits, labels):
|
||||
"""Use native implementation for CPU."""
|
||||
max_val = F.maximum(-logits, 0)
|
||||
|
||||
if self.pos_weight is not None:
|
||||
log_weight = ((self.pos_weight - 1) * labels) + 1
|
||||
loss = (1 - labels) * logits
|
||||
loss_1 = F.log(F.exp(F.neg_tensor(max_val)) + F.exp(F.neg_tensor(logits) - max_val)) + max_val
|
||||
loss += log_weight * loss_1
|
||||
else:
|
||||
loss = (1 - labels) * logits
|
||||
loss += max_val
|
||||
loss += F.log(F.exp(F.neg_tensor(max_val)) + F.exp(F.neg_tensor(logits) - max_val))
|
||||
|
||||
if self.weight is not None:
|
||||
output = loss * self.weight
|
||||
else:
|
||||
output = loss
|
||||
|
||||
if self.reduction == "mean":
|
||||
return F.reduce_mean(output)
|
||||
if self.reduction == "sum":
|
||||
return F.reduce_sum(output)
|
||||
return output
|
||||
|
||||
def construct(self, logits, labels):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
if self.is_cpu:
|
||||
return self._construct_cpu(logits, labels)
|
||||
return self._construct_gpu_ascend(logits, labels)
|
||||
|
||||
def _construct_gpu_ascend(self, logits, labels):
|
||||
"""Use P.BCEWithLogitsLoss for Ascend and GPU."""
|
||||
ones_input = self.ones(logits)
|
||||
if self.weight is not None:
|
||||
weight = self.weight
|
||||
|
|
Loading…
Reference in New Issue