forked from mindspore-Ecosystem/mindspore
modified dtype cast logic in custom_log and custom_exp and fixed dtype issues in softplus bijector
This commit is contained in:
parent
1256737a7c
commit
239826515a
|
@ -15,7 +15,6 @@
|
|||
"""Softplus Bijector"""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.layer.activation import LogSigmoid
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor
|
||||
|
@ -71,6 +70,7 @@ class Softplus(Bijector):
|
|||
self.log = log_generic
|
||||
self.expm1 = expm1_generic
|
||||
self.abs = P.Abs()
|
||||
self.dtypeop = P.DType()
|
||||
self.fill = P.Fill()
|
||||
self.greater = P.Greater()
|
||||
self.less = P.Less()
|
||||
|
@ -90,7 +90,7 @@ class Softplus(Bijector):
|
|||
too_large = self.greater(x, -self.threshold)
|
||||
too_small_value = self.exp(x)
|
||||
too_large_value = x
|
||||
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
||||
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||
x = self.select(too_small_or_too_large, ones, x)
|
||||
y = self.log(self.exp(x) + 1.0)
|
||||
|
@ -106,7 +106,7 @@ class Softplus(Bijector):
|
|||
too_large = self.greater(x, -self.threshold)
|
||||
too_small_value = self.log(x)
|
||||
too_large_value = x
|
||||
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
||||
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||
x = self.select(too_small_or_too_large, ones, x)
|
||||
y = x + self.log(self.abs(self.expm1(-x)))
|
||||
|
|
|
@ -24,8 +24,11 @@ def exp_generic(input_x):
|
|||
"""
|
||||
exp = P.Exp()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
if not checktype(dtype(input_x), mstype.float_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
return exp(input_x)
|
||||
|
||||
|
||||
|
@ -51,8 +54,10 @@ def log_generic(input_x):
|
|||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
select = P.Select()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
if not checktype(dtype(input_x), mstype.float_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
||||
neg_x = less(input_x, 0.0)
|
||||
|
|
|
@ -222,7 +222,7 @@ class Bernoulli(Distribution):
|
|||
pmf(k) = probs0 if k = 0;
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.cast(value, self.parameter_type)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||
|
@ -241,7 +241,7 @@ class Bernoulli(Distribution):
|
|||
cdf(k) = 1 if k >=1;
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.cast(value, self.parameter_type)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
prob_type = self.dtypeop(probs1)
|
||||
|
|
|
@ -225,7 +225,7 @@ class Geometric(Distribution):
|
|||
pmf(k) = 0 if k < 0.
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.cast(value, self.parameter_type)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||
|
@ -247,7 +247,7 @@ class Geometric(Distribution):
|
|||
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.cast(value, self.parameter_type)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
|
|
Loading…
Reference in New Issue