From d3f42e7d6b3e129903a057096a2bc1e57bf6fc9f Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Thu, 31 Dec 2020 16:02:43 +0800 Subject: [PATCH] add nll_loss operation. --- .../ccsrc/transform/graph_ir/op_adapter_map.h | 2 + .../graph_ir/op_declare/math_ops_declare.cc | 35 +++++++++++ .../graph_ir/op_declare/math_ops_declare.h | 31 ++++++++++ mindspore/ops/_grad/grad_nn_ops.py | 14 +++++ mindspore/ops/_op_impl/tbe/__init__.py | 2 + mindspore/ops/_op_impl/tbe/nll_loss.py | 40 ++++++++++++ mindspore/ops/_op_impl/tbe/nll_loss_grad.py | 41 ++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 29 +++++++++ mindspore/ops/operations/nn_ops.py | 62 ++++++++++++++++++- tests/ut/python/ops/test_ops.py | 18 ++++++ 11 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.cc create mode 100644 mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.h create mode 100644 mindspore/ops/_op_impl/tbe/nll_loss.py create mode 100644 mindspore/ops/_op_impl/tbe/nll_loss_grad.py diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index 1e8e3a6b04c..ed48b9de724 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -147,6 +147,8 @@ constexpr const char kNameCumSum[] = "CumSum"; constexpr const char kNameHuberLossGrad[] = "HuberLossGrad"; constexpr const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; constexpr const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; +constexpr const char kNameNLLLoss[] = "NLLLoss"; +constexpr const char kNameNLLLossGrad[] = "NLLLossGrad"; constexpr const char kNameTopK[] = "TopK"; constexpr const char kNameSoftmaxGrad[] = "SoftmaxGrad"; constexpr const char kNameMaxPool[] = "MaxPool"; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.cc new file mode 100644 index 00000000000..129cca0c216 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/op_declare/math_ops_declare.h" + +namespace mindspore::transform { +// NLLLoss +INPUT_MAP(NLLLoss) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(weight)}}; +ATTR_MAP(NLLLoss) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(NLLLoss) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(total_weight)}}; +REG_ADPT_DESC(NLLLoss, kNameNLLLoss, ADPT_DESC(NLLLoss)) + +// NLLLossGrad +INPUT_MAP(NLLLossGrad) = {{1, INPUT_DESC(x)}, + {2, INPUT_DESC(y_grad)}, + {3, INPUT_DESC(target)}, + {4, INPUT_DESC(weight)}, + {5, INPUT_DESC(total_weight)}}; +ATTR_MAP(NLLLossGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(NLLLossGrad) = {{0, OUTPUT_DESC(x_grad)}}; +REG_ADPT_DESC(NLLLossGrad, kNameNLLLossGrad, ADPT_DESC(NLLLossGrad)) +} // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.h new file mode 100644 index 00000000000..dcd739413cf --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/math_ops_declare.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ +#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ + +#include +#include +#include "transform/graph_ir/op_declare/op_declare_macro.h" +#include "ops/math_ops.h" + +namespace mindspore::transform { +DECLARE_OP_ADAPTER(NLLLoss) +DECLARE_OP_USE_OUTPUT(NLLLoss) +DECLARE_OP_ADAPTER(NLLLossGrad) +DECLARE_OP_USE_OUTPUT(NLLLossGrad) +} // namespace mindspore::transform +#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 2ed2e3ad583..4fc558b9e7e 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -736,6 +736,20 @@ def get_bprop_softmax_cross_entropy_with_logits(self): return bprop +@bprop_getters.register(P.NLLLoss) +def get_bprop_nll_loss(self): + """Grad definition for `NLLLoss` operation.""" + nll_loss_grad = G.NLLLossGrad(reduction=self.reduction) + + def bprop(x, target, weight, out, dout): + total_weight = out[1] + dout_x = dout[0] + dx = nll_loss_grad(x, dout_x, target, weight, total_weight) + return dx, zeros_like(target), zeros_like(weight) + + return bprop + + @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits) def get_bprop_sparse_softmax_cross_entropy_with_logits(self): """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 0f27fd921b7..f4a18110730 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -353,3 +353,5 @@ from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe from .conv3d_transpose import _conv3d_transpose_tbe from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe +from .nll_loss import _nll_loss_tbe +from .nll_loss_grad import _nll_loss_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/nll_loss.py b/mindspore/ops/_op_impl/tbe/nll_loss.py new file mode 100644 index 00000000000..fa0aa40b913 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/nll_loss.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NLLLoss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +nll_loss_op_info = TBERegOp("NLLLoss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("nll_loss.so") \ + .compute_cost(10) \ + .kernel_name("nll_loss") \ + .partial_flag(True) \ + .attr("reduction", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "target", False, "required", "all") \ + .input(2, "weight", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "total_weight", False, "optional", "all") \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(nll_loss_op_info) +def _nll_loss_tbe(): + """NLLLoss TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/nll_loss_grad.py b/mindspore/ops/_op_impl/tbe/nll_loss_grad.py new file mode 100644 index 00000000000..d394b27abda --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/nll_loss_grad.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NLLLossGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +nll_loss_grad_op_info = TBERegOp("NLLLossGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("nll_loss_grad.so") \ + .compute_cost(10) \ + .kernel_name("nll_loss_grad") \ + .partial_flag(True) \ + .attr("reduction", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "y_grad", False, "required", "all") \ + .input(2, "target", False, "required", "all") \ + .input(3, "weight", False, "required", "all") \ + .input(4, "total_weight", False, "required", "all") \ + .output(0, "x_grad", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(nll_loss_grad_op_info) +def _nll_loss_grad_tbe(): + """NLLLossGrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 00d37aefe51..5fd6067186d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, - SigmoidCrossEntropyWithLogits, + SigmoidCrossEntropyWithLogits, NLLLoss, SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, @@ -147,6 +147,7 @@ __all__ = [ 'SoftmaxCrossEntropyWithLogits', 'ROIAlign', 'SparseSoftmaxCrossEntropyWithLogits', + 'NLLLoss', 'SGD', 'ApplyMomentum', 'ExpandDims', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 8ba8fc3de23..6cce3c7f24a 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1746,6 +1746,35 @@ class SliceGrad(PrimitiveWithInfer): 'value': None} +class NLLLossGrad(PrimitiveWithInfer): + """Computes the gradients of `NLLLoss`.""" + + @prim_attr_register + def __init__(self, reduction="mean"): + """Initialize NLLLoss""" + self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) + self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) + self.add_prim_attr('reduction', self.reduction) + + def infer_shape(self, x_shape, y_grad_shape, t_shape, w_shape, tw_shape): + validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name) + validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) + validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) + validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) + validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype): + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("y_grad_dtype", y_grad_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name) + validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("tw_dtype", tw_dtype, valid_dtypes, self.name) + validator.check('tw_shape_dtype', tw_dtype, 'w_shape_dtype', w_dtype, Rel.EQ, self.name) + return x_dtype + + class SmoothL1LossGrad(PrimitiveWithInfer): """Computes gradient for prediction on SmoothL1Loss.""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 784598b8dfb..906cd0a0fcc 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1917,6 +1917,66 @@ class TopK(PrimitiveWithInfer): 'value': None} +class NLLLoss(PrimitiveWithInfer): + r""" + Gets the negative log likelihood loss between logits and labels. + + Args: + reduction (string): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean". + + Inputs: + - **input** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type only support float32 or float16. + - **target** (Tensor) - Ground truth labels, with shape :math:`(N)`. Data type only support int32. + - **weight** (Tensor) - The rescaling weight to each class, with shape :math:`(C)` and data type only + support float32 or float16`. + + Outputs: + Tuple of 2 tensors composed with `loss` and `total_weight`. when `reduction` is `none` and `input` is 2D + tensor, the `loss` shape is `(N,)`. Otherwise, the `loss` and the `total_weight` is a scalar. The data type + of `loss` and `total_weight` are same with `input's` and `weight's` respectively. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> input = Tensor(np.array([[0.5488135, 0.71518934], + >>> [0.60276335, 0.5448832], + >>> [0.4236548, 0.6458941]]).astype(np.float32)) + >>> target = Tensor(np.array([0, 0, 0]).astype(np.int32)) + >>> weight = Tensor(np.array([0.3834415, 0.79172504]).astype(np.float32)) + >>> nll_loss = ops.NLLLoss(reduction="mean") + >>> loss, weight = nll_loss(input, target, weight) + >>> print(loss) + [-0.52507716] + >>> print(weight) + [1.1503246 0.79172504] + """ + + @prim_attr_register + def __init__(self, reduction="mean"): + """Initialize NLLLoss""" + self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) + self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name) + self.add_prim_attr('reduction', self.reduction) + + def infer_shape(self, x_shape, t_shape, w_shape): + validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name) + validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) + validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) + validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) + validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) + if self.reduction == "none": + return t_shape, () + return (), () + + def infer_dtype(self, x_dtype, t_dtype, w_dtype): + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name) + validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name) + return x_dtype, w_dtype + + class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): r""" Gets the softmax cross-entropy value between logits and labels with one-hot encoding. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 6c647b3b677..33e2af9a6e4 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -242,6 +242,18 @@ class BatchNorm3d(nn.Cell): return bn3d_out +class NLLLoss(nn.Cell): + """NLLLoss net definition""" + + def __init__(self, reduction): + super(NLLLoss, self).__init__() + self.nll_loss = P.NLLLoss(reduction=reduction) + + def construct(self, input_x, target, weight): + loss = self.nll_loss(input_x, target, weight) + return loss + + class ClipByNorm(nn.Cell): """ClipByNorm net definition""" @@ -1253,6 +1265,12 @@ test_case_math_ops = [ 'block': Moments(axis=(), keep_dims=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], 'skip': ['backward']}), + ('NLLLoss', { + 'block': NLLLoss(reduction="mean"), + 'desc_inputs': [Tensor(np.random.rand(3, 16), mstype.float32), + Tensor(np.random.rand(3), mstype.int32), + Tensor(np.random.rand(16), mstype.float32)], + 'desc_bprop': [(Tensor(np.random.rand(1), mstype.float32), Tensor(np.random.rand(1), mstype.float32))]}), ('BatchNorm3d', { 'block': BatchNorm3d(num_features=3), 'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))],