forked from mindspore-Ecosystem/mindspore
modified dtype cast logic in custom_log and custom_exp and fixed dtype issues in softplus bijector
This commit is contained in:
parent
1256737a7c
commit
239826515a
|
@ -15,7 +15,6 @@
|
||||||
"""Softplus Bijector"""
|
"""Softplus Bijector"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.nn.layer.activation import LogSigmoid
|
from mindspore.nn.layer.activation import LogSigmoid
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..distribution._utils.utils import cast_to_tensor
|
from ..distribution._utils.utils import cast_to_tensor
|
||||||
|
@ -71,6 +70,7 @@ class Softplus(Bijector):
|
||||||
self.log = log_generic
|
self.log = log_generic
|
||||||
self.expm1 = expm1_generic
|
self.expm1 = expm1_generic
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
|
self.dtypeop = P.DType()
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.greater = P.Greater()
|
self.greater = P.Greater()
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
|
@ -90,7 +90,7 @@ class Softplus(Bijector):
|
||||||
too_large = self.greater(x, -self.threshold)
|
too_large = self.greater(x, -self.threshold)
|
||||||
too_small_value = self.exp(x)
|
too_small_value = self.exp(x)
|
||||||
too_large_value = x
|
too_large_value = x
|
||||||
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||||
x = self.select(too_small_or_too_large, ones, x)
|
x = self.select(too_small_or_too_large, ones, x)
|
||||||
y = self.log(self.exp(x) + 1.0)
|
y = self.log(self.exp(x) + 1.0)
|
||||||
|
@ -106,7 +106,7 @@ class Softplus(Bijector):
|
||||||
too_large = self.greater(x, -self.threshold)
|
too_large = self.greater(x, -self.threshold)
|
||||||
too_small_value = self.log(x)
|
too_small_value = self.log(x)
|
||||||
too_large_value = x
|
too_large_value = x
|
||||||
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||||
x = self.select(too_small_or_too_large, ones, x)
|
x = self.select(too_small_or_too_large, ones, x)
|
||||||
y = x + self.log(self.abs(self.expm1(-x)))
|
y = x + self.log(self.abs(self.expm1(-x)))
|
||||||
|
|
|
@ -24,8 +24,11 @@ def exp_generic(input_x):
|
||||||
"""
|
"""
|
||||||
exp = P.Exp()
|
exp = P.Exp()
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
|
dtype = P.DType()
|
||||||
|
checktype = P.IsSubClass()
|
||||||
|
|
||||||
input_x = cast(input_x, mstype.float32)
|
if not checktype(dtype(input_x), mstype.float_):
|
||||||
|
input_x = cast(input_x, mstype.float32)
|
||||||
return exp(input_x)
|
return exp(input_x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,8 +54,10 @@ def log_generic(input_x):
|
||||||
dtype = P.DType()
|
dtype = P.DType()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
select = P.Select()
|
select = P.Select()
|
||||||
|
checktype = P.IsSubClass()
|
||||||
|
|
||||||
input_x = cast(input_x, mstype.float32)
|
if not checktype(dtype(input_x), mstype.float_):
|
||||||
|
input_x = cast(input_x, mstype.float32)
|
||||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||||
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
||||||
neg_x = less(input_x, 0.0)
|
neg_x = less(input_x, 0.0)
|
||||||
|
|
|
@ -222,7 +222,7 @@ class Bernoulli(Distribution):
|
||||||
pmf(k) = probs0 if k = 0;
|
pmf(k) = probs0 if k = 0;
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, self.parameter_type)
|
||||||
probs1 = self._check_param_type(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||||
|
@ -241,7 +241,7 @@ class Bernoulli(Distribution):
|
||||||
cdf(k) = 1 if k >=1;
|
cdf(k) = 1 if k >=1;
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, self.parameter_type)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param_type(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
prob_type = self.dtypeop(probs1)
|
prob_type = self.dtypeop(probs1)
|
||||||
|
|
|
@ -225,7 +225,7 @@ class Geometric(Distribution):
|
||||||
pmf(k) = 0 if k < 0.
|
pmf(k) = 0 if k < 0.
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, self.parameter_type)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param_type(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||||
|
@ -247,7 +247,7 @@ class Geometric(Distribution):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, self.parameter_type)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param_type(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
|
|
Loading…
Reference in New Issue