[fix][assistant][I3PYDB] fix bug in the Ascend operator SoftShrink and SoftShrinkGrad
This commit is contained in:
parent
5d00d482e4
commit
b6acac4960
|
@ -37,7 +37,6 @@ AbstractBasePtr SoftShrinkInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimSoftShrinkPtr = std::shared_ptr<SoftShrink>;
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_SOFTSHRINK_H_
|
||||
|
|
|
@ -39,6 +39,7 @@ __all__ = ['Softmax',
|
|||
'HSwish',
|
||||
'ELU',
|
||||
'LogSigmoid',
|
||||
'SoftShrink',
|
||||
]
|
||||
|
||||
|
||||
|
@ -754,6 +755,53 @@ class LogSigmoid(Cell):
|
|||
ret = self.log(rec_exp_neg_input_1)
|
||||
return ret
|
||||
|
||||
class SoftShrink(Cell):
|
||||
r"""
|
||||
Applies the soft shrinkage function elementwise.
|
||||
|
||||
.. math::
|
||||
\text{SoftShrink}(x) =
|
||||
\begin{cases}
|
||||
x - \lambda, & \text{ if } x > \lambda \\
|
||||
x + \lambda, & \text{ if } x < -\lambda \\
|
||||
0, & \text{ otherwise }
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
lambd: the :math:`\lambda` must be no less than zero value for the Softshrink formulation. Default: 0.5.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
|
||||
Any number of additional dimensions.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If lambd is not a float.
|
||||
TypeError: If input_x is not a Tensor.
|
||||
TypeError: If dtype of input_x is neither float16 nor float32.
|
||||
ValueError: If lambd is less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mstype.float16)
|
||||
>>> softshrink = nn.SoftShrink()
|
||||
>>> output = softshrink(input_x)
|
||||
>>> print(output)
|
||||
[[ 0.02979 0.287 0.676 ]
|
||||
[ 0.2837 0.1216 -0.6543 ]]
|
||||
"""
|
||||
|
||||
def __init__(self, lambd=0.5):
|
||||
super(SoftShrink, self).__init__()
|
||||
self.softshrink = P.SoftShrink(lambd)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.softshrink(input_x)
|
||||
return output
|
||||
|
||||
_activation = {
|
||||
'softmax': Softmax,
|
||||
|
@ -770,6 +818,7 @@ _activation = {
|
|||
'hswish': HSwish,
|
||||
'hsigmoid': HSigmoid,
|
||||
'logsigmoid': LogSigmoid,
|
||||
'softshrink': SoftShrink,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ def get_bprop_ctc_loss_v2(self):
|
|||
|
||||
return bprop
|
||||
|
||||
"""nn_ops"""
|
||||
|
||||
@bprop_getters.register(P.SoftShrink)
|
||||
def get_bprop_softshrink(self):
|
||||
|
|
|
@ -2206,8 +2206,9 @@ class SoftShrinkGrad(Primitive):
|
|||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lambd=0.5):
|
||||
self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
|
||||
validator.check_value_type("lambd", lambd, [float], self.name)
|
||||
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
|
||||
validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
|
||||
|
|
|
@ -8670,6 +8670,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
}
|
||||
return out
|
||||
|
||||
|
||||
class SoftShrink(Primitive):
|
||||
r"""
|
||||
Applies the soft shrinkage function elementwise.
|
||||
|
|
Loading…
Reference in New Issue