[fix][assistant][I3PYDB] fix bug in the Ascend operator SoftShrink and SoftShrinkGrad

This commit is contained in:
zx 2021-07-06 21:21:34 +08:00
parent 5d00d482e4
commit b6acac4960
5 changed files with 52 additions and 3 deletions

View File

@ -37,7 +37,6 @@ AbstractBasePtr SoftShrinkInfer(const abstract::AnalysisEnginePtr &, const Primi
const std::vector<AbstractBasePtr> &input_args);
using PrimSoftShrinkPtr = std::shared_ptr<SoftShrink>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SOFTSHRINK_H_

View File

@ -39,6 +39,7 @@ __all__ = ['Softmax',
'HSwish',
'ELU',
'LogSigmoid',
'SoftShrink',
]
@ -754,6 +755,53 @@ class LogSigmoid(Cell):
ret = self.log(rec_exp_neg_input_1)
return ret
class SoftShrink(Cell):
r"""
Applies the soft shrinkage function elementwise.
.. math::
\text{SoftShrink}(x) =
\begin{cases}
x - \lambda, & \text{ if } x > \lambda \\
x + \lambda, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` must be no less than zero value for the Softshrink formulation. Default: 0.5.
Inputs:
- **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
Any number of additional dimensions.
Outputs:
Tensor, has the same shape and data type as `input_x`.
Raises:
TypeError: If lambd is not a float.
TypeError: If input_x is not a Tensor.
TypeError: If dtype of input_x is neither float16 nor float32.
ValueError: If lambd is less than 0.
Supported Platforms:
``Ascend``
Examples:
>>> input_x = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mstype.float16)
>>> softshrink = nn.SoftShrink()
>>> output = softshrink(input_x)
>>> print(output)
[[ 0.02979 0.287 0.676 ]
[ 0.2837 0.1216 -0.6543 ]]
"""
def __init__(self, lambd=0.5):
super(SoftShrink, self).__init__()
self.softshrink = P.SoftShrink(lambd)
def construct(self, input_x):
output = self.softshrink(input_x)
return output
_activation = {
'softmax': Softmax,
@ -770,6 +818,7 @@ _activation = {
'hswish': HSwish,
'hsigmoid': HSigmoid,
'logsigmoid': LogSigmoid,
'softshrink': SoftShrink,
}

View File

@ -33,7 +33,6 @@ def get_bprop_ctc_loss_v2(self):
return bprop
"""nn_ops"""
@bprop_getters.register(P.SoftShrink)
def get_bprop_softshrink(self):

View File

@ -2206,8 +2206,9 @@ class SoftShrinkGrad(Primitive):
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self, lambd=0.5):
self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
validator.check_value_type("lambd", lambd, [float], self.name)
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)

View File

@ -8670,6 +8670,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
}
return out
class SoftShrink(Primitive):
r"""
Applies the soft shrinkage function elementwise.