From b6acac4960394d1e602cb4871938ed6d41c3b102 Mon Sep 17 00:00:00 2001 From: zx <2597798649@qq.com> Date: Tue, 6 Jul 2021 21:21:34 +0800 Subject: [PATCH] [fix][assistant][I3PYDB] fix bug in the Ascend operator SoftShrink and SoftShrinkGrad --- mindspore/core/ops/soft_shrink.h | 1 - mindspore/nn/layer/activation.py | 49 +++++++++++++++++++ .../ops/_grad_experimental/grad_nn_ops.py | 1 - mindspore/ops/operations/_grad_ops.py | 3 +- mindspore/ops/operations/nn_ops.py | 1 + 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/mindspore/core/ops/soft_shrink.h b/mindspore/core/ops/soft_shrink.h index 9d1aaa03534..ce9531d6324 100644 --- a/mindspore/core/ops/soft_shrink.h +++ b/mindspore/core/ops/soft_shrink.h @@ -37,7 +37,6 @@ AbstractBasePtr SoftShrinkInfer(const abstract::AnalysisEnginePtr &, const Primi const std::vector &input_args); using PrimSoftShrinkPtr = std::shared_ptr; - } // namespace ops } // namespace mindspore #endif // MINDSPORE_CORE_OPS_SOFTSHRINK_H_ diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 8ad278500d2..0b95d5ec8c5 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -39,6 +39,7 @@ __all__ = ['Softmax', 'HSwish', 'ELU', 'LogSigmoid', + 'SoftShrink', ] @@ -754,6 +755,53 @@ class LogSigmoid(Cell): ret = self.log(rec_exp_neg_input_1) return ret +class SoftShrink(Cell): + r""" + Applies the soft shrinkage function elementwise. + + .. math:: + \text{SoftShrink}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` must be no less than zero value for the Softshrink formulation. Default: 0.5. + + Inputs: + - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32. + Any number of additional dimensions. + + Outputs: + Tensor, has the same shape and data type as `input_x`. + + Raises: + TypeError: If lambd is not a float. + TypeError: If input_x is not a Tensor. + TypeError: If dtype of input_x is neither float16 nor float32. + ValueError: If lambd is less than 0. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> input_x = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mstype.float16) + >>> softshrink = nn.SoftShrink() + >>> output = softshrink(input_x) + >>> print(output) + [[ 0.02979 0.287 0.676 ] + [ 0.2837 0.1216 -0.6543 ]] + """ + + def __init__(self, lambd=0.5): + super(SoftShrink, self).__init__() + self.softshrink = P.SoftShrink(lambd) + + def construct(self, input_x): + output = self.softshrink(input_x) + return output _activation = { 'softmax': Softmax, @@ -770,6 +818,7 @@ _activation = { 'hswish': HSwish, 'hsigmoid': HSigmoid, 'logsigmoid': LogSigmoid, + 'softshrink': SoftShrink, } diff --git a/mindspore/ops/_grad_experimental/grad_nn_ops.py b/mindspore/ops/_grad_experimental/grad_nn_ops.py index 3ad9a853830..acb3f84dc31 100644 --- a/mindspore/ops/_grad_experimental/grad_nn_ops.py +++ b/mindspore/ops/_grad_experimental/grad_nn_ops.py @@ -33,7 +33,6 @@ def get_bprop_ctc_loss_v2(self): return bprop -"""nn_ops""" @bprop_getters.register(P.SoftShrink) def get_bprop_softshrink(self): diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 9eaa310d745..3da2deb2403 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -2206,8 +2206,9 @@ class SoftShrinkGrad(Primitive): Supported Platforms: ``Ascend`` """ + @prim_attr_register def __init__(self, lambd=0.5): self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output']) validator.check_value_type("lambd", lambd, [float], self.name) - validator.check_number("lambd", lambd, 0, Rel.GE, self.name) \ No newline at end of file + validator.check_number("lambd", lambd, 0, Rel.GE, self.name) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ba4efbe9c99..1ee7a2fe3c7 100755 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -8670,6 +8670,7 @@ class Conv3DTranspose(PrimitiveWithInfer): } return out + class SoftShrink(Primitive): r""" Applies the soft shrinkage function elementwise.