diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index c9900dab750..dc49cd4b11a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -32,6 +32,7 @@ #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" #include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h" +#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h" #include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" @@ -191,6 +192,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph) { @@ -333,6 +335,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.cc new file mode 100644 index 00000000000..41920ec9e94 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h" +#include +#include +#include +#include +#include "utils/utils.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/trace_base.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Copy a new sigmoid node, shape of output is the same as input + std::vector new_simoid_inputs = { + NewValueNode(std::make_shared(prim::kPrimBCEWithLogitsLoss->name()))}; + new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + auto predict_input = cnode->inputs()[1]; + auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)}; + auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get()); + + // Add reduce node + string reduction = AnfAlgo::GetNodeAttr(node, kAttrReduction); + MS_LOG(INFO) << "Create reduce node, reduction attr is: " << reduction; + std::vector reduce_inputs; + if (reduction == "sum") { + reduce_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), new_cnode}; + } else if (reduction == "mean") { + reduce_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceMean->name())), new_cnode}; + } else { + MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission."; + return nullptr; + } + auto reduce_node = func_graph->NewCNode(reduce_inputs); + MS_EXCEPTION_IF_NULL(reduce_node); + auto type = AnfAlgo::GetOutputInferDataType(node, 0); + if (type == kNumberTypeFloat16) { + type = kNumberTypeFloat32; + } + auto shape = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{}), reduce_node); + AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_node); + reduce_node->set_scope(cnode->scope()); + return reduce_node; +} +} // namespace + +const BaseRef BCEWithLogitsLossFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs}); +} + +const AnfNodePtr BCEWithLogitsLossFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (GetBoolAttr(cnode, kAttrVisited)) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + if (cnode->inputs().size() == 0) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr("reduction", cnode)) { + MS_LOG(INFO) << "Has no reduction attr."; + return nullptr; + } + return AddReduceNode(func_graph, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h new file mode 100644 index 00000000000..44f32a03f4a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BCEWithLogitsLossFission : public PatternProcessPass { + public: + explicit BCEWithLogitsLossFission(bool multigraph = true) + : PatternProcessPass("bce_with_logits_loss_fission", multigraph) {} + ~BCEWithLogitsLossFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4facb834b7e..57c703f0718 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -344,6 +344,7 @@ constexpr auto kAttrWaitEventStream = "wait_event_stream"; constexpr auto kAttrIndex = "index"; constexpr auto kAttrSplitDim = "split_dim"; constexpr auto kAttrNumSplit = "num_split"; +constexpr auto kAttrReduction = "reduction"; constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrOutputDefault = "output_default"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 19d7d2464af..39f7fe4a61d 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -282,6 +282,7 @@ inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared("FusedAdam"); inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared("FusedAdamWeightDecay"); inline const PrimitivePtr kPrimSGD = std::make_shared("SGD"); +inline const PrimitivePtr kPrimBCEWithLogitsLoss = std::make_shared("BCEWithLogitsLoss"); inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared("ClipByNormNoDivSum"); inline const PrimitivePtr kPrimTensorMove = std::make_shared("TensorMove"); inline const PrimitivePtr kPrimL2Normalize = std::make_shared("L2Normalize"); diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index fe3b9d983f5..8f1530de588 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which from .loss import L1Loss, MSELoss, SmoothL1Loss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ - SampledSoftmaxLoss, DiceLoss + SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', - 'SoftmaxCrossEntropyWithLogits', 'BCELoss', + 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index ade650fda07..7b8c19c2015 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -15,6 +15,7 @@ """loss""" import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr @@ -739,3 +740,86 @@ class CosineEmbeddingLoss(_Loss): output_unreduced = pos_part + neg_part return self.get_loss(output_unreduced) + + +class BCEWithLogitsLoss(_Loss): + r""" + Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy + between the target and the output. + + Sets input predict as `X`, input target as `Y`, output as `L`. Then, + + .. math:: + p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} + + .. math:: + L_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})] + + Then, + + .. math:: + \ell(x, y) = \begin{cases} + L, & \text{if reduction} = \text{`none';}\\ + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". + If "none", do not perform reduction. Default:`mean`. + weight (Tensor, optional): A rescaling weight applied to the loss of each batch element. + If not None, it must can be broadcast to a tensor with shape of `predict`, + data type must be float16 or float32. Default: None. + pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the + number of classes. If not None, it must can be broadcast to a tensor with shape of `predict`, + data type must be float16 or float32. Default: None. + + Inputs: + - **predict** (Tensor) - Input logits. The data type must be float16 or float32. + - **target** (Tensor) - Ground truth label. Has the same data type and shape with `predict`. + + Outputs: + Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`. + + Raises: + TypeError: If data type of `predict` or `target` is neither float16 nor float32. + TypeError: If `weight` or `pos_weight` is Parameter. + TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32. + ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`. + ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) + >>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) + >>> loss = nn.BCEWithLogitsLoss() + >>> output = loss(inputs, labels) + >>> print(output) + 0.3463612 + """ + + def __init__(self, reduction='mean', weight=None, pos_weight=None): + super(BCEWithLogitsLoss, self).__init__() + self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction) + if isinstance(weight, Parameter): + raise TypeError(f"For {self.cls_name}, weight can not be Parameter.") + if isinstance(pos_weight, Parameter): + raise TypeError(f"For {self.cls_name}, pos_weight can not be Parameter.") + self.weight = weight + self.pos_weight = pos_weight + self.ones = P.OnesLike() + + def construct(self, predict, target): + ones_input = self.ones(predict) + if self.weight is not None: + weight = self.weight + else: + weight = ones_input + if self.pos_weight is not None: + pos_weight = self.pos_weight + else: + pos_weight = ones_input + loss = self.bce_with_logits_loss(predict, target, weight, pos_weight) + return loss diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 52a7d88b9b9..6a28a3d57c0 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -1212,6 +1212,32 @@ def get_bprop_binary_cross_entropy(self): return bprop +@bprop_getters.register(P.BCEWithLogitsLoss) +def get_bprop_ce_with_logits_loss(self): + """Grad definition for `BCEWithLogitsLoss` operation.""" + reduction = self.reduction + mul = P.Mul() + sigmoid = P.Sigmoid() + add = P.TensorAdd() + sub = P.Sub() + size = P.Size() + + def bprop(predict, target, weight, pos_weight, out, dout): + sigmoid_input = sigmoid(predict) + if pos_weight is not None: + t = mul(target, pos_weight) + dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) + else: + dx = mul((sigmoid_input - target), dout) + if weight is not None: + dx = mul(dx, weight) + if reduction == 'mean': + dx = dx / size(dx) + return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight) + + return bprop + + @bprop_getters.register(P.KLDivLoss) def get_bprop_kl_div_loss(self): """Grad definition for `KLDivLoss` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index a6491b4799f..1e4b1baddb3 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -254,6 +254,7 @@ from .prelu import _prelu_tbe from .prelu_grad import _prelu_grad_tbe from .binary_cross_entropy import _binary_cross_entropy_tbe from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe +from .bce_with_logits_loss import _bce_with_logits_loss_op_tbe from .sin import _sin_tbe from .cos import _cos_tbe from .tan import _tan_tbe diff --git a/mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py b/mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py new file mode 100644 index 00000000000..eda25f98261 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed unde:q!r the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BCEWithLogitsLoss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sigmoid_cross_entropy_with_logits_v2.so") \ + .compute_cost(10) \ + .kernel_name("sigmoid_cross_entropy_with_logits_v2") \ + .partial_flag(True) \ + .op_pattern("dynamicFormat") \ + .attr("reduction", "optional", "str", "all", "mean") \ + .input(0, "predict", False, "required", "all") \ + .input(1, "target", False, "required", "all") \ + .input(2, "weight", False, "optional", "all") \ + .input(3, "pos_weight", False, "optional", "all") \ + .output(0, "loss", False, "required", "all") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, + DataType.None_None) \ + .get_op_info() + + +@op_info_register(bce_with_logits_loss_op_info) +def _bce_with_logits_loss_op_tbe(): + """BCEWithLogitsLoss TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 71adbf54b79..1fc6083a975 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SeLU, - SigmoidCrossEntropyWithLogits, NLLLoss, + SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss, SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, @@ -149,6 +149,7 @@ __all__ = [ 'Softsign', 'LogSoftmax', 'SoftmaxCrossEntropyWithLogits', + 'BCEWithLogitsLoss', 'ROIAlign', 'SparseSoftmaxCrossEntropyWithLogits', 'NLLLoss', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index d50552835d9..cc223fb8d22 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -20,7 +20,6 @@ import operator from functools import reduce, partial from mindspore import log as logger from mindspore._checkparam import _check_3d_int_or_tuple -from mindspore import log as logger import numpy as np from ... import context from .. import signature as sig @@ -3701,6 +3700,99 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): return x_dtype +class BCEWithLogitsLoss(PrimitiveWithInfer): + r""" + Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy + between the target and the output. + + Sets input predict as `X`, input target as `Y`, output as `L`. Then, + + .. math:: + p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} + + .. math:: + L_{ij} = -[Y_{ij} * log(p_{ij}) + (1 - Y_{ij})log(1 - p_{ij})] + + Then, + + .. math:: + \ell(x, y) = \begin{cases} + L, & \text{if reduction} = \text{`none';}\\ + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". + If "none", do not perform reduction. Default:`mean`. + + Inputs: + - **predict** (Tensor) - Input logits. Data type must be float16 or float32. + - **target** (Tensor) - Ground truth label. Has the same shape with `predict`. + Data type must be float16 or float32. + - **weight** (Tensor) - A rescaling weight applied to the loss of each batch element. It must can be + broadcast to a tensor with shape of `predict`. Data type must be float16 or float32. + - **pos_weight** (Tensor) - A weight of positive examples. Must be a vector with length equal to the + number of classes. It must can be broadcast to a tensor with shape of `predict`. + Data type must be float16 or float32. + + Outputs: + Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`. + + Raises: + TypeError: If data type of any input is neither float16 nor float32. + ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`. + ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) + >>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) + >>> weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + >>> pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + >>> loss = ops.BCEWithLogitsLoss() + >>> output = loss(predict, target, weight, pos_weight) + >>> print(output) + 0.3463612 + """ + + @prim_attr_register + def __init__(self, reduction='mean'): + """Initialize BCEWithLogitsLoss""" + self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) + + def infer_shape(self, predict, target, weight, pos_weight): + validator.check('predict_shape', predict, 'target_shape', target, Rel.EQ, self.name) + reversed_weight_shape = tuple(reversed(weight)) + reversed_target = tuple(reversed(predict)) + for i, v in enumerate(reversed_weight_shape): + if v not in (reversed_target[i], 1): + raise ValueError(f"For {self.name}, shapes can not broadcast. " + f"predict: {tuple(predict)}, weight shape {tuple(weight)}.") + + reversed_pos_shape = tuple(reversed(pos_weight)) + reversed_target = tuple(reversed(predict)) + for i, v in enumerate(reversed_pos_shape): + if v not in (reversed_target[i], 1): + raise ValueError(f"For {self.name}, shapes can not broadcast. " + f"predict: {tuple(predict)}, weight shape {tuple(weight)}.") + + if self.reduction in ('mean', 'sum'): + shape = [] + else: + shape = predict + return shape + + def infer_dtype(self, predict, target, weight, pos_weight): + validator.check_tensor_dtype_valid('predict dtype', predict, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('target dtype', target, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('weight dtype', weight, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('pos_weight dtype', pos_weight, [mstype.float16, mstype.float32], self.name) + return predict + + class Pad(PrimitiveWithInfer): """ Pads the input tensor according to the paddings. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index bfcc0be7047..503765b322c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2058,6 +2058,10 @@ test_case_nn_ops = [ 'block': P.L2Loss(), 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], 'desc_bprop': []}), + ('BCEWithLogitsLoss', { + 'block': P.BCEWithLogitsLoss(), + 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], + 'desc_bprop': []}), ('ResizeBilinear', { 'block': P.ResizeBilinear((5, 5)), 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],