!35667 [kernel]add celu vmap

Merge pull request !35667 from 张学同/celu
This commit is contained in:
i-robot 2022-06-10 08:21:43 +00:00 committed by Gitee
commit 402431e754
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 317 additions and 10 deletions

View File

@ -57,6 +57,7 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
:nosignatures:
:template: classtemplate.rst
mindspore.ops.celu
mindspore.ops.fast_gelu
mindspore.ops.hardshrink
mindspore.ops.hardswish

View File

@ -277,6 +277,28 @@ mindspore.Tensor
- **ValueError** - 如果当前Tensor的维度信息与输入参数 `input_y` 不相同。
- **ValueError** - 如果当前Tensor或输入参数 `input_y` 不是2维或3维。
.. py:method:: celu(alpha=1.0)
celu激活函数按输入元素计算输出公式定义如下
.. math::
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
**参数:**
- **alpha** (float) - celu公式定义的阈值 :math:`\alpha` 。默认值1.0。
**返回:**
Tensorshape和数据类型与输入相同。
**异常:**
- **TypeError** - `alpha` 不是float。
- **ValueError** - `alpha` 的值为零。
- **TypeError** - `x` 不是tensor。
- **TypeError** - `x` 的dtype既不是float16也不是float32。
.. py:method:: clip(xmin, xmax, dtype=None)
裁剪Tensor中的值。

View File

@ -0,0 +1,25 @@
mindspore.ops.celu
========================
.. py:function:: mindspore.ops.celu(x, alpha=1.0)
celu激活函数按输入元素计算输出公式定义如下
.. math::
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
**参数:**
- **x** (Tensor) - celu的输入数据类型为float16或float32。
- **alpha** (float) - celu公式定义的阈值 :math:`\alpha` 。默认值1.0。
**返回:**
Tensorshape和数据类型与输入相同。
**异常:**
- **TypeError** - `alpha` 不是float。
- **ValueError** - `alpha` 的值为零。
- **TypeError** - `x` 不是tensor。
- **TypeError** - `x` 的dtype既不是float16也不是float32。

View File

@ -247,6 +247,7 @@ BuiltInTypeMap &GetMethodMap() {
{"repeat", std::string("repeat")}, // C.repeat_elements
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
{"cdist", std::string("cdist")}, // P.cdist
{"celu", std::string("celu")}, // P.celu
{"hardshrink", std::string("hardshrink")}, // P.hshrink
{"hardswish", std::string("hardswish")}, // P.HSwish
{"soft_shrink", std::string("soft_shrink")}, // P.SoftShrink

View File

@ -147,11 +147,10 @@ int CdistCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
<< in_shape1.size() << ", kernel_name_ " << kernel_name_;
return KRET_RESIZE_FAILED;
}
batch_ = 0;
batch_ = 1;
for (size_t i = 0; i < in_shape_size - kCdistInputDimsMin; i++) {
batch_ += in_shape0[i];
batch_ *= in_shape0[i];
}
batch_ = (batch_ <= 0) ? 1 : batch_;
r0_ = in_shape0[in_shape_size - 2];
m_ = in_shape0[in_shape_size - 1];

View File

@ -69,7 +69,9 @@ bool CeluCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
auto in_data = static_cast<float *>(inputs[0]->addr);
auto out_data = static_cast<float *>(outputs[0]->addr);
auto task = [this, in_data, out_data](size_t start, size_t end) { Celu(in_data, (end - start), out_data, alpha_); };
auto task = [this, in_data, out_data](size_t start, size_t end) {
Celu(in_data + start, (end - start), out_data + start, alpha_);
};
ParallelLaunchAutoSearch(task, input_elements_, this, &parallel_search_info_, pool_);
return true;

View File

@ -29,7 +29,7 @@ bool CeluGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::CeLU>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Cdist ops failed!";
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast celu ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();

View File

@ -1497,6 +1497,13 @@ def repeat(x, repeats, axis=None):
return P.Concat(axis)(repeated_subs)
def celu(x, alpha=1.0):
r"""
Apply the Hard Shrink function for a tensor. Calculates the output according to the input elements.
"""
return P.CeLU(alpha)(x)
def hardshrink(x, lambd=0.5):
r"""
Apply the Hard Shrink function for a tensor. Calculates the output according to the input elements.

View File

@ -3811,6 +3811,45 @@ class Tensor(Tensor_):
s, _, _ = svd_op(full_matrices, compute_uv)(self)
return s
def celu(self, alpha=1.0):
r"""
Computes celu (Continuously differentiable exponential linear units) of input tensors element-wise.
The formula is defined as follows:
.. math::
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
It returns :math:`\max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))` element-wise.
The picture about celu looks like this `celu <https://arxiv.org/abs/1704.07483>`_.
Args:
alpha (float): The :math:`\alpha` value for the Celu formulation. Default: 1.0
Returns:
Tensor, has the same shape and data type as self.
Raises:
TypeError: If `alpha` is not a float.
ValueError: If `alpha` has the value of 0.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]), mindspore.float32)
>>> print(x.celu())
[-0.86466473 -0.63212055 1. 2. ]
"""
self._init_check()
return tensor_operator_registry.get('celu')(alpha)(self)
def hardshrink(self, lambd=0.5):
r"""
Apply the Hard Shrink function for tensor. Calculates the output according to the input elements.

View File

@ -435,6 +435,7 @@ def get_lrn_grad_vmap_rule(prim, axis_size):
get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.CeLU)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule)

View File

