Change the interfaces in trasformation base class

This commit is contained in:
peixu_ren 2020-08-20 16:07:24 -04:00
parent 5a0fe979ab
commit 4aa339cb5a
2 changed files with 32 additions and 15 deletions

View File

@ -272,6 +272,10 @@ def check_type(data_type, value_type, name):
def raise_none_error(name): def raise_none_error(name):
raise ValueError(f"{name} should be specified. Value cannot be None") raise ValueError(f"{name} should be specified. Value cannot be None")
@constexpr
def raise_not_impl_error(name):
raise ValueError(f"{name} function should be implemented for non-linear transformation")
@constexpr @constexpr
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
if name != expected_name: if name != expected_name:

View File

@ -18,7 +18,7 @@ from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_type from ._utils.utils import check_type, raise_not_impl_error
class TransformedDistribution(Distribution): class TransformedDistribution(Distribution):
""" """
@ -56,6 +56,7 @@ class TransformedDistribution(Distribution):
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = P.Exp() self.exp = P.Exp()
self.log = P.Log()
@property @property
def bijector(self): def bijector(self):
@ -69,37 +70,49 @@ class TransformedDistribution(Distribution):
def is_linear_transformation(self): def is_linear_transformation(self):
return self._is_linear_transformation return self._is_linear_transformation
def _cdf(self, value): def _cdf(self, *args, **kwargs):
r""" r"""
.. math:: .. math::
Y = g(X) Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a)) P(Y <= a) = P(X <= g^{-1}(a))
""" """
inverse_value = self.bijector.inverse(value) inverse_value = self.bijector("inverse", *args, **kwargs)
return self.distribution.cdf(inverse_value) return self.distribution("cdf", inverse_value)
def _log_prob(self, value): def _log_cdf(self, *args, **kwargs):
return self.log(self._cdf(*args, **kwargs))
def _survival_function(self, *args, **kwargs):
return 1.0 - self._cdf(*args, **kwargs)
def _log_survival(self, *args, **kwargs):
return self.log(self._survival_function(*args, **kwargs))
def _log_prob(self, *args, **kwargs):
r""" r"""
.. math:: .. math::
Y = g(X) Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a) Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a)) \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
""" """
inverse_value = self.bijector.inverse(value) inverse_value = self.bijector("inverse", *args, **kwargs)
unadjust_prob = self.distribution.log_prob(inverse_value) unadjust_prob = self.distribution("log_prob", inverse_value)
log_jacobian = self.bijector.inverse_log_jacobian(value) log_jacobian = self.bijector("inverse_log_jacobian", *args, **kwargs)
return unadjust_prob + log_jacobian return unadjust_prob + log_jacobian
def _prob(self, value): def _prob(self, *args, **kwargs):
return self.exp(self._log_prob(value)) return self.exp(self._log_prob(*args, **kwargs))
def _sample(self, shape): def _sample(self, *args, **kwargs):
org_sample = self.distribution.sample(shape) org_sample = self.distribution("sample", shape)
return self.bijector.forward(org_sample) return self.bijector("forward", org_sample)
def _mean(self): def _mean(self, *args, **kwargs):
""" """
Note: Note:
This function maybe overridden by derived class. This function maybe overridden by derived class.
""" """
return self.bijector.forward(self.distribution.mean()) if not self.is_linear_transformation:
raise_not_impl_error(mean)
return self.bijector("forward", self.distribution("mean"))