diff --git a/.gitignore b/.gitignore index 9a0fc864723..43b0d1e5a99 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,9 @@ mindspore/.commit_id # lite test file mindspore/lite/test/do_test/ +HSigmoid_Test/ +.vs + # lite opencl compile file *.cl.inc diff --git a/mindspore/core/ops/grad/hsigmoid_grad.cc b/mindspore/core/ops/grad/hsigmoid_grad.cc index 562aeed3635..c129d61988b 100644 --- a/mindspore/core/ops/grad/hsigmoid_grad.cc +++ b/mindspore/core/ops/grad/hsigmoid_grad.cc @@ -60,4 +60,4 @@ AbstractBasePtr HSigmoidGradInfer(const abstract::AnalysisEnginePtr &, const Pri } REGISTER_PRIMITIVE_EVAL_IMPL(HSigmoidGrad, prim::kPrimHSigmoidGrad, HSigmoidGradInfer, nullptr, true); } // namespace ops -} // namespace mindspore \ No newline at end of file +} // namespace mindspore diff --git a/mindspore/core/ops/grad/hsigmoid_grad.h b/mindspore/core/ops/grad/hsigmoid_grad.h index 1e35fb25f54..0aebfa4e4eb 100644 --- a/mindspore/core/ops/grad/hsigmoid_grad.h +++ b/mindspore/core/ops/grad/hsigmoid_grad.h @@ -40,4 +40,4 @@ using PrimHSigmoidGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore -#endif // MINDSPORE_CORE_OPS_HSIGMOID_GRAD_H_ \ No newline at end of file +#endif // MINDSPORE_CORE_OPS_HSIGMOID_GRAD_H_ diff --git a/mindspore/core/ops/hsigmoid.cc b/mindspore/core/ops/hsigmoid.cc new file mode 100644 index 00000000000..83818886a7f --- /dev/null +++ b/mindspore/core/ops/hsigmoid.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/hsigmoid.h" +#include +#include +#include + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + return std::make_shared(in_shape); +} +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { + MS_LOG(EXCEPTION) << "nullptr"; + } + std::map types; + const std::set valid_types = {kFloat16, kFloat32}; + types.emplace("input_x", input_args[0]->BuildType()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); +} + +} // namespace +AbstractBasePtr HSigmoidInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(HSigmoid, prim::kPrimHSigmoid, HSigmoidInfer, nullptr, true); + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/hsigmoid.h b/mindspore/core/ops/hsigmoid.h new file mode 100644 index 00000000000..f094e168e3b --- /dev/null +++ b/mindspore/core/ops/hsigmoid.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "ops/primitive_c.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameHSigmoid = "HSigmoid"; +class HSigmoid : public PrimitiveC { + public: + HSigmoid() : PrimitiveC(kNameHSigmoid) { InitIOName({"input_x"}, {"output"}); } + ~HSigmoid() = default; + MS_DECLARE_PARENT(HSigmoid, PrimitiveC); // come from ops/primitive_c.h +}; + +AbstractBasePtr HSigmoidInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimHSigmoidPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 0b95d5ec8c5..fbd28740fe5 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -685,14 +685,14 @@ class HSigmoid(Cell): TypeError: If dtype of `x` is neither float16 nor float32. Supported Platforms: - ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16) >>> hsigmoid = nn.HSigmoid() >>> result = hsigmoid(x) >>> print(result) - [0.3333 0.1666 0.5 0.833 0.6665] + [0.3333 0.1666 0.5 0.8335 0.6665] """ def __init__(self): @@ -700,8 +700,8 @@ class HSigmoid(Cell): super(HSigmoid, self).__init__() self.hsigmoid = P.HSigmoid() - def construct(self, x): - return self.hsigmoid(x) + def construct(self, input_x): + return self.hsigmoid(input_x) class LogSigmoid(Cell): diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 747d4c29871..6c1816ed539 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -392,3 +392,4 @@ from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe from .soft_shrink import _soft_shrink_tbe from .soft_shrink_grad import _soft_shrink_grad_tbe from .hsigmoid_grad import _hsigmoid_grad_tbe +from .hsigmoid import _hsigmoid_tbe diff --git a/mindspore/ops/_op_impl/tbe/hsigmoid.py b/mindspore/ops/_op_impl/tbe/hsigmoid.py new file mode 100644 index 00000000000..f5c8d44a10d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/hsigmoid.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""HSigmoid op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +hsigmoid_op_info = TBERegOp("HSigmoid") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("hard_sigmoid.so") \ + .compute_cost(10) \ + .kernel_name("hard_sigmoid") \ + .partial_flag(True) \ + .attr("alpha", "optional", "float", "all", "0.16666666") \ + .attr("beta", "optional", "float", "all", "0.5") \ + .input(0, "input_x", False, "required", "all") \ + .output(0, "output_y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(hsigmoid_op_info) +def _hsigmoid_tbe(): + """HSigmoid TBE register""" + return diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 3a597ae474a..cb469ae76f0 100755 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -794,55 +794,6 @@ class Sigmoid(PrimitiveWithInfer): return input_x -class HSigmoid(PrimitiveWithInfer): - r""" - Hard sigmoid activation function. - - Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. - - Hard sigmoid is defined as: - - .. math:: - - \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})), - - where :math:`x_i` is an element of the input Tensor. - - Inputs: - - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of - additional dimensions, with float16 or float32 data type. - - Outputs: - Tensor, with the same type and shape as the `input_x`. - - Raises: - TypeError: If `input_x` is not a Tensor. - TypeError: If dtype of `input_x` is neither float16 nor float32. - - Supported Platforms: - ``GPU`` ``CPU`` - - Examples: - >>> hsigmoid = ops.HSigmoid() - >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16) - >>> result = hsigmoid(input_x) - >>> print(result) - [0.3333 0.1666 0.5 0.833 0.6665] - """ - - @prim_attr_register - def __init__(self): - """Initialize HSigmoid.""" - self.init_prim_io_names(inputs=['x'], outputs=['output']) - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) - return x_dtype - - class Tanh(PrimitiveWithInfer): r""" Tanh activation function. @@ -8716,3 +8667,43 @@ class SoftShrink(Primitive): """Initialize SoftShrink""" validator.check_value_type("lambd", lambd, [float], self.name) validator.check_number("lambd", lambd, 0, Rel.GE, self.name) + +class HSigmoid(Primitive): + r""" + Hard sigmoid activation function. + + Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. + + Hard sigmoid is defined as: + + .. math:: + + \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})), + + where :math:`x_i` is an element of the input Tensor. + + Inputs: + - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of + additional dimensions, with float16 or float32 data type. + + Outputs: + Tensor, with the same type and shape as the `input_x`. + + Raises: + TypeError: If `input_x` is not a Tensor. + TypeError: If dtype of `input_x` is neither float16 nor float32. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> hsigmoid = ops.HSigmoid() + >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mstype.float16) + >>> result = hsigmoid(input_x) + >>> print(result) + [0.3333 0.1666 0.5 0.8335 0.6665] + """ + @prim_attr_register + def __init__(self): + """Initialize HSigmoid.""" + self.init_prim_io_names(inputs=['input_x'], outputs=['output']) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8bb7e6d21d7..5820433ea9f 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2159,6 +2159,11 @@ test_case_nn_ops = [ 'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), mstype.float16), Tensor(np.array([[-4, -3, -2], [1, 2, 4]]), mstype.float16)], 'skip': ['backward']}), + ('HSigmoid', { + 'block': P.HSigmoid(), + 'desc_inputs': [Tensor(np.array([[-4, 4, 1]]), mstype.float32)], + 'desc_bprop': [Tensor(np.array([[0, 1, 0.6666]]), mstype.float32)], + 'skip': ['backward']}), ] test_case_array_ops = [