update bernoull dist: clamp prob for log_prob/prob
fix doc
This commit is contained in:
parent
f9f24ca94d
commit
b8fd0c196c
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_prob, check_distribution_name
|
||||
from ._utils.utils import check_prob, check_distribution_name, clamp_probs
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -86,7 +86,6 @@ class Bernoulli(Distribution):
|
|||
>>> ans = b2.mean(probs_a)
|
||||
>>> print(ans.shape)
|
||||
(1,)
|
||||
>>> print(ans.shape)
|
||||
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
||||
>>> # Args:
|
||||
>>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported.
|
||||
|
@ -132,7 +131,8 @@ class Bernoulli(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'probs': probs}
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type_name(
|
||||
"dtype", dtype, valid_dtype, type(self).__name__)
|
||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._probs = self._add_parameter(probs, 'probs')
|
||||
|
@ -241,6 +241,9 @@ class Bernoulli(Distribution):
|
|||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.parameter_type)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
|
||||
# clamp value for numerical stability
|
||||
probs1 = clamp_probs(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||
|
||||
|
@ -266,8 +269,10 @@ class Bernoulli(Distribution):
|
|||
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
|
||||
comp_zero = self.less(value, 0.0)
|
||||
comp_one = self.less(value, 1.0)
|
||||
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0)
|
||||
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0)
|
||||
zeros = self.fill(self.parameter_type, self.shape(
|
||||
broadcast_shape_tensor), 0.0)
|
||||
ones = self.fill(self.parameter_type, self.shape(
|
||||
broadcast_shape_tensor), 1.0)
|
||||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||
return self.select(comp_one, less_than_zero, ones)
|
||||
|
||||
|
|
Loading…
Reference in New Issue