!35705 [Prob]fix zeros generation in power_transform

Merge pull request !35705 from zichun_ye/power_transform_fix
This commit is contained in:
i-robot 2022-06-10 07:25:05 +00:00 committed by Gitee
commit e086ba7cf7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 5 deletions

View File

@ -133,11 +133,11 @@ class PowerTransform(Bijector):
self.shape(x + power_local), 1.)
power_local = power_local * ones
x = x * ones
safe_power = self.select_base(self.equal_base(power_local, 0.),
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
ones,
power_local)
forward_v = self.select_base(self.equal_base(power_local, 0.),
forward_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
self.exp(x),
self.exp(self.log1p(x * safe_power) / safe_power))
return forward_v
@ -154,11 +154,11 @@ class PowerTransform(Bijector):
self.shape(y + power_local), 1.)
power_local = power_local * ones
y = y * ones
safe_power = self.select_base(self.equal_base(power_local, 0.),
safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
ones,
power_local)
inverse_v = self.select_base(self.equal_base(power_local, 0.),
inverse_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
self.log(y),
self.expm1(self.log(y) * safe_power) / safe_power)
@ -185,7 +185,7 @@ class PowerTransform(Bijector):
power_local = power_local * ones
x = x * ones
forward_log_j = self.select_base(self.equal_base(power_local, 0.),
forward_log_j = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
x,
(1. / power_local - 1) * self.log1p(x * power_local))