!41988 add strided_slice and KLDivLoss interface

Merge pull request !41988 from 范吉斌/add_interface
This commit is contained in:
i-robot 2022-09-15 14:05:09 +00:00 committed by Gitee
commit cc826879ef
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 210 additions and 3 deletions

View File

@ -23,10 +23,10 @@ from __future__ import absolute_import
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss']
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss']

View File

@ -2079,3 +2079,69 @@ class CrossEntropyLoss(LossBase):
if logits.ndim == labels.ndim and self.ignore_index > 0:
_cross_entropy_ignore_index_warning(self.cls_name)
return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing)
class KLDivLoss(LossBase):
r"""
Computes the Kullback-Leibler divergence between the logits and the labels.
The updating formulas of KLDivLoss algorithm are as follows,
.. math::
L = \{l_1,\dots,l_N\}^\top, \quad
l_n = target_n \cdot (\log target_n - x_n)
Then,
.. math::
\ell(x, target) = \begin{cases}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{batchmean}(L), & \text{if reduction} = \text{'batchmean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
where :math:`x` represents `logits`.
:math:`target` represents `labels`.
:math:`\ell(x, target)` represents `output`.
Note:
Currently it does not support float64 input on `Ascend`.
It behaves the same as the mathematical definition only when `reduction` is set to `batchmean`.
Args:
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean', 'batchmean' or 'sum'. Default: 'mean'.
Inputs:
- **logits** (Tensor) - The input Tensor. The data type must be float16, float32 or float64.
- **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`.
Returns:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
Otherwise, it is a scalar.
Raises:
TypeError: If `reduction` is not a str.
TypeError: If neither `logits` nor `labels` is a Tensor.
TypeError: If dtype of `logits` or `labels` is not float32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> logits = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
>>> labels = Tensor(np.array([0., 1., 0.]), mindspore.float32)
>>> loss = nn.KLDivLoss(reduction='mean')
>>> output = loss(logits, labels)
>>> print(output)
-0.23333333
"""
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
return F.kl_div(logits, labels, self.reduction)

View File

@ -51,6 +51,7 @@ from .array_func import (
stack,
unstack,
tensor_slice,
strided_slice,
slice,
scalar_to_array,
scalar_to_tensor,

View File

@ -1060,6 +1060,108 @@ def select(cond, x, y):
return tensor_select_(cond, input_x, input_y)
def strided_slice(input_x, begin, end, strides):
r"""
Extracts a strided slice of a tensor.
This operation extracts a fragment of size (end-begin)/stride from the given 'input_tensor'.
Starting from the beginning position, the fragment continues adding stride to the index until
all dimensions are not less than the ending position.
Note:
The stride may be negative value, which causes reverse slicing.
The shape of `begin`, `end` and `strides` must be the same.
`begin` and `end` are zero-indexed. The element of `strides` must be non-zero.
Args:
input_x (Tensor): The input Tensor.
begin (tuple[int]): A tuple which represents the location where to start. Only
constant value is allowed.
end (tuple[int]): A tuple or which represents the maximum location where to end.
Only constant value is allowed.
strides (tuple[int]): - A tuple which represents the stride is continuously added
before reaching the maximum location. Only constant value is allowed.
Returns:
Tensor, The output is explained by following example.
In the 0th dimension, begin is 1, end is 2, and strides is 1,
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
Thus, return the element with :math:`index = 1` in 0th dimension, i.e., [[3, 3, 3], [4, 4, 4]].
In the 1st dimension, similarly, the interval is :math:`[0,1)`.
Based on the return value of the 0th dimension, return the element with :math:`index = 0`,
i.e., [3, 3, 3].
In the 2nd dimension, similarly, the interval is :math:`[0,3)`.
Based on the return value of the 1st dimension, return the element with :math:`index = 0,1,2`,
i.e., [3, 3, 3].
Finally, the output is [3, 3, 3].
Raises:
TypeError: If `begin`, `end` or `strides` is not a tuple.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
>>> output = ops.strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
>>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
>>> # start = [1, 0, 2] , end = [3, 1, 3], stride = [1, 1, 1], Find a segment of (start, end),
>>> # note that end is an open interval
>>> # To facilitate understanding, this operator can be divided into three steps:
>>> # Step 1: Calculation of the first dimension:
>>> # start = 1, end = 3, stride = 1, So can take 1st, 2nd rows, and then gets the final output at this time.
>>> # output_1th =
>>> # [
>>> # [
>>> # [3,3,3]
>>> # [4,4,4]
>>> # ]
>>> # [
>>> # [5,5,5]
>>> # [6,6,6]
>>> # ]
>>> # ]
>>> # Step 2: Calculation of the second dimension
>>> # 2nd dimension, start = 0, end = 1, stride = 1. So only 0th rows can be taken, and the output at this time.
>>> # output_2nd =
>>> # [
>>> # [
>>> # [3,3,3]
>>> # ]
>>> # [
>>> # [5,5,5]
>>> # ]
>>> # ]
>>> # Step 3: Calculation of the third dimension
>>> # 3nd dimension,start = 2, end = 3, stride = 1, So can take 2th cols,
>>> # and you get the final output at this time.
>>> # output_3ed =
>>> # [
>>> # [
>>> # [3]
>>> # ]
>>> # [
>>> # [5]
>>> # ]
>>> # ]
>>> # The final output after finishing is:
>>> print(output)
[[[3.]]
[[5.]]]
>>> # another example like :
>>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
>>> print(output)
[[[3. 3. 3.]]]
"""
strided_slice_ = _get_cache_prim(P.StridedSlice)()
return strided_slice_(input_x, begin, end, strides)
def slice(input_x, begin, size):
r"""
Slices a tensor in the specified shape.
@ -4311,6 +4413,7 @@ __all__ = [
'reshape_',
'flatten',
'tensor_slice',
'strided_slice',
'slice',
'concat',
'stack',

View File

@ -52,7 +52,6 @@ isinstance_ = P.IsInstance()
merge = P.Merge()
geswitch = P.GeSwitch()
strided_slice = P.StridedSlice()
check_bprop = P.CheckBprop()
sqrt = P.Sqrt()
reduce_sum = P.ReduceSum()

View File

@ -211,3 +211,24 @@ def test_mode_batchmean_and_dtype_with_dynamic_input(mode, dtype):
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([0.52491106]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_mode_batchmean_with_nn_interface(mode, dtype):
"""
Feature: test batchmean mode with nn interface.
Description: test batchmean mode with nn interface.
Expectation: success.
"""
context.set_context(mode=mode)
np.random.seed(42)
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
net = nn.KLDivLoss("batchmean")
loss = net(Tensor(prediction), Tensor(target))
expect = np.array([0.52491106]).astype(dtype)
assert np.allclose(loss.asnumpy(), expect)

View File

@ -19,6 +19,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import ops
from mindspore.ops import operations as P
from mindspore.ops.functional import vmap
@ -61,3 +62,19 @@ def test_slice_vmap():
output = stridedslice_vmap(x)
expect = np.ones((16, 1, 2, 3))
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice_functional():
"""
Feature: Test strided_slice functional interface.
Description: Test strided_slice functional interface.
Expectation: success.
"""
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
output = ops.strided_slice(x, (2, 0, 0), (3, 2, 3), (1, 1, 1))
expect = [[[5., 5., 5.],
[6., 7., 8.]]]
assert (output.asnumpy() == expect).all()