From 9da9e946068ae80b4c87bd4f649be1caf24538e4 Mon Sep 17 00:00:00 2001 From: wangshuide2020 Date: Thu, 18 Mar 2021 17:15:02 +0800 Subject: [PATCH] fix the bug that the data type of float16 and float32 of SeLU is only supported. --- mindspore/ops/operations/nn_ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 72b41e31620..724a30ffce9 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -428,6 +428,11 @@ class SeLU(PrimitiveWithInfer): \text{alpha} * (\exp(x_i) - 1), &\text{otherwise.} \end{cases} + where :math:`alpha` and :math:`scale` are pre-defined constants(:math:`alpha=1.67326324` + and :math:`scale=1.05070098`). + + See more details in `Self-Normalizing Neural Networks `_. + Inputs: - **input_x** (Tensor) - The input tensor. @@ -438,7 +443,7 @@ class SeLU(PrimitiveWithInfer): ``Ascend`` Raise: - TypeError: If num_features data type not int8, int32, float16 and float32 Tensor. + TypeError: If dtype of `input_x` is neither float16 nor float32. Examples: >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) @@ -458,7 +463,7 @@ class SeLU(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) return x_dtype