!21607 [assistant][ops]New operator implementation, Celu

Merge pull request !21607 from yangwm/ops
This commit is contained in:
i-robot 2021-09-29 06:54:40 +00:00 committed by Gitee
commit 1407dbec37
10 changed files with 271 additions and 1 deletions

View File

@ -266,6 +266,7 @@ inline const PrimitivePtr kPrimReal = std::make_shared<Primitive>(kReal);
inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shared<Primitive>("ExtractVolumePatches");
// NN
inline const PrimitivePtr kPrimCeLU = std::make_shared<Primitive>("CeLU");
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram");
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/celu.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 1, prim_name);
auto shape_element = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
return shape_element;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("CeLU input numbers", input_args.size(), kEqual, 1, prim_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr CeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(CeLU, prim::kPrimCeLU, CeLUInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

45
mindspore/core/ops/celu.h Normal file
View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CELU_H_
#define MINDSPORE_CORE_OPS_CELU_H_
#include <map>
#include <memory>
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCeLU = "CeLU";
class CeLU : public PrimitiveC {
public:
CeLU() : PrimitiveC(kNameCeLU) { InitIOName({"x"}, {"output"}); }
~CeLU() = default;
MS_DECLARE_PARENT(CeLU, PrimitiveC);
void Init() {}
};
AbstractBasePtr CeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCeLUPtr = std::shared_ptr<CeLU>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CELU_H_

View File

@ -41,9 +41,60 @@ __all__ = ['Softmax',
'LogSigmoid',
'SoftShrink',
'HShrink',
'CELU',
]
class CELU(Cell):
r"""
Continuously differentiable exponential linear units activation function.
Applies the continuously differentiable exponential linear units function element-wise.
.. math::
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
It returns element-wise :math:`\max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`.
The picture about CELU looks like this `CELU <https://arxiv.org/abs/1704.07483>`_.
Args:
alpha (float): The :math:`\alpha` value for the Celu formulation. Default: 1.0
Inputs:
- **x** (Tensor) - The input of CELU. The required dtype is float16 or float32.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Outputs:
Tensor, with the same type and shape as the `x`.
Raises:
TypeError: If `alpha` is not a float.
ValueError: If `alpha` has the value of 0.
TypeError: If `x` is not a Tensor.
TypeError: If the dtype of 'input_x' is neither float16 nor float32.
Supported Platforms:
``Ascend``
Examples:
>>> x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]), mindspore.float32)
>>> celu = nn.CELU()
>>> output = celu(x)
>>> print(output)
[-0.86466473 -0.63212055 1. 2. ]
"""
def __init__(self, alpha=1.0):
"""Initialize CELU."""
super(CELU, self).__init__()
self.celu = P.CeLU(alpha=alpha)
def construct(self, x):
return self.celu(x)
class Softmax(Cell):
r"""
Softmax activation function.

View File

@ -70,3 +70,19 @@ def get_bprop_hshrink(self):
return (dx,)
return bprop
@bprop_getters.register(P.CeLU)
def get_bprop_celu(self):
"""Grad definition for `CeLU` operation."""
alpha = self.alpha
greater_equal = P.GreaterEqual()
less = P.Less()
def bprop(x, out, dout):
greater = greater_equal(x, 0.0)
lesser = less(x, 0.0)
dx = dout * (greater * 1.0 + lesser * (out / alpha + 1.0))
return (dx,)
return bprop

View File

@ -14,6 +14,7 @@
# ============================================================================
"""tbe ops"""
from .celu import _celu_tbe
from .abs import _abs_tbe
from .inplace_add import _inplace_add_tbe
from .inplace_sub import _inplace_sub_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Celu op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
celu_op_info = TBERegOp("CeLU") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("celu.so") \
.compute_cost(10) \
.kernel_name("celu") \
.partial_flag(True) \
.attr("alpha", "optional", "float", "all", "1.0") \
.attr("alpha2", "optional", "float", "all", "1.0") \
.attr("alpha3", "optional", "float", "all", "1.0") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(celu_op_info)
def _celu_tbe():
"""CeLU TBE register"""
return

View File

@ -70,7 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
GeLU, Gelu, FastGeLU, FastGelu, Elu,
GeLU, Gelu, FastGeLU, FastGelu, Elu, CeLU,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
LogSoftmax, MaxPool3D, AvgPool3D,
MaxPool, DataFormatDimMap,
@ -123,6 +123,7 @@ from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
from ._inner_ops import (MatmulDDS, DSDMatmul, NonZero)
__all__ = [
'CeLU',
'Ger',
'Unique',
'ReverseSequence',

View File

@ -89,6 +89,57 @@ def _update_attr_by_format(arg_value, arg_format):
return ret
class CeLU(Primitive):
r"""
Computes CeLU (Continuously differentiable exponential linear units) of input tensors element-wise.
.. math::
\text{CeLU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
It returns :math:`\max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))` element-wise.
The picture about CeLU looks like this `CeLU <https://arxiv.org/abs/1704.07483>`_.
Args:
alpha (float): The :math:`\alpha` value for the Celu formulation. Default: 1.0
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with dtype of float16 and float32.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Raises:
TypeError: If `alpha` is not a float.
ValueError: If `alpha` has the value of 0.
TypeError: If `input_x` is not a Tensor.
TypeError: If the dtype of 'input_x' is neither float16 nor float32.
Supported Platforms:
``Ascend``
Examples:
>>> input_x = Tensor(np.array([-2.0, -1.0, 1.0, 2.0]), mindspore.float32)
>>> celu = ops.CeLU(alpha=1.0)
>>> output = celu(input_x)
>>> print(output)
[-0.86466473 -0.63212055 1. 2. ]
"""
@prim_attr_register
def __init__(self, alpha=1.0):
"""Initialize CeLU"""
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_float(alpha, 0.0, Rel.NE, "alpha", self.name)
self.alpha = alpha
self.alpha2 = alpha
self.add_prim_attr('alpha', self.alpha)
self.add_prim_attr('alpha2', self.alpha2)
class Flatten(PrimitiveWithInfer):
r"""
Flattens a tensor without changing its batch size on the 0-th axis.

View File

@ -1729,6 +1729,10 @@ test_case_math_ops = [
]
test_case_nn_ops = [
('CeLU', {
'block': P.CeLU(),
'desc_inputs': [[1, 2, 3]],
'desc_bprop': [[1, 2, 3]]}),
('BiasAdd', {
'block': P.BiasAdd(),
'desc_inputs': [[1, 3, 3, 3], [3]],