add functional api and vmap rule for softmax

This commit is contained in:
hanhuiyu1996 2022-06-13 21:58:02 +08:00
parent f462ef4166
commit 348e6d9e26
9 changed files with 151 additions and 36 deletions

View File

@ -91,6 +91,7 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
mindspore.ops.mish
mindspore.ops.selu
mindspore.ops.soft_shrink
mindspore.ops.softmax
mindspore.ops.softsign
mindspore.ops.tanh

View File

@ -0,0 +1,26 @@
mindspore.ops.softmax
=====================
.. py::: function.ops.softmax(x, axis=-1)
Softmax函数。
在指定轴上使用Softmax函数做归一化操作。假设指定轴 :math:`x` 上有切片,那么每个元素 :math:`x_i` 所对应的Softmax函数如下所示
.. math::
\text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},
其中 :math:`N` 代表Tensor的长度。
参数:
- **axis** (Int) - 指定Softmax操作的轴。默认值-1。
- **x** (Tensor) - Softmax的输入任意维度的Tensor。其数据类型为float16或float32。
返回:
Tensor数据类型和shape与 `x` 相同。
异常:
- **TypeError** - `axis` 不是int。
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
- **ValueError** - `axis` 是长度小于1的tuple。
- **ValueError** - `axis` 是一个tuple其元素不全在[-len(x.shape), len(x.shape))范围中。

View File

@ -91,6 +91,7 @@ Activation Functions
mindspore.ops.selu
mindspore.ops.softsign
mindspore.ops.soft_shrink
mindspore.ops.softmax
mindspore.ops.tanh
Sampling Functions

View File

@ -136,6 +136,10 @@ class SoftmaxGpuKernelMod : public DeprecatedNativeGpuKernelMod {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be equal to 0, but got "
<< axis.size();
}
if (axis.size() > 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be greater than 1, but got "
<< axis.size();
}
InitSizeByAxis(input_shape, axis[0]);
}
CHECK_CUDNN_RET_WITH_EXCEPT(

View File

@ -25,6 +25,7 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_any, \
_bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, _vmap_clone_prim
from .._vmap.vmap_array_ops import _get_reduce_batch_axis
from ..primitive import Primitive
@ -595,6 +596,26 @@ def get_matmul_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.Softmax)
def get_softmax_vmap_rule(prim, axis_size):
"""VmapRule for `Softmax`"""
axis = prim.axis[0]
if isinstance(axis, tuple):
axis = axis[0]
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
x_ndim = F.rank(x)
batch_axis = _get_reduce_batch_axis(axis, x_dim, x_ndim)
out = P.Softmax(batch_axis)(x)
return out, x_dim
return vmap_rule
@vmap_rules_getters.register(P.AdaptiveAvgPool2D)
def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
"""VmapRule for `AdaptiveAvgPool2D` operation."""

View File

@ -278,6 +278,7 @@ from .nn_func import (
hardswish,
softsign,
selu,
softmax,
pdist,
pad,
nll_loss,

View File

@ -858,6 +858,48 @@ def softsign(x):
return softsign_(x)
def softmax(x, axis=-1):
r"""
Softmax operation.
Applies the Softmax operation to the input tensor on the specified axis.
Suppose a slice in the given axis :math:`x`, then for each element :math:`x_i`,
the Softmax function is shown as follows:
.. math::
\text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},
where :math:`N` is the length of the tensor.
Args:
axis (Int): The axis to perform the Softmax operation. Default: -1.
x (Tensor): Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
Returns:
Tensor, with the same type and shape as the logits.
Raises:
TypeError: If `axis` is nnot an int.
TypeError: If dtype of `x` is neither float16 nor float32.
ValueError: If `axis` is a tuple whose length is less than 1.
ValueError: If `axis` is a tuple whose elements are not all in range [-len(logits.shape), len(logits.shape))
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> output = ops.softmax(x)
>>> print(output)
[0.01165623 0.03168492 0.08612854 0.23412167 0.6364086 ]
"""
validator.check_value_type("axis", axis, int)
softmax_ = P.Softmax(axis=axis)
return softmax_(x)
def soft_shrink(x, lambd=0.5):
r"""
Applies the SoftShrink function element-wise.
@ -1920,6 +1962,7 @@ __all__ = [
'hardswish',
'softsign',
'selu',
'softmax',
'pdist',
'pad',
'cross_entropy',

View File

@ -396,32 +396,9 @@ class AdaptiveMaxPool3D(Primitive):
class Softmax(Primitive):
r"""
Softmax operation.
Applies the Softmax operation to the input tensor on the specified axis.
Suppose a slice in the given axis :math:`x`, then for each element :math:`x_i`,
the Softmax function is shown as follows:
.. math::
\text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},
where :math:`N` is the length of the tensor.
Args:
axis (Union[int, tuple]): The axis to perform the Softmax operation. Default: -1.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
Outputs:
Tensor, with the same type and shape as the logits.
Raises:
TypeError: If `axis` is neither an int nor a tuple.
TypeError: If dtype of `logits` is neither float16 nor float32.
ValueError: If `axis` is a tuple whose length is less than 1.
ValueError: If `axis` is a tuple whose elements are not all in range [-len(logits.shape), len(logits.shape)).
Refer to :func:`mindspore.ops.softmax` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -21,6 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class NetSoftmax(nn.Cell):
@ -47,22 +48,22 @@ def test_softmax():
error2 = expect2 * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
softmax = NetSoftmax()
output = softmax(x)
output_sum1 = output[0].asnumpy().sum(axis=1)
output_sum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(output_sum1 - expect1)
diff2 = np.abs(output_sum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
Softmax = NetSoftmax()
output = Softmax(x)
outputSum1 = output[0].asnumpy().sum(axis=1)
outputSum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(outputSum1 - expect1)
diff2 = np.abs(outputSum2 - expect2)
softmax = NetSoftmax()
output = softmax(x)
output_sum1 = output[0].asnumpy().sum(axis=1)
output_sum2 = output[1].asnumpy().sum(axis=0)
diff1 = np.abs(output_sum1 - expect1)
diff2 = np.abs(output_sum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
@ -198,3 +199,43 @@ def test_softmax_4d():
dx = Grad(Net())(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect_dx)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_softmax_functional():
"""
Feature: softmax function
Description: test the functional api of softmax
Expectation: run success
"""
x = Tensor(np.array([[0.1, 0.3, 0.6, -0.3],
[0.2, -0.6, 0.8, 0.6],
[0.6, -1.2, 0.4, 0.6]]).astype(np.float32))
expect1 = np.ones(3)
expect2 = np.ones(4)
error1 = expect1 * 1.0e-6
error2 = expect2 * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
softmax = F.softmax
output_0 = softmax(x)
output_1 = softmax(x, axis=-2)
output_sum1 = output_0.asnumpy().sum(axis=1)
output_sum2 = output_1.asnumpy().sum(axis=0)
diff1 = np.abs(output_sum1 - expect1)
diff2 = np.abs(output_sum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
softmax = F.softmax
output_0 = softmax(x)
output_1 = softmax(x, axis=-2)
output_sum1 = output_0.asnumpy().sum(axis=1)
output_sum2 = output_1.asnumpy().sum(axis=0)
diff1 = np.abs(output_sum1 - expect1)
diff2 = np.abs(output_sum2 - expect2)
assert np.all(diff1 < error1)
assert np.all(diff2 < error2)