!11959 Add BCEWithLogitsLoss op for Ascend.

From: @liu_xiao_93
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-08 09:05:31 +08:00 committed by Gitee
commit 61ab50915f
13 changed files with 392 additions and 4 deletions

View File

@ -32,6 +32,7 @@
#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h"
#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h"
#include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h"
#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h"
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
@ -191,6 +192,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>());
ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>());
ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>());
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
}
} // namespace
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
@ -333,6 +335,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h"
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Copy a new sigmoid node, shape of output is the same as input
std::vector<AnfNodePtr> new_simoid_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))};
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
auto predict_input = cnode->inputs()[1];
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)};
AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get());
// Add reduce node
string reduction = AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction);
MS_LOG(INFO) << "Create reduce node, reduction attr is: " << reduction;
std::vector<AnfNodePtr> reduce_inputs;
if (reduction == "sum") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode};
} else if (reduction == "mean") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode};
} else {
MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission.";
return nullptr;
}
auto reduce_node = func_graph->NewCNode(reduce_inputs);
MS_EXCEPTION_IF_NULL(reduce_node);
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
if (type == kNumberTypeFloat16) {
type = kNumberTypeFloat32;
}
auto shape = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node);
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_node);
reduce_node->set_scope(cnode->scope());
return reduce_node;
}
} // namespace
const BaseRef BCEWithLogitsLossFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs});
}
const AnfNodePtr BCEWithLogitsLossFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (GetBoolAttr(cnode, kAttrVisited)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
if (cnode->inputs().size() == 0) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr("reduction", cnode)) {
MS_LOG(INFO) << "Has no reduction attr.";
return nullptr;
}
return AddReduceNode(func_graph, node);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
class BCEWithLogitsLossFission : public PatternProcessPass {
public:
explicit BCEWithLogitsLossFission(bool multigraph = true)
: PatternProcessPass("bce_with_logits_loss_fission", multigraph) {}
~BCEWithLogitsLossFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_

View File

@ -344,6 +344,7 @@ constexpr auto kAttrWaitEventStream = "wait_event_stream";
constexpr auto kAttrIndex = "index";
constexpr auto kAttrSplitDim = "split_dim";
constexpr auto kAttrNumSplit = "num_split";
constexpr auto kAttrReduction = "reduction";
constexpr auto kAttrOutputNum = "output_num";
constexpr auto kAttrSizeSplits = "size_splits";
constexpr auto kAttrOutputDefault = "output_default";

View File

@ -282,6 +282,7 @@ inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Pri
inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam");
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimBCEWithLogitsLoss = std::make_shared<Primitive>("BCEWithLogitsLoss");
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");
inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize");

View File

@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss']

View File

@ -15,6 +15,7 @@
"""loss"""
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
@ -739,3 +740,86 @@ class CosineEmbeddingLoss(_Loss):
output_unreduced = pos_part + neg_part
return self.get_loss(output_unreduced)
class BCEWithLogitsLoss(_Loss):
r"""
Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy
between the target and the output.
Sets input predict as `X`, input target as `Y`, output as `L`. Then,
.. math::
p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
.. math::
L_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})]
Then,
.. math::
\ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{`none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
Args:
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default:`mean`.
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
If not None, it must can be broadcast to a tensor with shape of `predict`,
data type must be float16 or float32. Default: None.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
number of classes. If not None, it must can be broadcast to a tensor with shape of `predict`,
data type must be float16 or float32. Default: None.
Inputs:
- **predict** (Tensor) - Input logits. The data type must be float16 or float32.
- **target** (Tensor) - Ground truth label. Has the same data type and shape with `predict`.
Outputs:
Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`.
Raises:
TypeError: If data type of `predict` or `target` is neither float16 nor float32.
TypeError: If `weight` or `pos_weight` is Parameter.
TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend``
Examples:
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
>>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
>>> loss = nn.BCEWithLogitsLoss()
>>> output = loss(inputs, labels)
>>> print(output)
0.3463612
"""
def __init__(self, reduction='mean', weight=None, pos_weight=None):
super(BCEWithLogitsLoss, self).__init__()
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
if isinstance(weight, Parameter):
raise TypeError(f"For {self.cls_name}, weight can not be Parameter.")
if isinstance(pos_weight, Parameter):
raise TypeError(f"For {self.cls_name}, pos_weight can not be Parameter.")
self.weight = weight
self.pos_weight = pos_weight
self.ones = P.OnesLike()
def construct(self, predict, target):
ones_input = self.ones(predict)
if self.weight is not None:
weight = self.weight
else:
weight = ones_input
if self.pos_weight is not None:
pos_weight = self.pos_weight
else:
pos_weight = ones_input
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight)
return loss

View File

@ -1212,6 +1212,32 @@ def get_bprop_binary_cross_entropy(self):
return bprop
@bprop_getters.register(P.BCEWithLogitsLoss)
def get_bprop_ce_with_logits_loss(self):
"""Grad definition for `BCEWithLogitsLoss` operation."""
reduction = self.reduction
mul = P.Mul()
sigmoid = P.Sigmoid()
add = P.TensorAdd()
sub = P.Sub()
size = P.Size()
def bprop(predict, target, weight, pos_weight, out, dout):
sigmoid_input = sigmoid(predict)
if pos_weight is not None:
t = mul(target, pos_weight)
dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
else:
dx = mul((sigmoid_input - target), dout)
if weight is not None:
dx = mul(dx, weight)
if reduction == 'mean':
dx = dx / size(dx)
return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight)
return bprop
@bprop_getters.register(P.KLDivLoss)
def get_bprop_kl_div_loss(self):
"""Grad definition for `KLDivLoss` operation."""

