!7051 Fixed zero plus neg_inf issue under fp16

Merge pull request !7051 from XunDeng/pp_issue_branch
This commit is contained in:
mindspore-ci-bot 2020-09-30 08:37:05 +08:00 committed by Gitee
commit a75b3161e1
1 changed files with 6 additions and 1 deletions

View File

@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Transformed Distribution"""
import numpy as np
from mindspore._checkparam import Validator as validator
from mindspore.ops import operations as P
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import raise_not_impl_error
@ -80,6 +82,8 @@ class TransformedDistribution(Distribution):
self.parameter_names = distribution.parameter_names
self.exp = exp_generic
self.log = log_generic
self.equal_base = P.Equal()
self.select_base = P.Select()
@property
def bijector(self):
@ -125,7 +129,8 @@ class TransformedDistribution(Distribution):
inverse_value = self.bijector("inverse", value)
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
log_jacobian = self.bijector("inverse_log_jacobian", value)
return unadjust_prob + log_jacobian
isneginf = self.equal_base(unadjust_prob, -np.inf)
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)
def _prob(self, value, *args, **kwargs):
return self.exp(self._log_prob(value, *args, **kwargs))