diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 58d1c7cd014..e4fd79154ec 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -272,6 +272,10 @@ def check_type(data_type, value_type, name): def raise_none_error(name): raise ValueError(f"{name} should be specified. Value cannot be None") +@constexpr +def raise_not_impl_error(name): + raise ValueError(f"{name} function should be implemented for non-linear transformation") + @constexpr def check_distribution_name(name, expected_name): if name != expected_name: diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index baed5f10d13..259f105d4e8 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -18,7 +18,7 @@ from mindspore._checkparam import Validator as validator from mindspore.common import dtype as mstype import mindspore.nn as nn from .distribution import Distribution -from ._utils.utils import check_type +from ._utils.utils import check_type, raise_not_impl_error class TransformedDistribution(Distribution): """ @@ -56,6 +56,7 @@ class TransformedDistribution(Distribution): self._distribution = distribution self._is_linear_transformation = bijector.is_constant_jacobian self.exp = P.Exp() + self.log = P.Log() @property def bijector(self): @@ -69,37 +70,49 @@ class TransformedDistribution(Distribution): def is_linear_transformation(self): return self._is_linear_transformation - def _cdf(self, value): + def _cdf(self, *args, **kwargs): r""" .. math:: Y = g(X) P(Y <= a) = P(X <= g^{-1}(a)) """ - inverse_value = self.bijector.inverse(value) - return self.distribution.cdf(inverse_value) + inverse_value = self.bijector("inverse", *args, **kwargs) + return self.distribution("cdf", inverse_value) - def _log_prob(self, value): + def _log_cdf(self, *args, **kwargs): + return self.log(self._cdf(*args, **kwargs)) + + def _survival_function(self, *args, **kwargs): + return 1.0 - self._cdf(*args, **kwargs) + + def _log_survival(self, *args, **kwargs): + return self.log(self._survival_function(*args, **kwargs)) + + def _log_prob(self, *args, **kwargs): r""" .. math:: Y = g(X) Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a) \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a)) """ - inverse_value = self.bijector.inverse(value) - unadjust_prob = self.distribution.log_prob(inverse_value) - log_jacobian = self.bijector.inverse_log_jacobian(value) + inverse_value = self.bijector("inverse", *args, **kwargs) + unadjust_prob = self.distribution("log_prob", inverse_value) + log_jacobian = self.bijector("inverse_log_jacobian", *args, **kwargs) return unadjust_prob + log_jacobian - def _prob(self, value): - return self.exp(self._log_prob(value)) + def _prob(self, *args, **kwargs): + return self.exp(self._log_prob(*args, **kwargs)) - def _sample(self, shape): - org_sample = self.distribution.sample(shape) - return self.bijector.forward(org_sample) + def _sample(self, *args, **kwargs): + org_sample = self.distribution("sample", shape) + return self.bijector("forward", org_sample) - def _mean(self): + def _mean(self, *args, **kwargs): """ Note: This function maybe overridden by derived class. """ - return self.bijector.forward(self.distribution.mean()) + if not self.is_linear_transformation: + raise_not_impl_error(mean) + + return self.bijector("forward", self.distribution("mean"))