forked from mindspore-Ecosystem/mindspore
!41988 add strided_slice and KLDivLoss interface
Merge pull request !41988 from 范吉斌/add_interface
This commit is contained in:
commit
cc826879ef
|
@ -23,10 +23,10 @@ from __future__ import absolute_import
|
|||
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss
|
||||
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss']
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss']
|
||||
|
|
|
@ -2079,3 +2079,69 @@ class CrossEntropyLoss(LossBase):
|
|||
if logits.ndim == labels.ndim and self.ignore_index > 0:
|
||||
_cross_entropy_ignore_index_warning(self.cls_name)
|
||||
return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing)
|
||||
|
||||
|
||||
class KLDivLoss(LossBase):
|
||||
r"""
|
||||
Computes the Kullback-Leibler divergence between the logits and the labels.
|
||||
|
||||
The updating formulas of KLDivLoss algorithm are as follows,
|
||||
|
||||
.. math::
|
||||
L = \{l_1,\dots,l_N\}^\top, \quad
|
||||
l_n = target_n \cdot (\log target_n - x_n)
|
||||
|
||||
Then,
|
||||
|
||||
.. math::
|
||||
\ell(x, target) = \begin{cases}
|
||||
L, & \text{if reduction} = \text{'none';}\\
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
|
||||
\operatorname{batchmean}(L), & \text{if reduction} = \text{'batchmean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
|
||||
\end{cases}
|
||||
|
||||
where :math:`x` represents `logits`.
|
||||
:math:`target` represents `labels`.
|
||||
:math:`\ell(x, target)` represents `output`.
|
||||
|
||||
Note:
|
||||
Currently it does not support float64 input on `Ascend`.
|
||||
It behaves the same as the mathematical definition only when `reduction` is set to `batchmean`.
|
||||
|
||||
Args:
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'batchmean' or 'sum'. Default: 'mean'.
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - The input Tensor. The data type must be float16, float32 or float64.
|
||||
- **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`.
|
||||
|
||||
Returns:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
|
||||
Otherwise, it is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: If `reduction` is not a str.
|
||||
TypeError: If neither `logits` nor `labels` is a Tensor.
|
||||
TypeError: If dtype of `logits` or `labels` is not float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> logits = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([0., 1., 0.]), mindspore.float32)
|
||||
>>> loss = nn.KLDivLoss(reduction='mean')
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
-0.23333333
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, logits, labels):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
return F.kl_div(logits, labels, self.reduction)
|
||||
|
|
|
@ -51,6 +51,7 @@ from .array_func import (
|
|||
stack,
|
||||
unstack,
|
||||
tensor_slice,
|
||||
strided_slice,
|
||||
slice,
|
||||
scalar_to_array,
|
||||
scalar_to_tensor,
|
||||
|
|
|
@ -1060,6 +1060,108 @@ def select(cond, x, y):
|
|||
return tensor_select_(cond, input_x, input_y)
|
||||
|
||||
|
||||
def strided_slice(input_x, begin, end, strides):
|
||||
r"""
|
||||
Extracts a strided slice of a tensor.
|
||||
|
||||
This operation extracts a fragment of size (end-begin)/stride from the given 'input_tensor'.
|
||||
Starting from the beginning position, the fragment continues adding stride to the index until
|
||||
all dimensions are not less than the ending position.
|
||||
|
||||
Note:
|
||||
The stride may be negative value, which causes reverse slicing.
|
||||
The shape of `begin`, `end` and `strides` must be the same.
|
||||
`begin` and `end` are zero-indexed. The element of `strides` must be non-zero.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The input Tensor.
|
||||
begin (tuple[int]): A tuple which represents the location where to start. Only
|
||||
constant value is allowed.
|
||||
end (tuple[int]): A tuple or which represents the maximum location where to end.
|
||||
Only constant value is allowed.
|
||||
strides (tuple[int]): - A tuple which represents the stride is continuously added
|
||||
before reaching the maximum location. Only constant value is allowed.
|
||||
|
||||
Returns:
|
||||
Tensor, The output is explained by following example.
|
||||
|
||||
In the 0th dimension, begin is 1, end is 2, and strides is 1,
|
||||
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
|
||||
Thus, return the element with :math:`index = 1` in 0th dimension, i.e., [[3, 3, 3], [4, 4, 4]].
|
||||
|
||||
In the 1st dimension, similarly, the interval is :math:`[0,1)`.
|
||||
Based on the return value of the 0th dimension, return the element with :math:`index = 0`,
|
||||
i.e., [3, 3, 3].
|
||||
|
||||
In the 2nd dimension, similarly, the interval is :math:`[0,3)`.
|
||||
Based on the return value of the 1st dimension, return the element with :math:`index = 0,1,2`,
|
||||
i.e., [3, 3, 3].
|
||||
|
||||
Finally, the output is [3, 3, 3].
|
||||
|
||||
Raises:
|
||||
TypeError: If `begin`, `end` or `strides` is not a tuple.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
||||
... [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
||||
>>> output = ops.strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
|
||||
>>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
|
||||
>>> # start = [1, 0, 2] , end = [3, 1, 3], stride = [1, 1, 1], Find a segment of (start, end),
|
||||
>>> # note that end is an open interval
|
||||
>>> # To facilitate understanding, this operator can be divided into three steps:
|
||||
>>> # Step 1: Calculation of the first dimension:
|
||||
>>> # start = 1, end = 3, stride = 1, So can take 1st, 2nd rows, and then gets the final output at this time.
|
||||
>>> # output_1th =
|
||||
>>> # [
|
||||
>>> # [
|
||||
>>> # [3,3,3]
|
||||
>>> # [4,4,4]
|
||||
>>> # ]
|
||||
>>> # [
|
||||
>>> # [5,5,5]
|
||||
>>> # [6,6,6]
|
||||
>>> # ]
|
||||
>>> # ]
|
||||
>>> # Step 2: Calculation of the second dimension
|
||||
>>> # 2nd dimension, start = 0, end = 1, stride = 1. So only 0th rows can be taken, and the output at this time.
|
||||
>>> # output_2nd =
|
||||
>>> # [
|
||||
>>> # [
|
||||
>>> # [3,3,3]
|
||||
>>> # ]
|
||||
>>> # [
|
||||
>>> # [5,5,5]
|
||||
>>> # ]
|
||||
>>> # ]
|
||||
>>> # Step 3: Calculation of the third dimension
|
||||
>>> # 3nd dimension,start = 2, end = 3, stride = 1, So can take 2th cols,
|
||||
>>> # and you get the final output at this time.
|
||||
>>> # output_3ed =
|
||||
>>> # [
|
||||
>>> # [
|
||||
>>> # [3]
|
||||
>>> # ]
|
||||
>>> # [
|
||||
>>> # [5]
|
||||
>>> # ]
|
||||
>>> # ]
|
||||
>>> # The final output after finishing is:
|
||||
>>> print(output)
|
||||
[[[3.]]
|
||||
[[5.]]]
|
||||
>>> # another example like :
|
||||
>>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
|
||||
>>> print(output)
|
||||
[[[3. 3. 3.]]]
|
||||
"""
|
||||
strided_slice_ = _get_cache_prim(P.StridedSlice)()
|
||||
return strided_slice_(input_x, begin, end, strides)
|
||||
|
||||
|
||||
def slice(input_x, begin, size):
|
||||
r"""
|
||||
Slices a tensor in the specified shape.
|
||||
|
@ -4311,6 +4413,7 @@ __all__ = [
|
|||
'reshape_',
|
||||
'flatten',
|
||||
'tensor_slice',
|
||||
'strided_slice',
|
||||
'slice',
|
||||
'concat',
|
||||
'stack',
|
||||
|
|
|
@ -52,7 +52,6 @@ isinstance_ = P.IsInstance()
|
|||
|
||||
merge = P.Merge()
|
||||
geswitch = P.GeSwitch()
|
||||
strided_slice = P.StridedSlice()
|
||||
check_bprop = P.CheckBprop()
|
||||
sqrt = P.Sqrt()
|
||||
reduce_sum = P.ReduceSum()
|
||||
|
|
|
@ -211,3 +211,24 @@ def test_mode_batchmean_and_dtype_with_dynamic_input(mode, dtype):
|
|||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([0.52491106]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
|
||||
def test_mode_batchmean_with_nn_interface(mode, dtype):
|
||||
"""
|
||||
Feature: test batchmean mode with nn interface.
|
||||
Description: test batchmean mode with nn interface.
|
||||
Expectation: success.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
np.random.seed(42)
|
||||
prediction = mindspore.Tensor(np.log(np.array([[0.3, 0.7], [0.5, 0.5]])).astype(dtype))
|
||||
target = mindspore.Tensor(np.array([[-1, 1], [1, -1]]).astype(dtype))
|
||||
net = nn.KLDivLoss("batchmean")
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
expect = np.array([0.52491106]).astype(dtype)
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
@ -61,3 +62,19 @@ def test_slice_vmap():
|
|||
output = stridedslice_vmap(x)
|
||||
expect = np.ones((16, 1, 2, 3))
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice_functional():
|
||||
"""
|
||||
Feature: Test strided_slice functional interface.
|
||||
Description: Test strided_slice functional interface.
|
||||
Expectation: success.
|
||||
"""
|
||||
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
output = ops.strided_slice(x, (2, 0, 0), (3, 2, 3), (1, 1, 1))
|
||||
expect = [[[5., 5., 5.],
|
||||
[6., 7., 8.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
Loading…
Reference in New Issue