diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ddcf2772423..00745b4fa55 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -423,6 +423,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared("Elu"); inline const PrimitivePtr kPrimRelu6 = std::make_shared(kReLU6); inline const PrimitivePtr kPrimReluV2 = std::make_shared(kReLUV2); inline const PrimitivePtr kPrimPRelu = std::make_shared("PReLU"); +inline const PrimitivePtr kPrimSelu = std::make_shared("SeLU"); inline const PrimitivePtr kPrimSoftplus = std::make_shared("Softplus"); inline const PrimitivePtr kPrimSoftplusGrad = std::make_shared("SoftplusGrad"); inline const PrimitivePtr kPrimZeros = std::make_shared("Zeros"); diff --git a/mindspore/core/ops/selu.cc b/mindspore/core/ops/selu.cc new file mode 100644 index 00000000000..51794d01cd7 --- /dev/null +++ b/mindspore/core/ops/selu.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/selu.h" +#include +#include +#include +#include +#include + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SeLUInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 1, prim_name); + (void)CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); + auto x = input_args[0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto shape_element = x->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return shape_element; +} +TypePtr SeLUInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = 1; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, + prim_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto x_type = input_args[0]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name); + return x_type; +} +} // namespace +AbstractBasePtr SeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto type = SeLUInferType(primitive, input_args); + auto shape = SeLUInferShape(primitive, input_args); + return abstract::MakeAbstract(shape, type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SeLU, prim::kPrimSelu, SeLUInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/selu.h b/mindspore/core/ops/selu.h new file mode 100644 index 00000000000..36ae6ff72b7 --- /dev/null +++ b/mindspore/core/ops/selu.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_OPS_SELU_H_ +#define MINDSPORE_CORE_OPS_SELU_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "ops/op_utils.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSeLU = "SeLU"; +class SeLU : public PrimitiveC { + public: + SeLU() : PrimitiveC(kNameSeLU) { InitIOName({"x"}, {"output"}); } + ~SeLU() = default; + MS_DECLARE_PARENT(SeLU, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SELU_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index aba65cbcb59..45c50813d53 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -513,6 +513,7 @@ from .mish import _mish_tbe from .mul_no_nan import _mul_no_nan_tbe from .mul_no_nan_ds import _mul_no_nan_ds_tbe from .selu import _selu_tbe +from .selu_ds import _selu_ds_tbe from .centralization import _centralization_tbe from .exp_ds import _exp_ds_tbe from .log_ds import _log_ds_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/selu_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/selu_ds.py new file mode 100644 index 00000000000..83c37ddfac7 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/selu_ds.py @@ -0,0 +1,40 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SeLU op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +selu_op_info = TBERegOp("SeLU") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("selu.so") \ + .compute_cost(10) \ + .kernel_name("selu") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(selu_op_info) +def _selu_ds_tbe(): + """Selu TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 3ab0173010d..63fc8418744 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -563,7 +563,7 @@ class Mish(PrimitiveWithInfer): return x_dtype -class SeLU(PrimitiveWithInfer): +class SeLU(Primitive): r""" Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise. @@ -609,14 +609,6 @@ class SeLU(PrimitiveWithInfer): """Initialize SeLU""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - valid_dtypes = [mstype.float16, mstype.float32] - validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) - return x_dtype - class ReLU6(PrimitiveWithCheck): r"""