View File

@ -254,6 +254,7 @@ from .prelu import _prelu_tbe
from .prelu_grad import _prelu_grad_tbe
from .binary_cross_entropy import _binary_cross_entropy_tbe
from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
from .bce_with_logits_loss import _bce_with_logits_loss_op_tbe
from .sin import _sin_tbe
from .cos import _cos_tbe
from .tan import _tan_tbe

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed unde:q!r the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BCEWithLogitsLoss op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sigmoid_cross_entropy_with_logits_v2.so") \
.compute_cost(10) \
.kernel_name("sigmoid_cross_entropy_with_logits_v2") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.attr("reduction", "optional", "str", "all", "mean") \
.input(0, "predict", False, "required", "all") \
.input(1, "target", False, "required", "all") \
.input(2, "weight", False, "optional", "all") \
.input(3, "pos_weight", False, "optional", "all") \
.output(0, "loss", False, "required", "all") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None) \
.get_op_info()
@op_info_register(bce_with_logits_loss_op_info)
def _bce_with_logits_loss_op_tbe():
"""BCEWithLogitsLoss TBE register"""
return

View File

@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU,
SigmoidCrossEntropyWithLogits, NLLLoss,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
@ -149,6 +149,7 @@ __all__ = [
'Softsign',
'LogSoftmax',
'SoftmaxCrossEntropyWithLogits',
'BCEWithLogitsLoss',
'ROIAlign',
'SparseSoftmaxCrossEntropyWithLogits',
'NLLLoss',

View File

@ -20,7 +20,6 @@ import operator
from functools import reduce, partial
from mindspore import log as logger
from mindspore._checkparam import _check_3d_int_or_tuple
from mindspore import log as logger
import numpy as np
from ... import context
from .. import signature as sig
@ -3701,6 +3700,99 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
return x_dtype
class BCEWithLogitsLoss(PrimitiveWithInfer):
r"""
Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy
between the target and the output.
Sets input predict as `X`, input target as `Y`, output as `L`. Then,
.. math::
p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
.. math::
L_{ij} = -[Y_{ij} * log(p_{ij}) + (1 - Y_{ij})log(1 - p_{ij})]
Then,
.. math::
\ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{`none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
Args:
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default:`mean`.
Inputs:
- **predict** (Tensor) - Input logits. Data type must be float16 or float32.
- **target** (Tensor) - Ground truth label. Has the same shape with `predict`.
Data type must be float16 or float32.
- **weight** (Tensor) - A rescaling weight applied to the loss of each batch element. It must can be
broadcast to a tensor with shape of `predict`. Data type must be float16 or float32.
- **pos_weight** (Tensor) - A weight of positive examples. Must be a vector with length equal to the
number of classes. It must can be broadcast to a tensor with shape of `predict`.
Data type must be float16 or float32.
Outputs:
Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`.
Raises:
TypeError: If data type of any input is neither float16 nor float32.
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend``
Examples:
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
>>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
>>> weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
>>> pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
>>> loss = ops.BCEWithLogitsLoss()
>>> output = loss(predict, target, weight, pos_weight)
>>> print(output)
0.3463612
"""
@prim_attr_register
def __init__(self, reduction='mean'):
"""Initialize BCEWithLogitsLoss"""
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
def infer_shape(self, predict, target, weight, pos_weight):
validator.check('predict_shape', predict, 'target_shape', target, Rel.EQ, self.name)
reversed_weight_shape = tuple(reversed(weight))
reversed_target = tuple(reversed(predict))
for i, v in enumerate(reversed_weight_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"For {self.name}, shapes can not broadcast. "
f"predict: {tuple(predict)}, weight shape {tuple(weight)}.")
reversed_pos_shape = tuple(reversed(pos_weight))
reversed_target = tuple(reversed(predict))
for i, v in enumerate(reversed_pos_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"For {self.name}, shapes can not broadcast. "
f"predict: {tuple(predict)}, weight shape {tuple(weight)}.")
if self.reduction in ('mean', 'sum'):
shape = []
else:
shape = predict
return shape
def infer_dtype(self, predict, target, weight, pos_weight):
validator.check_tensor_dtype_valid('predict dtype', predict, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('target dtype', target, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('weight dtype', weight, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('pos_weight dtype', pos_weight, [mstype.float16, mstype.float32], self.name)
return predict
class Pad(PrimitiveWithInfer):
"""
Pads the input tensor according to the paddings.

View File

@ -2058,6 +2058,10 @@ test_case_nn_ops = [
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)],
'desc_bprop': []}),
('BCEWithLogitsLoss', {
'block': P.BCEWithLogitsLoss(),
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': []}),
('ResizeBilinear', {
'block': P.ResizeBilinear((5, 5)),
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],