forked from mindspore-Ecosystem/mindspore
!49061 [bugfix] Fix I5YRQ8
Merge pull request !49061 from shaojunsong/fix/I5YRQ8
This commit is contained in:
commit
1eeeba149e
|
@ -556,16 +556,17 @@ class RReLU(Cell):
|
|||
raise ValueError(f"For {self.cls_name}, the value of 'upper' must be greater than 'lower', "
|
||||
f"but got upper: {upper}, lower: {lower}. ")
|
||||
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.lower = Tensor(lower)
|
||||
self.upper = Tensor(upper)
|
||||
self.sign = P.Sign()
|
||||
|
||||
def construct(self, x):
|
||||
size = x.shape
|
||||
_size = x.shape
|
||||
_dtype = x.dtype
|
||||
sign_matrix = self.sign(x)
|
||||
negative_filter = sign_matrix.clip(None, 0)
|
||||
positive_filter = sign_matrix.clip(0, None)
|
||||
mask = P.Cast()(Tensor(np.random.uniform(self.lower, self.upper, size=size)), P.DType()(x))
|
||||
mask = ops.uniform(_size, self.lower.astype(_dtype), self.upper.astype((_dtype)), dtype=_dtype)
|
||||
negative_mask = negative_filter * mask * -1
|
||||
total_mask = negative_mask + positive_filter
|
||||
out = total_mask * x
|
||||
|
|
Loading…
Reference in New Issue