forked from mindspore-Ecosystem/mindspore
Add BCEWithLogitsLoss
This commit is contained in:
parent
4ca4a148eb
commit
fabc25538e
|
@ -32,6 +32,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h"
|
||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
|
||||
|
@ -191,6 +192,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||
}
|
||||
} // namespace
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
|
@ -333,6 +335,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||
|
||||
optimizer->AddPassManager(ir_fusion_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Copy a new sigmoid node, shape of output is the same as input
|
||||
std::vector<AnfNodePtr> new_simoid_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))};
|
||||
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto predict_input = cnode->inputs()[1];
|
||||
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||
auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get());
|
||||
|
||||
// Add reduce node
|
||||
string reduction = AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction);
|
||||
MS_LOG(INFO) << "Create reduce node, reduction attr is: " << reduction;
|
||||
std::vector<AnfNodePtr> reduce_inputs;
|
||||
if (reduction == "sum") {
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode};
|
||||
} else if (reduction == "mean") {
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode};
|
||||
} else {
|
||||
MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission.";
|
||||
return nullptr;
|
||||
}
|
||||
auto reduce_node = func_graph->NewCNode(reduce_inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_node);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
if (type == kNumberTypeFloat16) {
|
||||
type = kNumberTypeFloat32;
|
||||
}
|
||||
auto shape = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_node);
|
||||
reduce_node->set_scope(cnode->scope());
|
||||
return reduce_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BCEWithLogitsLossFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr BCEWithLogitsLossFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (GetBoolAttr(cnode, kAttrVisited)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
if (cnode->inputs().size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr("reduction", cnode)) {
|
||||
MS_LOG(INFO) << "Has no reduction attr.";
|
||||
return nullptr;
|
||||
}
|
||||
return AddReduceNode(func_graph, node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BCEWithLogitsLossFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit BCEWithLogitsLossFission(bool multigraph = true)
|
||||
: PatternProcessPass("bce_with_logits_loss_fission", multigraph) {}
|
||||
~BCEWithLogitsLossFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_
|
|
@ -344,6 +344,7 @@ constexpr auto kAttrWaitEventStream = "wait_event_stream";
|
|||
constexpr auto kAttrIndex = "index";
|
||||
constexpr auto kAttrSplitDim = "split_dim";
|
||||
constexpr auto kAttrNumSplit = "num_split";
|
||||
constexpr auto kAttrReduction = "reduction";
|
||||
constexpr auto kAttrOutputNum = "output_num";
|
||||
constexpr auto kAttrSizeSplits = "size_splits";
|
||||
constexpr auto kAttrOutputDefault = "output_default";
|
||||
|
|
|
@ -282,6 +282,7 @@ inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Pri
|
|||
inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam");
|
||||
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
|
||||
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
|
||||
inline const PrimitivePtr kPrimBCEWithLogitsLoss = std::make_shared<Primitive>("BCEWithLogitsLoss");
|
||||
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
|
||||
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");
|
||||
inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize");
|
||||
|
|
|
@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which
|
|||
|
||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss
|
||||
|
||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss']
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""loss"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -739,3 +740,86 @@ class CosineEmbeddingLoss(_Loss):
|
|||
output_unreduced = pos_part + neg_part
|
||||
|
||||
return self.get_loss(output_unreduced)
|
||||
|
||||
|
||||
class BCEWithLogitsLoss(_Loss):
|
||||
r"""
|
||||
Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy
|
||||
between the target and the output.
|
||||
|
||||
Sets input predict as `X`, input target as `Y`, output as `L`. Then,
|
||||
|
||||
.. math::
|
||||
p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
|
||||
|
||||
.. math::
|
||||
L_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})]
|
||||
|
||||
Then,
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
L, & \text{if reduction} = \text{`none';}\\
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
||||
If "none", do not perform reduction. Default:`mean`.
|
||||
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
||||
If not None, it must can be broadcast to a tensor with shape of `predict`,
|
||||
data type must be float16 or float32. Default: None.
|
||||
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
|
||||
number of classes. If not None, it must can be broadcast to a tensor with shape of `predict`,
|
||||
data type must be float16 or float32. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **predict** (Tensor) - Input logits. The data type must be float16 or float32.
|
||||
- **target** (Tensor) - Ground truth label. Has the same data type and shape with `predict`.
|
||||
|
||||
Outputs:
|
||||
Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`.
|
||||
|
||||
Raises:
|
||||
TypeError: If data type of `predict` or `target` is neither float16 nor float32.
|
||||
TypeError: If `weight` or `pos_weight` is Parameter.
|
||||
TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
|
||||
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
>>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
|
||||
>>> loss = nn.BCEWithLogitsLoss()
|
||||
>>> output = loss(inputs, labels)
|
||||
>>> print(output)
|
||||
0.3463612
|
||||
"""
|
||||
|
||||
def __init__(self, reduction='mean', weight=None, pos_weight=None):
|
||||
super(BCEWithLogitsLoss, self).__init__()
|
||||
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
|
||||
if isinstance(weight, Parameter):
|
||||
raise TypeError(f"For {self.cls_name}, weight can not be Parameter.")
|
||||
if isinstance(pos_weight, Parameter):
|
||||
raise TypeError(f"For {self.cls_name}, pos_weight can not be Parameter.")
|
||||
self.weight = weight
|
||||
self.pos_weight = pos_weight
|
||||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, predict, target):
|
||||
ones_input = self.ones(predict)
|
||||
if self.weight is not None:
|
||||
weight = self.weight
|
||||
else:
|
||||
weight = ones_input
|
||||
if self.pos_weight is not None:
|
||||
pos_weight = self.pos_weight
|
||||
else:
|
||||
pos_weight = ones_input
|
||||
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight)
|
||||
return loss
|
||||
|
|
|
@ -1212,6 +1212,32 @@ def get_bprop_binary_cross_entropy(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BCEWithLogitsLoss)
|
||||
def get_bprop_ce_with_logits_loss(self):
|
||||
"""Grad definition for `BCEWithLogitsLoss` operation."""
|
||||
reduction = self.reduction
|
||||
mul = P.Mul()
|
||||
sigmoid = P.Sigmoid()
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
size = P.Size()
|
||||
|
||||
def bprop(predict, target, weight, pos_weight, out, dout):
|
||||
sigmoid_input = sigmoid(predict)
|
||||
if pos_weight is not None:
|
||||
t = mul(target, pos_weight)
|
||||
dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
|
||||
else:
|
||||
dx = mul((sigmoid_input - target), dout)
|
||||
if weight is not None:
|
||||
dx = mul(dx, weight)
|
||||
if reduction == 'mean':
|
||||
dx = dx / size(dx)
|
||||
return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.KLDivLoss)
|
||||
def get_bprop_kl_div_loss(self):
|
||||
"""Grad definition for `KLDivLoss` operation."""
|
||||
|
|
|
@ -254,6 +254,7 @@ from .prelu import _prelu_tbe
|
|||
from .prelu_grad import _prelu_grad_tbe
|
||||
from .binary_cross_entropy import _binary_cross_entropy_tbe
|
||||
from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
|
||||
from .bce_with_logits_loss import _bce_with_logits_loss_op_tbe
|
||||
from .sin import _sin_tbe
|
||||
from .cos import _cos_tbe
|
||||
from .tan import _tan_tbe
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed unde:q!r the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BCEWithLogitsLoss op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sigmoid_cross_entropy_with_logits_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid_cross_entropy_with_logits_v2") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.attr("reduction", "optional", "str", "all", "mean") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "target", False, "required", "all") \
|
||||
.input(2, "weight", False, "optional", "all") \
|
||||
.input(3, "pos_weight", False, "optional", "all") \
|
||||
.output(0, "loss", False, "required", "all") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
|
||||
DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bce_with_logits_loss_op_info)
|
||||
def _bce_with_logits_loss_op_tbe():
|
||||
"""BCEWithLogitsLoss TBE register"""
|
||||
return
|
|
@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid, SeLU,
|
||||
SigmoidCrossEntropyWithLogits, NLLLoss,
|
||||
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
|
||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
|
@ -149,6 +149,7 @@ __all__ = [
|
|||
'Softsign',
|
||||
'LogSoftmax',
|
||||
'SoftmaxCrossEntropyWithLogits',
|
||||
'BCEWithLogitsLoss',
|
||||
'ROIAlign',
|
||||
'SparseSoftmaxCrossEntropyWithLogits',
|
||||
'NLLLoss',
|
||||
|
|
|
@ -20,7 +20,6 @@ import operator
|
|||
from functools import reduce, partial
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import _check_3d_int_or_tuple
|
||||
from mindspore import log as logger
|
||||
import numpy as np
|
||||
from ... import context
|
||||
from .. import signature as sig
|
||||
|
@ -3701,6 +3700,99 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class BCEWithLogitsLoss(PrimitiveWithInfer):
|
||||
r"""
|
||||
Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy
|
||||
between the target and the output.
|
||||
|
||||
Sets input predict as `X`, input target as `Y`, output as `L`. Then,
|
||||
|
||||
.. math::
|
||||
p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
|
||||
|
||||
.. math::
|
||||
L_{ij} = -[Y_{ij} * log(p_{ij}) + (1 - Y_{ij})log(1 - p_{ij})]
|
||||
|
||||
Then,
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
L, & \text{if reduction} = \text{`none';}\\
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
||||
If "none", do not perform reduction. Default:`mean`.
|
||||
|
||||
Inputs:
|
||||
- **predict** (Tensor) - Input logits. Data type must be float16 or float32.
|
||||
- **target** (Tensor) - Ground truth label. Has the same shape with `predict`.
|
||||
Data type must be float16 or float32.
|
||||
- **weight** (Tensor) - A rescaling weight applied to the loss of each batch element. It must can be
|
||||
broadcast to a tensor with shape of `predict`. Data type must be float16 or float32.
|
||||
- **pos_weight** (Tensor) - A weight of positive examples. Must be a vector with length equal to the
|
||||
number of classes. It must can be broadcast to a tensor with shape of `predict`.
|
||||
Data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`.
|
||||
|
||||
Raises:
|
||||
TypeError: If data type of any input is neither float16 nor float32.
|
||||
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
>>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
|
||||
>>> weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
|
||||
>>> pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
|
||||
>>> loss = ops.BCEWithLogitsLoss()
|
||||
>>> output = loss(predict, target, weight, pos_weight)
|
||||
>>> print(output)
|
||||
0.3463612
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize BCEWithLogitsLoss"""
|
||||
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, predict, target, weight, pos_weight):
|
||||
validator.check('predict_shape', predict, 'target_shape', target, Rel.EQ, self.name)
|
||||
reversed_weight_shape = tuple(reversed(weight))
|
||||
reversed_target = tuple(reversed(predict))
|
||||
for i, v in enumerate(reversed_weight_shape):
|
||||
if v not in (reversed_target[i], 1):
|
||||
raise ValueError(f"For {self.name}, shapes can not broadcast. "
|
||||
f"predict: {tuple(predict)}, weight shape {tuple(weight)}.")
|
||||
|
||||
reversed_pos_shape = tuple(reversed(pos_weight))
|
||||
reversed_target = tuple(reversed(predict))
|
||||
for i, v in enumerate(reversed_pos_shape):
|
||||
if v not in (reversed_target[i], 1):
|
||||
raise ValueError(f"For {self.name}, shapes can not broadcast. "
|
||||
f"predict: {tuple(predict)}, weight shape {tuple(weight)}.")
|
||||
|
||||
if self.reduction in ('mean', 'sum'):
|
||||
shape = []
|
||||
else:
|
||||
shape = predict
|
||||
return shape
|
||||
|
||||
def infer_dtype(self, predict, target, weight, pos_weight):
|
||||
validator.check_tensor_dtype_valid('predict dtype', predict, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('target dtype', target, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('weight dtype', weight, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('pos_weight dtype', pos_weight, [mstype.float16, mstype.float32], self.name)
|
||||
return predict
|
||||
|
||||
|
||||
class Pad(PrimitiveWithInfer):
|
||||
"""
|
||||
Pads the input tensor according to the paddings.
|
||||
|
|
|
@ -2058,6 +2058,10 @@ test_case_nn_ops = [
|
|||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)],
|
||||
'desc_bprop': []}),
|
||||
('BCEWithLogitsLoss', {
|
||||
'block': P.BCEWithLogitsLoss(),
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
||||
'desc_bprop': []}),
|
||||
('ResizeBilinear', {
|
||||
'block': P.ResizeBilinear((5, 5)),
|
||||
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],
|
||||
|
|
Loading…
Reference in New Issue