forked from mindspore-Ecosystem/mindspore
!7051 Fixed zero plus neg_inf issue under fp16
Merge pull request !7051 from XunDeng/pp_issue_branch
This commit is contained in:
commit
a75b3161e1
|
@ -13,7 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformed Distribution"""
|
||||
import numpy as np
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import raise_not_impl_error
|
||||
|
@ -80,6 +82,8 @@ class TransformedDistribution(Distribution):
|
|||
self.parameter_names = distribution.parameter_names
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.equal_base = P.Equal()
|
||||
self.select_base = P.Select()
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
|
@ -125,7 +129,8 @@ class TransformedDistribution(Distribution):
|
|||
inverse_value = self.bijector("inverse", value)
|
||||
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
|
||||
log_jacobian = self.bijector("inverse_log_jacobian", value)
|
||||
return unadjust_prob + log_jacobian
|
||||
isneginf = self.equal_base(unadjust_prob, -np.inf)
|
||||
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)
|
||||
|
||||
def _prob(self, value, *args, **kwargs):
|
||||
return self.exp(self._log_prob(value, *args, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue