forked from mindspore-Ecosystem/mindspore
!35705 [Prob]fix zeros generation in power_transform
Merge pull request !35705 from zichun_ye/power_transform_fix
This commit is contained in:
commit
e086ba7cf7
|
@ -133,11 +133,11 @@ class PowerTransform(Bijector):
|
|||
self.shape(x + power_local), 1.)
|
||||
power_local = power_local * ones
|
||||
x = x * ones
|
||||
safe_power = self.select_base(self.equal_base(power_local, 0.),
|
||||
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
|
||||
ones,
|
||||
power_local)
|
||||
|
||||
forward_v = self.select_base(self.equal_base(power_local, 0.),
|
||||
forward_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
|
||||
self.exp(x),
|
||||
self.exp(self.log1p(x * safe_power) / safe_power))
|
||||
return forward_v
|
||||
|
@ -154,11 +154,11 @@ class PowerTransform(Bijector):
|
|||
self.shape(y + power_local), 1.)
|
||||
power_local = power_local * ones
|
||||
y = y * ones
|
||||
safe_power = self.select_base(self.equal_base(power_local, 0.),
|
||||
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
|
||||
ones,
|
||||
power_local)
|
||||
|
||||
inverse_v = self.select_base(self.equal_base(power_local, 0.),
|
||||
inverse_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
|
||||
self.log(y),
|
||||
self.expm1(self.log(y) * safe_power) / safe_power)
|
||||
|
||||
|
@ -185,7 +185,7 @@ class PowerTransform(Bijector):
|
|||
power_local = power_local * ones
|
||||
x = x * ones
|
||||
|
||||
forward_log_j = self.select_base(self.equal_base(power_local, 0.),
|
||||
forward_log_j = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
|
||||
x,
|
||||
(1. / power_local - 1) * self.log1p(x * power_local))
|
||||
|
||||
|
|
Loading…
Reference in New Issue