!30628 Add CPU support for BCEWithLogitsLoss

Merge pull request !30628 from 吕昱峰(Nate.River)/bceloss
This commit is contained in:
i-robot 2022-02-28 01:52:13 +00:00 committed by Gitee
commit a7bea38992
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 34 additions and 1 deletions

View File

@ -1251,7 +1251,7 @@ class BCEWithLogitsLoss(LossBase):
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
@ -1265,6 +1265,7 @@ class BCEWithLogitsLoss(LossBase):
def __init__(self, reduction='mean', weight=None, pos_weight=None):
"""Initialize BCEWithLogitsLoss."""
super(BCEWithLogitsLoss, self).__init__()
self.reduction = reduction
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
if isinstance(weight, Parameter):
raise TypeError(f"For '{self.cls_name}', the 'weight' can not be a Parameter.")
@ -1273,10 +1274,42 @@ class BCEWithLogitsLoss(LossBase):
self.weight = weight
self.pos_weight = pos_weight
self.ones = P.OnesLike()
self.is_cpu = context.get_context("device_target") == "CPU"
def _construct_cpu(self, logits, labels):
"""Use native implementation for CPU."""
max_val = F.maximum(-logits, 0)
if self.pos_weight is not None:
log_weight = ((self.pos_weight - 1) * labels) + 1
loss = (1 - labels) * logits
loss_1 = F.log(F.exp(F.neg_tensor(max_val)) + F.exp(F.neg_tensor(logits) - max_val)) + max_val
loss += log_weight * loss_1
else:
loss = (1 - labels) * logits
loss += max_val
loss += F.log(F.exp(F.neg_tensor(max_val)) + F.exp(F.neg_tensor(logits) - max_val))
if self.weight is not None:
output = loss * self.weight
else:
output = loss
if self.reduction == "mean":
return F.reduce_mean(output)
if self.reduction == "sum":
return F.reduce_sum(output)
return output
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
if self.is_cpu:
return self._construct_cpu(logits, labels)
return self._construct_gpu_ascend(logits, labels)
def _construct_gpu_ascend(self, logits, labels):
"""Use P.BCEWithLogitsLoss for Ascend and GPU."""
ones_input = self.ones(logits)
if self.weight is not None:
weight = self.weight