@ -206,6 +206,7 @@ from .math_func import (
)
from .nn_func import (
adaptive_avgpool2d,
celu,
deformable_conv2d,
fast_gelu,
hardshrink,

View File

@ -116,6 +116,44 @@ softsign_ = P.Softsign()
hardswish_ = P.HSwish()
def celu(x, alpha=1.0):
r"""
Computes celu (Continuously differentiable exponential linear units) of input tensors element-wise.
.. math::
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
It returns :math:`\max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))` element-wise.
The picture about celu looks like this `celu <https://arxiv.org/abs/1704.07483>`_.
Args:
x (Tensor): The input of celu with data type of float16 or float32.
alpha (float): The :math:`\alpha` value for the Celu formulation. Default: 1.0
Returns:
Tensor, has the same data type and shape as the input.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Raises:
TypeError: If `alpha` is not a float.
ValueError: If `alpha` has the value of 0.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Examples:
>>> x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]), mindspore.float32)
>>> output = ops.celu(x, alpha=1.0)
>>> print(output)
[-0.86466473 -0.63212055 1. 2. ]
"""
celu_op = P.CeLU(alpha)
return celu_op(x)
def fast_gelu(x):
r"""
Fast Gaussian Error Linear Units activation function.
@ -830,6 +868,7 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
__all__ = [
'adaptive_avgpool2d',
'celu',
'deformable_conv2d',
'fast_gelu',
'hardshrink',

View File

@ -963,6 +963,7 @@ tensor_operator_registry.register('inv', inv)
tensor_operator_registry.register('invert', invert)
tensor_operator_registry.register('matrix_band_part', matrix_band_part)
tensor_operator_registry.register('padding', padding)
tensor_operator_registry.register('celu', P.CeLU)
tensor_operator_registry.register('hardshrink', P.HShrink)
tensor_operator_registry.register('hardswish', P.HSwish)
tensor_operator_registry.register('soft_shrink', P.SoftShrink)

View File

@ -17,7 +17,9 @@ import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor, context
from mindspore.ops.functional import vmap
from mindspore.ops import functional as F
from mindspore.common.api import ms_function
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -41,9 +43,8 @@ def test_celu_op(data_type):
Description: test the celu alpha = 1.0.
Expectation: match to np benchmark.
"""
error = 1e-6
if data_type == np.float16:
error = 1e-3
error = 1e-3
celu = CeluTEST(1.)
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
@ -55,3 +56,86 @@ def test_celu_op(data_type):
output = celu(x)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
def test_celu_func(data_type):
"""
Feature: Celu cpu kernel
Description: test the celu alpha = 1.0.
Expectation: match to np benchmark.
"""
error = 1e-3
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
context.set_context(mode=context.GRAPH_MODE)
output = F.celu(x, 1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = F.celu(x, 1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
def test_celu_tensor(data_type):
"""
Feature: Celu cpu kernel
Description: test the celu alpha = 1.0.
Expectation: match to np benchmark.
"""
error = 1e-3
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
context.set_context(mode=context.GRAPH_MODE)
output = x.celu(1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = x.celu(1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_celu_vmap():
"""
Feature: celu cpu kernel.
Description: test celu vmap feature.
Expectation: Success.
"""
error = 1e-3
def cal_celu(x):
return P.CeLU(1.0)(x)
x = Tensor(np.array([[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0],
[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0],
[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0]]).astype(np.float32))
expect = np.array([[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.]]).astype(np.float32)
vmap_celu = vmap(cal_celu, in_axes=(0), out_axes=0)
output = vmap_celu(x)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@ms_function
def manually_batched(xs):
output = []
for i in range(xs.shape[0]):
output.append(cal_celu(xs[i]))
return F.stack(output)
expect_m = manually_batched(x)
np.testing.assert_allclose(output.asnumpy(), expect_m.asnumpy(), rtol=error)

View File

@ -17,7 +17,9 @@ import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor, context
from mindspore.ops.functional import vmap
from mindspore.ops import functional as F
from mindspore.common.api import ms_function
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
@ -53,3 +55,86 @@ def test_celu_op(data_type):
output = celu(x)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
def test_celu_func(data_type):
"""
Feature: Celu cpu kernel
Description: test the celu alpha = 1.0.
Expectation: match to np benchmark.
"""
error = 1e-3
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
context.set_context(mode=context.GRAPH_MODE)
output = F.celu(x, 1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = F.celu(x, 1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("data_type", [np.float32, np.float16])
def test_celu_tensor(data_type):
"""
Feature: Celu gpu kernel
Description: test the celu alpha = 1.0.
Expectation: match to np benchmark.
"""
error = 1e-3
x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]).astype(data_type))
expect = np.array([-0.86468184, -0.6321212, 1., 2.]).astype(data_type)
context.set_context(mode=context.GRAPH_MODE)
output = x.celu(1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = x.celu(1.0)
print(output)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_celu_vmap():
"""
Feature: celu gpu kernel.
Description: test celu vmap feature.
Expectation: Success.
"""
error = 1e-3
def cal_celu(x):
return P.CeLU(1.0)(x)
x = Tensor(np.array([[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0],
[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0],
[-2.0, -1.0, 1.0, 2.0], [-2.0, -1.0, 1.0, 2.0]]).astype(np.float32))
expect = np.array([[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.],
[-0.86468184, -0.6321212, 1., 2.], [-0.86468184, -0.6321212, 1., 2.]]).astype(np.float32)
vmap_celu = vmap(cal_celu, in_axes=(0), out_axes=0)
output = vmap_celu(x)
np.testing.assert_allclose(output.asnumpy(), expect, rtol=error)
@ms_function
def manually_batched(xs):
output = []
for i in range(xs.shape[0]):
output.append(cal_celu(xs[i]))
return F.stack(output)
expect_m = manually_batched(x)
np.testing.assert_allclose(output.asnumpy(), expect_m.asnumpy(), rtol=error)