diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e754f994382..a5aa26c06f1 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -371,6 +371,7 @@ inline const PrimitivePtr kSoftmaxGradExt = std::make_shared("Softmax inline const PrimitivePtr kSquareSumV1 = std::make_shared("SquareSumV1"); inline const PrimitivePtr kFusedMulAdd = std::make_shared("FusedMulAdd"); inline const PrimitivePtr kPrimSoftShrink = std::make_shared("SoftShrink"); +inline const PrimitivePtr kPrimSoftShrinkGrad = std::make_shared("SoftShrinkGrad"); // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); diff --git a/mindspore/core/ops/grad/soft_shrink_grad.cc b/mindspore/core/ops/grad/soft_shrink_grad.cc new file mode 100644 index 00000000000..52508fab8e3 --- /dev/null +++ b/mindspore/core/ops/grad/soft_shrink_grad.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/soft_shrink_grad.h" + +#include +#include +#include +#include +#include + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SoftShrinkGradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name()); + auto input_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto prim_name = primitive->name(); + CheckAndConvertUtils::Check("input_grad_shape", input_grad_shape, kEqual, "input_x_shape", input_x_shape, prim_name, + TypeError); + return std::make_shared(input_grad_shape); +} + +TypePtr SoftShrinkGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + const std::set valid_types = {kFloat16, kFloat32}; + std::map types; + types.emplace("input_grad", input_args[0]->BuildType()); + types.emplace("input_x", input_args[1]->BuildType()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); +} +} // namespace + +AbstractBasePtr SoftShrinkGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(SoftShrinkGradInferType(primitive, input_args), + SoftShrinkGradInferShape(primitive, input_args)->shape()); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SoftShrinkGrad, prim::kPrimSoftShrinkGrad, SoftShrinkGradInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/soft_shrink_grad.h b/mindspore/core/ops/grad/soft_shrink_grad.h new file mode 100644 index 00000000000..248e6983162 --- /dev/null +++ b/mindspore/core/ops/grad/soft_shrink_grad.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SOFTSHRINK_GRAD_H_ +#define MINDSPORE_CORE_OPS_SOFTSHRINK_GRAD_H_ +#include +#include +#include +#include + +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSoftShrinkGrad = "SoftShrinkGrad"; +class SoftShrinkGrad : public PrimitiveC { + public: + SoftShrinkGrad() : PrimitiveC(kNameSoftShrinkGrad) { InitIOName({"input_grad", "input_x"}, {"output"}); } + ~SoftShrinkGrad() = default; + MS_DECLARE_PARENT(SoftShrinkGrad, PrimitiveC); +}; +AbstractBasePtr SoftShrinkGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSoftShrinkGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_SOFTSHRINK_GRAD_H_ diff --git a/mindspore/ops/_grad_experimental/grad_nn_ops.py b/mindspore/ops/_grad_experimental/grad_nn_ops.py index 7b811c0b35d..3ad9a853830 100644 --- a/mindspore/ops/_grad_experimental/grad_nn_ops.py +++ b/mindspore/ops/_grad_experimental/grad_nn_ops.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================ + """Define the grad rules of neural network related operations.""" from .._grad.grad_base import bprop_getters from .. import operations as P from ..composite.multitype_ops.zeros_like_impl import zeros_like - +from ..operations import _grad_ops as G @bprop_getters.register(P.CTCLossV2) def get_bprop_ctc_loss_v2(self): @@ -31,3 +32,16 @@ def get_bprop_ctc_loss_v2(self): return grad, zeros_like(targets), zeros_like(input_lengths), zeros_like(target_lengths) return bprop + +"""nn_ops""" + +@bprop_getters.register(P.SoftShrink) +def get_bprop_softshrink(self): + """Grad definition for `SoftShrink` operation.""" + input_grad = G.SoftShrinkGrad(self.lambd) + + def bprop(input_x, out, dout): + dx = input_grad(dout, input_x) + return (dx,) + + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 29dfb4b71be..d1fee1221a9 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -387,3 +387,4 @@ from .reciprocal_ds import _reciprocal_ds_tbe from .ctc_loss_v2 import _ctc_loss_v2_tbe from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe from .soft_shrink import _soft_shrink_tbe +from .soft_shrink_grad import _soft_shrink_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/soft_shrink_grad.py b/mindspore/ops/_op_impl/tbe/soft_shrink_grad.py new file mode 100644 index 00000000000..703d08ab8c3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/soft_shrink_grad.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SoftShrinkGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +soft_shrink_grad_op_info = TBERegOp("SoftShrinkGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("soft_shrink_grad.so") \ + .compute_cost(10) \ + .kernel_name("soft_shrink_grad") \ + .partial_flag(True) \ + .attr("lambd", "optional", "float", "all", "0.5") \ + .input(0, "input_grad", False, "required", "all") \ + .input(1, "input_x", False, "required", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(soft_shrink_grad_op_info) +def _soft_shrink_grad_tbe(): + """SoftShrinkGrad TBE register""" + return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 02aaa469f77..9eaa310d745 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -2181,3 +2181,33 @@ class MaskedSelectGrad(PrimitiveWithInfer): def infer_dtype(self, x, mask, grad): return x + + +class SoftShrinkGrad(Primitive): + r""" + Gradients for SoftShrink operation. + + Args: + lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5. + + Inputs: + - **input_grad** (Tensor) - The input gradient. + - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32. + Any number of additional dimensions. + + Outputs: + output - Tensor, has the same shape and data type as input_x. + + Raises: + TypeError: If lambd is not a float. + TypeError: If dtype of input_x is neither float16 nor float32. + ValueError: If lambd is less than to 0. + + Supported Platforms: + ``Ascend`` + """ + @prim_attr_register + def __init__(self, lambd=0.5): + self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output']) + validator.check_value_type("lambd", lambd, [float], self.name) + validator.check_number("lambd", lambd, 0, Rel.GE, self.name) \ No newline at end of file diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 30aea8c2cc3..d92b048cca3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2149,6 +2149,12 @@ test_case_nn_ops = [ 'block': P.SoftShrink(), 'desc_inputs': [Tensor(np.array([[0.5297, 0.7871, 1.1754], [0.7836, 0.6218, -1.1542]]), mstype.float32)], 'desc_bprop': [Tensor(np.array([[0, 0.4, 1], [1, 2, 4]]), mstype.float32)]}), + ('SoftShrinkGrad', { + 'block': G.SoftShrinkGrad(), + 'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), mstype.float16), + Tensor(np.array([[-3, -2, 0], [1, 2, 4]]), mstype.float16)], + 'desc_bprop': [], + 'skip': ['backward']}), ] test_case_array_ops = [