!49061 [bugfix] Fix I5YRQ8

Merge pull request !49061 from shaojunsong/fix/I5YRQ8
This commit is contained in:
i-robot 2023-02-20 11:09:14 +00:00 committed by Gitee
commit 1eeeba149e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 4 deletions

View File

@ -556,16 +556,17 @@ class RReLU(Cell):
raise ValueError(f"For {self.cls_name}, the value of 'upper' must be greater than 'lower', "
f"but got upper: {upper}, lower: {lower}. ")
self.lower = lower
self.upper = upper
self.lower = Tensor(lower)
self.upper = Tensor(upper)
self.sign = P.Sign()
def construct(self, x):
size = x.shape
_size = x.shape
_dtype = x.dtype
sign_matrix = self.sign(x)
negative_filter = sign_matrix.clip(None, 0)
positive_filter = sign_matrix.clip(0, None)
mask = P.Cast()(Tensor(np.random.uniform(self.lower, self.upper, size=size)), P.DType()(x))
mask = ops.uniform(_size, self.lower.astype(_dtype), self.upper.astype((_dtype)), dtype=_dtype)
negative_mask = negative_filter * mask * -1
total_mask = negative_mask + positive_filter
out = total_mask * x