From b8fd0c196c2187d4c31b0f14c91a77002b1353e4 Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Sun, 21 Feb 2021 17:40:30 -0500 Subject: [PATCH] update bernoull dist: clamp prob for log_prob/prob fix doc --- .../nn/probability/distribution/bernoulli.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index e6285257bbb..9e67f9e1d95 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore._checkparam import Validator from .distribution import Distribution -from ._utils.utils import check_prob, check_distribution_name +from ._utils.utils import check_prob, check_distribution_name, clamp_probs from ._utils.custom_ops import exp_generic, log_generic @@ -86,7 +86,6 @@ class Bernoulli(Distribution): >>> ans = b2.mean(probs_a) >>> print(ans.shape) (1,) - >>> print(ans.shape) >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: >>> # Args: >>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported. @@ -132,7 +131,8 @@ class Bernoulli(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) + Validator.check_type_name( + "dtype", dtype, valid_dtype, type(self).__name__) super(Bernoulli, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') @@ -241,6 +241,9 @@ class Bernoulli(Distribution): value = self._check_value(value, 'value') value = self.cast(value, self.parameter_type) probs1 = self._check_param_type(probs1) + + # clamp value for numerical stability + probs1 = clamp_probs(probs1) probs0 = 1.0 - probs1 return self.log(probs1) * value + self.log(probs0) * (1.0 - value) @@ -266,8 +269,10 @@ class Bernoulli(Distribution): probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) comp_zero = self.less(value, 0.0) comp_one = self.less(value, 1.0) - zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) - ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) + zeros = self.fill(self.parameter_type, self.shape( + broadcast_shape_tensor), 0.0) + ones = self.fill(self.parameter_type, self.shape( + broadcast_shape_tensor), 1.0) less_than_zero = self.select(comp_zero, zeros, probs0) return self.select(comp_one, less_than_zero, ones)