forked from mindspore-Ecosystem/mindspore
fixed prob, survival function of exponential distribution
This commit is contained in:
parent
39e2791149
commit
9083e9dcd6
|
@ -198,9 +198,9 @@ class Exponential(Distribution):
|
|||
return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
|
||||
|
||||
|
||||
def _prob(self, value, rate=None):
|
||||
def _log_prob(self, value, rate=None):
|
||||
r"""
|
||||
pdf of Exponential distribution.
|
||||
log_pdf of Exponential distribution.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
|
@ -211,15 +211,16 @@ class Exponential(Distribution):
|
|||
Value should be greater or equal to zero.
|
||||
|
||||
.. math::
|
||||
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||
log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0
|
||||
"""
|
||||
value = self._check_value(value, "value")
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param(rate)
|
||||
prob = self.exp(self.log(rate) - rate * value)
|
||||
prob = self.log(rate) - rate * value
|
||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, prob)
|
||||
return self.select(comp, neginf, prob)
|
||||
|
||||
def _cdf(self, value, rate=None):
|
||||
r"""
|
||||
|
@ -243,6 +244,27 @@ class Exponential(Distribution):
|
|||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
|
||||
def _log_survival(self, value, rate=None):
|
||||
r"""
|
||||
log survival_function of Exponential distribution.
|
||||
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
rate (Tensor): rate of the distribution. Default: self.rate.
|
||||
|
||||
Note:
|
||||
Value should be greater or equal to zero.
|
||||
|
||||
.. math::
|
||||
log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param(rate)
|
||||
sf = -1. * rate * value
|
||||
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, sf)
|
||||
|
||||
def _kl_loss(self, dist, rate_b, rate=None):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue